In [1]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F

In [2]:
# ==========================================
# 1. Setup and Configuration
# ==========================================

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

# Hyperparameters
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
EPOCHS = 10
IMG_SIZE = 256
SEED = 42

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# ==========================================
# 2. Dataset Definition
# ==========================================

class CovidDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted([f for f in os.listdir(image_dir) if f.endswith('.png')])
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_name = self.images[index]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name)

        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")

        if self.transform is not None:
            image = self.transform(image)
            mask = self.transform(mask)

        # Binarize mask
        mask = (mask > 0).float()

        return image, mask

In [4]:
# ==========================================
# 3. Model Architecture (U-Net)
# ==========================================

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()
        self.down1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.down4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(512, 1024)

        self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.up_conv1 = DoubleConv(1024, 512)
        self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.up_conv2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up_conv3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up_conv4 = DoubleConv(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        d1 = self.down1(x)
        p1 = self.pool1(d1)
        d2 = self.down2(p1)
        p2 = self.pool2(d2)
        d3 = self.down3(p2)
        p3 = self.pool3(d3)
        d4 = self.down4(p3)
        p4 = self.pool4(d4)

        b = self.bottleneck(p4)

        u1 = self.up1(b)
        u1 = torch.cat((d4, u1), dim=1)
        u1 = self.up_conv1(u1)

        u2 = self.up2(u1)
        u2 = torch.cat((d3, u2), dim=1)
        u2 = self.up_conv2(u2)

        u3 = self.up3(u2)
        u3 = torch.cat((d2, u3), dim=1)
        u3 = self.up_conv3(u3)

        u4 = self.up4(u3)
        u4 = torch.cat((d1, u4), dim=1)
        u4 = self.up_conv4(u4)

        return self.final_conv(u4)


In [5]:
# ==========================================
# 4. Training Helper Functions
# ==========================================

def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = y_true.view(-1)
    y_pred_f = y_pred.view(-1)
    intersection = (y_true_f * y_pred_f).sum()
    return (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)

def train_model(model, train_loader, val_loader, criterion, optimizer, epochs):
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for images, masks in loop:
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        # Validation phase
        model.eval()
        val_score = 0
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                outputs = model(images)
                preds = torch.sigmoid(outputs) > 0.5
                val_score += dice_coef(masks, preds.float()).item()

        avg_loss = train_loss / len(train_loader)
        avg_dice = val_score / len(val_loader)
        print(f"Epoch {epoch+1} Completed. Avg Loss: {avg_loss:.4f}, Val Dice Score: {avg_dice:.4f}")

def visualize_results(model, loader, device, num_samples=4):
    model.eval()
    images, masks = next(iter(loader))
    images = images.to(device)
    
    with torch.no_grad():
        outputs = model(images)
        preds = torch.sigmoid(outputs) > 0.5

    plt.figure(figsize=(15, 5))
    for i in range(min(num_samples, images.shape[0])):
        # Input
        plt.subplot(3, num_samples, i+1)
        plt.imshow(images[i].cpu().squeeze(), cmap='gray')
        plt.title("Input X-ray")
        plt.axis('off')

        # Ground Truth
        plt.subplot(3, num_samples, i+1+num_samples)
        plt.imshow(masks[i].cpu().squeeze(), cmap='gray')
        plt.title("Ground Truth")
        plt.axis('off')

        # Prediction
        plt.subplot(3, num_samples, i+1+2*num_samples)
        plt.imshow(preds[i].cpu().squeeze(), cmap='gray')
        plt.title("Prediction")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# ==========================================
# 5. Main Execution
# ==========================================

if __name__ == "__main__":
    set_seed(SEED)
    print(f'Using device: {device}')

    # Define Paths
    BASE_PATH = "D:\\ml_med_data\\Infection Segmentation Data"
    TRAIN_COVID_PATH = os.path.join(BASE_PATH, "Train/COVID-19")
    IMAGES_DIR = os.path.join(TRAIN_COVID_PATH, "images")
    MASKS_DIR = os.path.join(TRAIN_COVID_PATH, "infection masks")

    # Check if paths exist
    if not os.path.exists(IMAGES_DIR) or not os.path.exists(MASKS_DIR):
        print("Dataset paths not found. Please check your directory structure.")
        exit()

    # Transformations
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
    ])

    # Dataset & DataLoader
    dataset = CovidDataset(IMAGES_DIR, MASKS_DIR, transform=transform)
    
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    print(f"Dataset Loaded. Training samples: {train_size}, Validation samples: {val_size}")

    # Model Init
    model = UNet(in_channels=1, out_channels=1).to(device)
    
    # Loss and Optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Train
    print("Starting training...")
    train_model(model, train_loader, val_loader, criterion, optimizer, EPOCHS)

    # Visualize
    print("Visualizing results...")
    visualize_results(model, val_loader, device)