In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import cv2

# First, let's analyze the dataset structure
base_dir = "COVID-19_Radiography_Dataset"  # Change this to your dataset location

# Print out directory structure to understand the dataset organization
def print_directory_structure(base_path, level=0):
    if not os.path.exists(base_path):
        print(f"Path does not exist: {base_path}")
        return
        
    if level == 0:
        print(f"Directory structure of {base_path}:")
    
    for item in os.listdir(base_path):
        item_path = os.path.join(base_path, item)
        prefix = "  " * level + "- "
        
        if os.path.isdir(item_path):
            print(f"{prefix}{item} (dir)")
            if level < 2:  # Limit depth to avoid too much output
                print_directory_structure(item_path, level + 1)
        else:
            if level <= 1 or (level == 2 and "..." not in prefix):
                print(f"{prefix}{item}")
            if level == 2 and "..." not in prefix:
                print(f"{prefix}...")  # Show only a few files for level 2

print_directory_structure(base_dir)

# Based on the typical structure of this dataset, let's modify our approach
# The dataset usually has these folders: COVID, Normal, Lung_Opacity, Viral Pneumonia
# And for each, there should be images and corresponding masks

# Function to load and preprocess images
def load_data(img_paths, mask_paths, img_size=(256, 256)):
    X = []
    y = []
    
    for img_path, mask_path in zip(img_paths, mask_paths):
        try:
            # Load and resize image
            img = load_img(img_path, target_size=img_size, color_mode='grayscale')
            img = img_to_array(img) / 255.0
            
            # Load and resize mask
            mask = load_img(mask_path, target_size=img_size, color_mode='grayscale')
            mask = img_to_array(mask) / 255.0
            
            X.append(img)
            y.append(mask)
        except Exception as e:
            print(f"Error loading {img_path} or {mask_path}: {e}")
    
    return np.array(X), np.array(y)

# Dynamically discover the dataset structure
image_files = []
mask_files = []

# Look for common folder patterns in this dataset
possible_class_folders = ["COVID", "COVID-19", "Normal", "NORMAL", "Lung_Opacity", "Viral Pneumonia"]
possible_image_folders = ["images", "Images"]
possible_mask_folders = ["masks", "Masks", "mask", "Mask"]

# Check for metadata files that might help understand the structure
metadata_files = [f for f in os.listdir(base_dir) if f.endswith('.csv') or f.endswith('.xlsx')]
if metadata_files:
    print(f"Found metadata files: {metadata_files}")
    # You could parse these files to understand dataset structure better

# Let's try multiple approaches to find the images and masks
# Approach 1: Look for images and masks directories
if os.path.exists(os.path.join(base_dir, "images")) and os.path.exists(os.path.join(base_dir, "masks")):
    images_dir = os.path.join(base_dir, "images")
    masks_dir = os.path.join(base_dir, "masks")
    
    for class_name in os.listdir(images_dir):
        class_images_dir = os.path.join(images_dir, class_name)
        class_masks_dir = os.path.join(masks_dir, class_name)
        
        if os.path.isdir(class_images_dir) and os.path.isdir(class_masks_dir):
            for filename in os.listdir(class_images_dir):
                if filename.endswith(('.png', '.jpg', '.jpeg')):
                    image_path = os.path.join(class_images_dir, filename)
                    mask_path = os.path.join(class_masks_dir, filename)
                    
                    if os.path.exists(mask_path):
                        image_files.append(image_path)
                        mask_files.append(mask_path)

# Approach 2: Look for class directories with image and mask subdirectories
else:
    for class_name in os.listdir(base_dir):
        class_dir = os.path.join(base_dir, class_name)
        
        if os.path.isdir(class_dir):
            # Check for image and mask folders within each class directory
            found_images_dir = None
            found_masks_dir = None
            
            for folder in possible_image_folders:
                if os.path.exists(os.path.join(class_dir, folder)):
                    found_images_dir = os.path.join(class_dir, folder)
                    break
            
            for folder in possible_mask_folders:
                if os.path.exists(os.path.join(class_dir, folder)):
                    found_masks_dir = os.path.join(class_dir, folder)
                    break
            
            # If we found both, add the image-mask pairs
            if found_images_dir and found_masks_dir:
                for filename in os.listdir(found_images_dir):
                    if filename.endswith(('.png', '.jpg', '.jpeg')):
                        image_path = os.path.join(found_images_dir, filename)
                        mask_path = os.path.join(found_masks_dir, filename)
                        
                        if os.path.exists(mask_path):
                            image_files.append(image_path)
                            mask_files.append(mask_path)
            
            # Also check if images and masks are stored directly in the class directory
            # with different prefixes or suffixes
            image_prefix_candidates = ["", "image_", "img_"]
            mask_prefix_candidates = ["mask_", "seg_", "segmentation_"]
            
            for img_prefix in image_prefix_candidates:
                for mask_prefix in mask_prefix_candidates:
                    # Count potential matches
                    matches = 0
                    for filename in os.listdir(class_dir):
                        if filename.startswith(img_prefix) and filename.endswith(('.png', '.jpg', '.jpeg')):
                            base_name = filename[len(img_prefix):]
                            mask_name = f"{mask_prefix}{base_name}"
                            if os.path.exists(os.path.join(class_dir, mask_name)):
                                matches += 1
                    
                    if matches > 10:  # If we find enough matches, this is likely the pattern
                        print(f"Found pattern: images with prefix '{img_prefix}' and masks with prefix '{mask_prefix}' in {class_dir}")
                        
                        for filename in os.listdir(class_dir):
                            if filename.startswith(img_prefix) and filename.endswith(('.png', '.jpg', '.jpeg')):
                                base_name = filename[len(img_prefix):]
                                mask_name = f"{mask_prefix}{base_name}"
                                
                                image_path = os.path.join(class_dir, filename)
                                mask_path = os.path.join(class_dir, mask_name)
                                
                                if os.path.exists(mask_path):
                                    image_files.append(image_path)
                                    mask_files.append(mask_path)

# Approach 3: For the Kaggle COVID dataset, check for the common structure
# COVID-19, Normal, Lung_Opacity, Viral Pneumonia folders each with images and masks subfolders
for class_name in ["COVID-19", "Normal", "Lung_Opacity", "Viral Pneumonia"]:
    class_dir = os.path.join(base_dir, class_name)
    if os.path.exists(class_dir):
        images_subdir = os.path.join(class_dir, "images")
        masks_subdir = os.path.join(class_dir, "masks")
        
        if os.path.exists(images_subdir) and os.path.exists(masks_subdir):
            for filename in os.listdir(images_subdir):
                if filename.endswith(('.png', '.jpg', '.jpeg')):
                    image_path = os.path.join(images_subdir, filename)
                    mask_path = os.path.join(masks_subdir, filename)
                    
                    if os.path.exists(mask_path):
                        image_files.append(image_path)
                        mask_files.append(mask_path)

# If we still haven't found anything, try a more direct approach specific to the COVID-19 dataset
if not image_files:
    for class_name in ["COVID-19", "Normal", "Lung_Opacity", "Viral Pneumonia"]:
        class_dir = os.path.join(base_dir, class_name)
        if os.path.exists(class_dir):
            # Look for images
            images_dir = os.path.join(class_dir, "images")
            if not os.path.exists(images_dir):
                continue
                
            # Check for mask directory variations
            mask_dir = None
            for mask_folder in ["masks", "mask", "Masks", "Mask"]:
                potential_mask_dir = os.path.join(class_dir, mask_folder)
                if os.path.exists(potential_mask_dir):
                    mask_dir = potential_mask_dir
                    break
            
            if not mask_dir:
                continue
                
            # Add image-mask pairs
            for filename in os.listdir(images_dir):
                if filename.endswith(('.png', '.jpg', '.jpeg')):
                    image_path = os.path.join(images_dir, filename)
                    mask_path = os.path.join(mask_dir, filename)
                    
                    if os.path.exists(mask_path):
                        image_files.append(image_path)
                        mask_files.append(mask_path)

print(f"Found {len(image_files)} image-mask pairs.")

if len(image_files) == 0:
    print("Could not find image-mask pairs. Let's inspect the dataset further:")
    
    # List the top-level directories and files
    print("\nTop-level contents:")
    for item in os.listdir(base_dir):
        item_path = os.path.join(base_dir, item)
        if os.path.isdir(item_path):
            print(f"Directory: {item} - Contains {len(os.listdir(item_path))} items")
        else:
            print(f"File: {item}")
    
    # Look for any metadata that might help
    csv_files = [f for f in os.listdir(base_dir) if f.endswith('.csv')]
    if csv_files:
        for csv_file in csv_files:
            try:
                df = pd.read_csv(os.path.join(base_dir, csv_file))
                print(f"\nContents of {csv_file}:")
                print(df.head())
                # Look for columns that might point to image paths
                if 'filename' in df.columns or 'path' in df.columns or any('path' in col.lower() for col in df.columns):
                    print("This CSV might contain file paths!")
            except Exception as e:
                print(f"Error reading {csv_file}: {e}")
    
    print("\nPlease check the dataset structure and update the code accordingly.")
    # Placeholder values to allow rest of code to run for demonstration
    image_files = ["placeholder"]
    mask_files = ["placeholder"]

# Split data only if we found image-mask pairs
if len(image_files) > 1:
    # Split data
    train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = train_test_split(
        image_files, mask_files, test_size=0.2, random_state=42
    )
    
    # Create U-Net model
    def unet_model(input_size=(256, 256, 1)):
        inputs = Input(input_size)
        
        # Encoder
        conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
        conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
        pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
        
        conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
        conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
        pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
        
        conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2)
        conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
        
        # Bridge
        conv4 = Conv2D(512, 3, activation='relu', padding='same')(pool3)
        conv4 = Conv2D(512, 3, activation='relu', padding='same')(conv4)
        
        # Decoder
        up5 = concatenate([UpSampling2D(size=(2, 2))(conv4), conv3], axis=3)
        conv5 = Conv2D(256, 3, activation='relu', padding='same')(up5)
        conv5 = Conv2D(256, 3, activation='relu', padding='same')(conv5)
        
        up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv2], axis=3)
        conv6 = Conv2D(128, 3, activation='relu', padding='same')(up6)
        conv6 = Conv2D(128, 3, activation='relu', padding='same')(conv6)
        
        up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv1], axis=3)
        conv7 = Conv2D(64, 3, activation='relu', padding='same')(up7)
        conv7 = Conv2D(64, 3, activation='relu', padding='same')(conv7)
        
        # Output
        outputs = Conv2D(1, 1, activation='sigmoid')(conv7)
        
        model = Model(inputs=inputs, outputs=outputs)
        return model
    
    # Create and compile model
    model = unet_model()
    model.compile(optimizer=Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics=['accuracy'])
    
    # Load training and validation data
    print("Loading training and validation data...")
    train_images, train_masks = load_data(train_img_paths, train_mask_paths)
    val_images, val_masks = load_data(val_img_paths, val_mask_paths)
    
    if len(train_images) > 0 and len(val_images) > 0:
        # Reshape data for model
        train_images = train_images.reshape(-1, 256, 256, 1)
        train_masks = train_masks.reshape(-1, 256, 256, 1)
        val_images = val_images.reshape(-1, 256, 256, 1)
        val_masks = val_masks.reshape(-1, 256, 256, 1)
        
        print(f"Training data shape: {train_images.shape}")
        print(f"Training masks shape: {train_masks.shape}")
        print(f"Validation data shape: {val_images.shape}")
        print(f"Validation masks shape: {val_masks.shape}")
        
        # Train the model
        history = model.fit(
            train_images,
            train_masks,
            batch_size=16,
            epochs=20,
            validation_data=(val_images, val_masks)
        )
        
        # Plot training history
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.title('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history.history['accuracy'], label='Training Accuracy')
        plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
        plt.title('Accuracy')
        plt.legend()
        plt.show()
        
        # Function to make predictions and visualize results
        def visualize_predictions(model, images, masks, num_samples=5):
            # Select random samples
            indices = np.random.choice(range(len(images)), min(num_samples, len(images)), replace=False)
            
            plt.figure(figsize=(15, 5*len(indices)))
            
            for i, idx in enumerate(indices):
                # Original image
                plt.subplot(len(indices), 3, i*3 + 1)
                plt.imshow(images[idx].reshape(256, 256), cmap='gray')
                plt.title('Original Image')
                plt.axis('off')
                
                # True mask
                plt.subplot(len(indices), 3, i*3 + 2)
                plt.imshow(masks[idx].reshape(256, 256), cmap='gray')
                plt.title('True Mask')
                plt.axis('off')
                
                # Predicted mask
                pred_mask = model.predict(images[idx].reshape(1, 256, 256, 1))[0]
                plt.subplot(len(indices), 3, i*3 + 3)
                plt.imshow(pred_mask.reshape(256, 256), cmap='gray')
                plt.title('Predicted Mask')
                plt.axis('off')
            
            plt.tight_layout()
            plt.show()
        
        # Visualize some predictions
        visualize_predictions(model, val_images, val_masks)
        
        # Save the model
        model.save('covid_xray_segmentation_model.h5')
        print("Model saved successfully!")
    else:
        print("Not enough data found to train the model.")
else:
    print("Insufficient data to train the model. Please check the dataset structure.")

Directory structure of COVID-19_Radiography_Dataset:
- COVID (dir)
  - images (dir)
    - COVID-1.png
    - ...
    - COVID-10.png
    - ...
    - COVID-100.png
    - ...
    - COVID-1000.png
    - ...
    - COVID-1001.png
    - ...
    - COVID-1002.png
    - ...
    - COVID-1003.png
    - ...
    - COVID-1004.png
    - ...
    - COVID-1005.png
    - ...
    - COVID-1006.png
    - ...
    - COVID-1007.png
    - ...
    - COVID-1008.png
    - ...
    - COVID-1009.png
    - ...
    - COVID-101.png
    - ...
    - COVID-1010.png
    - ...
    - COVID-1011.png
    - ...
    - COVID-1012.png
    - ...
    - COVID-1013.png
    - ...
    - COVID-1014.png
    - ...
    - COVID-1015.png
    - ...
    - COVID-1016.png
    - ...
    - COVID-1017.png
    - ...
    - COVID-1018.png
    - ...
    - COVID-1019.png
    - ...
    - COVID-102.png
    - ...
    - COVID-1020.png
    - ...
    - COVID-1021.png
    - ...
    - COVID-1022.png
    - ...
    - COVID-1023.png
    - ...
    - COVID-1024.png
    

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split
import cv2

# First, let's analyze the dataset structure
base_dir = "COVID-19_Radiography_Dataset"  # Change this to your dataset location

# Verify dataset path
if not os.path.exists(base_dir):
    print(f"Dataset directory not found: {base_dir}")
    print("Please update the base_dir path to the correct location.")
    exit(1)

# Print out directory structure to understand the dataset organization
def print_directory_structure(base_path, level=0):
    if not os.path.exists(base_path):
        print(f"Path does not exist: {base_path}")
        return
        
    if level == 0:
        print(f"Directory structure of {base_path}:")
    
    for item in os.listdir(base_path):
        item_path = os.path.join(base_path, item)
        prefix = "  " * level + "- "
        
        if os.path.isdir(item_path):
            print(f"{prefix}{item} (dir)")
            if level < 2:  # Limit depth to avoid too much output
                print_directory_structure(item_path, level + 1)
        else:
            if level <= 1 or (level == 2 and "..." not in prefix):
                print(f"{prefix}{item}")
            if level == 2 and "..." not in prefix:
                print(f"{prefix}...")  # Show only a few files for level 2

print_directory_structure(base_dir)

# Function to load and preprocess images
def load_data(img_paths, mask_paths, img_size=(256, 256)):
    X = []
    y = []
    
    for img_path, mask_path in zip(img_paths, mask_paths):
        try:
            # Load and resize image
            img = load_img(img_path, target_size=img_size, color_mode='grayscale')
            img = img_to_array(img) / 255.0
            
            # Load and resize mask
            mask = load_img(mask_path, target_size=img_size, color_mode='grayscale')
            mask = img_to_array(mask) / 255.0
            
            X.append(img)
            y.append(mask)
        except Exception as e:
            print(f"Error loading {img_path} or {mask_path}: {e}")
    
    return np.array(X), np.array(y)

# Dynamically discover the dataset structure
image_files = []
mask_files = []

# Look for common folder patterns in this dataset
possible_class_folders = ["COVID", "COVID-19", "Normal", "NORMAL", "Lung_Opacity", "Viral Pneumonia"]
possible_image_folders = ["images", "Images"]
possible_mask_folders = ["masks", "Masks", "mask", "Mask"]

# Check for metadata files that might help understand the structure
metadata_files = [f for f in os.listdir(base_dir) if f.endswith('.csv') or f.endswith('.xlsx')]
if metadata_files:
    print(f"Found metadata files: {metadata_files}")

# Try multiple approaches to find images and masks
# Approach 1: Look for images and masks directories
if os.path.exists(os.path.join(base_dir, "images")) and os.path.exists(os.path.join(base_dir, "masks")):
    images_dir = os.path.join(base_dir, "images")
    masks_dir = os.path.join(base_dir, "masks")
    
    for class_name in os.listdir(images_dir):
        class_images_dir = os.path.join(images_dir, class_name)
        class_masks_dir = os.path.join(masks_dir, class_name)
        
        if os.path.isdir(class_images_dir) and os.path.isdir(class_masks_dir):
            for filename in os.listdir(class_images_dir):
                if filename.endswith(('.png', '.jpg', '.jpeg')):
                    image_path = os.path.join(class_images_dir, filename)
                    mask_path = os.path.join(class_masks_dir, filename)
                    
                    if os.path.exists(mask_path):
                        image_files.append(image_path)
                        mask_files.append(mask_path)

# Approach 2: Look for class directories with image and mask subdirectories
else:
    for class_name in os.listdir(base_dir):
        class_dir = os.path.join(base_dir, class_name)
        
        if os.path.isdir(class_dir):
            found_images_dir = None
            found_masks_dir = None
            
            for folder in possible_image_folders:
                if os.path.exists(os.path.join(class_dir, folder)):
                    found_images_dir = os.path.join(class_dir, folder)
                    break
            
            for folder in possible_mask_folders:
                if os.path.exists(os.path.join(class_dir, folder)):
                    found_masks_dir = os.path.join(class_dir, folder)
                    break
            
            if found_images_dir and found_masks_dir:
                for filename in os.listdir(found_images_dir):
                    if filename.endswith(('.png', '.jpg', '.jpeg')):
                        image_path = os.path.join(found_images_dir, filename)
                        mask_path = os.path.join(found_masks_dir, filename)
                        
                        if os.path.exists(mask_path):
                            image_files.append(image_path)
                            mask_files.append(mask_path)
            
            # Check for images and masks directly in class directory with prefixes
            image_prefix_candidates = ["", "image_", "img_"]
            mask_prefix_candidates = ["mask_", "seg_", "segmentation_"]
            
            for img_prefix in image_prefix_candidates:
                for mask_prefix in mask_prefix_candidates:
                    matches = 0
                    for filename in os.listdir(class_dir):
                        if filename.startswith(img_prefix) and filename.endswith(('.png', '.jpg', '.jpeg')):
                            base_name = filename[len(img_prefix):]
                            mask_name = f"{mask_prefix}{base_name}"
                            if os.path.exists(os.path.join(class_dir, mask_name)):
                                matches += 1
                    
                    if matches > 10:
                        print(f"Found pattern: images with prefix '{img_prefix}' and masks with prefix '{mask_prefix}' in {class_dir}")
                        
                        for filename in os.listdir(class_dir):
                            if filename.startswith(img_prefix) and filename.endswith(('.png', '.jpg', '.jpeg')):
                                base_name = filename[len(img_prefix):]
                                mask_name = f"{mask_prefix}{base_name}"
                                
                                image_path = os.path.join(class_dir, filename)
                                mask_path = os.path.join(class_dir, mask_name)
                                
                                if os.path.exists(mask_path):
                                    image_files.append(image_path)
                                    mask_files.append(mask_path)

# Approach 3: Kaggle COVID dataset structure
for class_name in ["COVID-19", "Normal", "Lung_Opacity", "Viral Pneumonia"]:
    class_dir = os.path.join(base_dir, class_name)
    if os.path.exists(class_dir):
        images_subdir = os.path.join(class_dir, "images")
        masks_subdir = os.path.join(class_dir, "masks")
        
        if os.path.exists(images_subdir) and os.path.exists(masks_subdir):
            for filename in os.listdir(images_subdir):
                if filename.endswith(('.png', '.jpg', '.jpeg')):
                    image_path = os.path.join(images_subdir, filename)
                    mask_path = os.path.join(masks_subdir, filename)
                    
                    if os.path.exists(mask_path):
                        image_files.append(image_path)
                        mask_files.append(mask_path)

# Final attempt for common structure
if not image_files:
    for class_name in ["COVID-19", "Normal", "Lung_Opacity", "Viral Pneumonia"]:
        class_dir = os.path.join(base_dir, class_name)
        if os.path.exists(class_dir):
            images_dir = os.path.join(class_dir, "images")
            if not os.path.exists(images_dir):
                continue
                
            mask_dir = None
            for mask_folder in ["masks", "mask", "Masks", "Mask"]:
                potential_mask_dir = os.path.join(class_dir, mask_folder)
                if os.path.exists(potential_mask_dir):
                    mask_dir = potential_mask_dir
                    break
            
            if not mask_dir:
                continue
                
            for filename in os.listdir(images_dir):
                if filename.endswith(('.png', '.jpg', '.jpeg')):
                    image_path = os.path.join(images_dir, filename)
                    mask_path = os.path.join(mask_dir, filename)
                    
                    if os.path.exists(mask_path):
                        image_files.append(image_path)
                        mask_files.append(mask_path)

print(f"Found {len(image_files)} image-mask pairs.")

if len(image_files) == 0:
    print("Could not find image-mask pairs. Inspecting dataset further:")
    print("\nTop-level contents:")
    for item in os.listdir(base_dir):
        item_path = os.path.join(base_dir, item)
        if os.path.isdir(item_path):
            print(f"Directory: {item} - Contains {len(os.listdir(item_path))} items")
        else:
            print(f"File: {item}")
    
    csv_files = [f for f in os.listdir(base_dir) if f.endswith('.csv')]
    if csv_files:
        for csv_file in csv_files:
            try:
                df = pd.read_csv(os.path.join(base_dir, csv_file))
                print(f"\nContents of {csv_file}:")
                print(df.head())
                if 'filename' in df.columns or 'path' in df.columns or any('path' in col.lower() for col in df.columns):
                    print("This CSV might contain file paths!")
            except Exception as e:
                print(f"Error reading {csv_file}: {e}")
    
    print("\nPlease check the dataset structure and update the code accordingly.")
    image_files = ["placeholder"]
    mask_files = ["placeholder"]

# Split data only if we found image-mask pairs
if len(image_files) > 1:
    # Split data
    train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = train_test_split(
        image_files, mask_files, test_size=0.2, random_state=42
    )
    
    # Create U-Net model
    def unet_model(input_size=(256, 256, 1)):
        inputs = Input(input_size)
        
        # Encoder
        conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
        conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
        pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
        
        conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
        conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
        pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
        
        conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2)
        conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
        
        # Bridge
        conv4 = Conv2D(512, 3, activation='relu', padding='same')(pool3)
        conv4 = Conv2D(512, 3, activation='relu', padding='same')(conv4)
        
        # Decoder
        up5 = concatenate([UpSampling2D(size=(2, 2))(conv4), conv3], axis=3)
        conv5 = Conv2D(256, 3, activation='relu', padding='same')(up5)
        conv5 = Conv2D(256, 3, activation='relu', padding='same')(conv5)
        
        up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv2], axis=3)
        conv6 = Conv2D(128, 3, activation='relu', padding='same')(up6)
        conv6 = Conv2D(128, 3, activation='relu', padding='same')(conv6)
        
        up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv1], axis=3)
        conv7 = Conv2D(64, 3, activation='relu', padding='same')(up7)
        conv7 = Conv2D(64, 3, activation='relu', padding='same')(conv7)
        
        # Output
        outputs = Conv2D(1, 1, activation='sigmoid')(conv7)
        
        model = Model(inputs=inputs, outputs=outputs)
        return model
    
    # Create and compile model
    model = unet_model()
    model.compile(optimizer=Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics=['accuracy'])
    
    # Load training and validation data
    print("Loading training and validation data...")
    train_images, train_masks = load_data(train_img_paths, train_mask_paths)
    val_images, val_masks = load_data(val_img_paths, val_mask_paths)
    
    if len(train_images) > 0 and len(val_images) > 0:
        # Reshape data for model
        train_images = train_images.reshape(-1, 256, 256, 1)
        train_masks = train_masks.reshape(-1, 256, 256, 1)
        val_images = val_images.reshape(-1, 256, 256, 1)
        val_masks = val_masks.reshape(-1, 256, 256, 1)
        
        print(f"Training data shape: {train_images.shape}")
        print(f"Training masks shape: {train_masks.shape}")
        print(f"Validation data shape: {val_images.shape}")
        print(f"Validation masks shape: {val_masks.shape}")
        
        # Define callbacks
        early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
        checkpoint = ModelCheckpoint(
            'best_covid_xray_segmentation_model.h5',
            monitor='val_loss',
            save_best_only=True,
            mode='min',
            verbose=1
        )
        
        # Train the model with epoch progress
        print("Training U-Net model...")
        history = model.fit(
            train_images,
            train_masks,
            batch_size=16,
            epochs=20,
            validation_data=(val_images, val_masks),
            callbacks=[early_stopping, checkpoint],
            verbose=1  # Show progress bar for each epoch
        )
        
        # Plot training history
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.title('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history.history['accuracy'], label='Training Accuracy')
        plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
        plt.title('Accuracy')
        plt.legend()
        plt.savefig('training_history.png')
        plt.close()
        
        # Function to make predictions and visualize results
        def visualize_predictions(model, images, masks, num_samples=5):
            indices = np.random.choice(range(len(images)), min(num_samples, len(images)), replace=False)
            
            plt.figure(figsize=(15, 5*len(indices)))
            
            for i, idx in enumerate(indices):
                plt.subplot(len(indices), 3, i*3 + 1)
                plt.imshow(images[idx].reshape(256, 256), cmap='gray')
                plt.title('Original Image')
                plt.axis('off')
                
                plt.subplot(len(indices), 3, i*3 + 2)
                plt.imshow(masks[idx].reshape(256, 256), cmap='gray')
                plt.title('True Mask')
                plt.axis('off')
                
                plt.subplot(len(indices), 3, i*3 + 3)
                pred_mask = model.predict(images[idx].reshape(1, 256, 256, 1))[0]
                plt.imshow(pred_mask.reshape(256, 256), cmap='gray')
                plt.title('Predicted Mask')
                plt.axis('off')
            
            plt.tight_layout()
            plt.savefig('predictions.png')
            plt.close()
        
        # Visualize some predictions
        print("Visualizing predictions...")
        visualize_predictions(model, val_images, val_masks)
        
        # Save the final model
        model.save('covid_xray_segmentation_model.h5')
        print("Final model saved successfully!")
    else:
        print("Not enough data found to train the model.")
else:
    print("Insufficient data to train the model. Please check the dataset structure.")

Directory structure of COVID-19_Radiography_Dataset:
- COVID (dir)
  - images (dir)
    - COVID-1.png
    - ...
    - COVID-10.png
    - ...
    - COVID-100.png
    - ...
    - COVID-1000.png
    - ...
    - COVID-1001.png
    - ...
    - COVID-1002.png
    - ...
    - COVID-1003.png
    - ...
    - COVID-1004.png
    - ...
    - COVID-1005.png
    - ...
    - COVID-1006.png
    - ...
    - COVID-1007.png
    - ...
    - COVID-1008.png
    - ...
    - COVID-1009.png
    - ...
    - COVID-101.png
    - ...
    - COVID-1010.png
    - ...
    - COVID-1011.png
    - ...
    - COVID-1012.png
    - ...
    - COVID-1013.png
    - ...
    - COVID-1014.png
    - ...
    - COVID-1015.png
    - ...
    - COVID-1016.png
    - ...
    - COVID-1017.png
    - ...
    - COVID-1018.png
    - ...
    - COVID-1019.png
    - ...
    - COVID-102.png
    - ...
    - COVID-1020.png
    - ...
    - COVID-1021.png
    - ...
    - COVID-1022.png
    - ...
    - COVID-1023.png
    - ...
    - COVID-1024.png
    