<a href="https://colab.research.google.com/github/iammuhammad41/Medical-Image-Segmentation/blob/main/skin-cancer-segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import random
import numpy as np
import pandas as pd
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import albumentations as A
from albumentations.pytorch import ToTensorV2


DATASET_PATH = "/kaggle/input/a0-2025-medical-image-segmentation/Dataset"

TRAIN_IMAGE_DIR = os.path.join(DATASET_PATH, "Train/Image")
TRAIN_MASK_DIR = os.path.join(DATASET_PATH, "Train/Mask")
TEST_DIR = os.path.join(DATASET_PATH, "Test/Image")


IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256

# !ls -R "{DATASET_PATH}"


print("\nGPU check:")
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if 'failed' in gpu_info or 'command not found' in gpu_info:
  print('-> There is no gpu...')
else:
  print(gpu_info)

In [4]:
mask_files = sorted([f for f in os.listdir(TRAIN_MASK_DIR) if f.lower().endswith('.png')])


image_mask_pairs = [
    (os.path.join(TRAIN_MASK_DIR, name), os.path.join(TRAIN_IMAGE_DIR, name.replace(".png", ".jpg")))
    for name in mask_files
]

print(f"Total image-mask pairs found: {len(image_mask_pairs)}")



random.seed(42)
random.shuffle(image_mask_pairs)

split_idx = int(len(image_mask_pairs) * 0.8)
train_pairs = image_mask_pairs[:split_idx]
val_pairs = image_mask_pairs[split_idx:]


print(f"Number of training pairs: {len(train_pairs)}")
print(f"Number of validation pairs: {len(val_pairs)}")
print(f"\nExample training pair (Mask, Image): {train_pairs[0]}")


print("\nDisplaying a sample pair from the training set...")
fig, axs = plt.subplots(1, 2, figsize=(12, 6))


mask_sample = Image.open(train_pairs[0][0])
axs[0].imshow(mask_sample, cmap="gray")
axs[0].set_title("Image Mask")
axs[0].axis("off")


image_sample = Image.open(train_pairs[0][1])
axs[1].imshow(image_sample)
axs[1].set_title("Original Image")
axs[1].axis("off")

plt.show()


Total image-mask pairs found: 1087
Number of training pairs: 869
Number of validation pairs: 218


In [None]:
BATCH_SIZE = 8
NUM_WORKERS = 2


train_transforms = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=35, p=0.5),
    A.ColorJitter(p=0.2),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

val_transforms = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])



class SegmentationDataset(Dataset):
    """
    Custom PyTorch Dataset for loading image-mask pairs.
    It uses Albumentations for synchronized transformations on both
    the image and the mask.
    """
    def __init__(self, image_mask_pairs, transform=None):
        self.image_mask_pairs = image_mask_pairs
        self.transform = transform

    def __len__(self):
        return len(self.image_mask_pairs)

    def __getitem__(self, idx):
        mask_path, image_path = self.image_mask_pairs[idx]

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)


        mask = (mask / 255.0 > 0.5).astype(np.float32)

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
            mask = mask.unsqueeze(0)

        return image, mask



train_dataset = SegmentationDataset(
    image_mask_pairs=train_pairs,
    transform=train_transforms
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS
)


val_dataset = SegmentationDataset(
    image_mask_pairs=val_pairs,
    transform=val_transforms
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)



try:
    print("DataLoaders created successfully.")
    images, masks = next(iter(train_loader))
    print(f"Image batch shape from train_loader: {images.shape}")
    print(f"Mask batch shape from train_loader: {masks.shape}")

except Exception as e:
    print(f"An error occurred while verifying the DataLoader: {e}")

DataLoaders created successfully.
Image batch shape from train_loader: torch.Size([8, 3, 256, 256])
Mask batch shape from train_loader: torch.Size([8, 1, 256, 256])


In [5]:
class DoubleConv(nn.Module):
    """A block consisting of two sequential convolutional layers,
    each followed by BatchNorm and ReLU activation.
    (Convolution => [BatchNorm] => ReLU) * 2
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)



class RollasUnet(nn.Module):
    """
    The U-Net architecture for image segmentation.
    It consists of an encoder (contracting path), a bottleneck,
    and a decoder (expansive path) with skip connections.
    """
    def __init__(self, in_channels=3, out_channels=1):
        super(RollasUnet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels


        self.inc = DoubleConv(in_channels, 64)
        self.down1 = DoubleConv(64, 128)
        self.down2 = DoubleConv(128, 256)
        self.down3 = DoubleConv(256, 512)
        self.pool = nn.MaxPool2d(2)


        self.bottleneck = DoubleConv(512, 1024)


        self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(1024, 512)

        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(512, 256)

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(256, 128)

        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv4 = DoubleConv(128, 64)


        self.outc = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):

        skip1 = self.inc(x)
        skip2 = self.down1(self.pool(skip1))
        skip3 = self.down2(self.pool(skip2))
        skip4 = self.down3(self.pool(skip3))


        x = self.bottleneck(self.pool(skip4))


        x = self.up1(x)
        x = torch.cat([x, skip4], dim=1)
        x = self.conv1(x)

        x = self.up2(x)
        x = torch.cat([x, skip3], dim=1)
        x = self.conv2(x)

        x = self.up3(x)
        x = torch.cat([x, skip2], dim=1)
        x = self.conv3(x)

        x = self.up4(x)
        x = torch.cat([x, skip1], dim=1)
        x = self.conv4(x)


        logits = self.outc(x)
        return logits


device = "cuda" if torch.cuda.is_available() else "cpu"


model = RollasUnet(in_channels=3, out_channels=1).to(device)


dummy_input = torch.randn(BATCH_SIZE, 3, IMAGE_HEIGHT, IMAGE_WIDTH).to(device)


preds = model(dummy_input)


print("--- Model Sanity Check ---")
print(f"Input tensor shape:  {dummy_input.shape}")
print(f"Output tensor shape: {preds.shape}")
print(f"Model successfully created and tested.")

LEARNING_RATE = 1e-4
NUM_EPOCHS = 25

In [6]:
class DiceLoss(nn.Module):
    """
    Implements Dice Loss, calculated as 1 - Dice Coefficient.
    Useful for directly optimizing the Dice score metric.
    """
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)

        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_coeff = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        return 1 - dice_coeff

def dice_coefficient(preds, targets, smooth=1e-6):
    """
    Calculates the Dice Coefficient metric for evaluation.
    Args:
        preds (torch.Tensor): The model's raw output logits.
        targets (torch.Tensor): The ground truth masks.
    Returns:
        float: The Dice Coefficient score.
    """
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()
    intersection = (preds * targets).sum()
    dice = (2. * intersection + smooth) / (preds.sum() + targets.sum() + smooth)
    return dice.item()

def combined_loss(pred, target):
    """
    A combined loss function that balances between BCE and Dice loss.
    This helps with both pixel-level accuracy and segmentation overlap.
    """
    bce = nn.BCEWithLogitsLoss()
    dice = DiceLoss()
    return 0.5 * bce(pred, target) + 0.5 * dice(pred, target)


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
NUM_EPOCHS = 25

model = RollasUnet(in_channels=3, out_channels=1).to(DEVICE)

loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_losses = []
val_losses = []
val_dice_scores = []
best_val_dice = 0.0




print("--- Starting Training ---")

for epoch in range(NUM_EPOCHS):
    model.train()
    total_train_loss = 0.0
    train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Training]", leave=False)

    for data, targets in train_loop:
        data = data.to(device=DEVICE)
        targets = targets.float().to(device=DEVICE)

        predictions = model(data)
        loss = loss_fn(predictions, targets)
        total_train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loop.set_postfix(loss=loss.item())

    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    model.eval()
    total_val_dice = 0.0
    total_val_loss = 0.0

    with torch.no_grad():
        for data, targets in val_loader:
            data = data.to(device=DEVICE)
            targets = targets.float().to(device=DEVICE)
            predictions = model(data)

            val_loss = loss_fn(predictions, targets)
            total_val_loss += val_loss.item()

            dice_score = dice_coefficient(predictions, targets)
            total_val_dice += dice_score

    avg_val_loss = total_val_loss / len(val_loader)
    avg_val_dice = total_val_dice / len(val_loader)
    val_losses.append(avg_val_loss)
    val_dice_scores.append(avg_val_dice)

    if avg_val_dice > best_val_dice:
        best_val_dice = avg_val_dice
        torch.save(model.state_dict(), "best_model.pth")
        print(f"--> New best model saved with Dice score: {avg_val_dice:.4f}")

    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | "
          f"Train Loss: {avg_train_loss:.4f} | "
          f"Val Loss: {avg_val_loss:.4f} | "
          f"Val Dice: {avg_val_dice:.4f}")

print("\n--- Training Finished ---")
print(f"Best validation Dice score achieved: {best_val_dice:.4f}")

In [7]:
def visualize_predictions(model, loader, device, num_images_to_show=3):
    """
    Visualizes model predictions against the ground truth masks.

    Args:
        model (torch.nn.Module): The trained model.
        loader (DataLoader): The DataLoader to get samples from (e.g., val_loader).
        device (str): The device to run the model on ("cuda" or "cpu").
        num_images_to_show (int): The number of sample images to display.
    """
    print("\n--- Visualizing a few predictions... ---")
    model.eval()

    try:
        images, masks = next(iter(loader))
    except StopIteration:
        print("DataLoader is empty. Cannot visualize.")
        return

    images = images.to(device)

    with torch.no_grad():

        preds = torch.sigmoid(model(images))

        preds = (preds > 0.5).float()


    images = images.cpu().numpy()
    masks = masks.cpu().numpy()
    preds = preds.cpu().numpy()

    num_to_show = min(num_images_to_show, len(images))


    fig, axs = plt.subplots(num_to_show, 3, figsize=(15, num_to_show * 5))

    for i in range(num_to_show):

        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])


        img_display = np.transpose(images[i], (1, 2, 0))

        img_display = std * img_display + mean

        img_display = np.clip(img_display, 0, 1)


        ax_row = axs[i] if num_to_show > 1 else axs

        ax_row[0].imshow(img_display)
        ax_row[0].set_title("Input Image")
        ax_row[0].axis("off")

        ax_row[1].imshow(masks[i].squeeze(), cmap="gray")
        ax_row[1].set_title("Ground Truth Mask")
        ax_row[1].axis("off")

        ax_row[2].imshow(preds[i].squeeze(), cmap="gray")
        ax_row[2].set_title("Predicted Mask")
        ax_row[2].axis("off")

    plt.tight_layout()
    plt.show()



visualize_predictions(model, val_loader, device=DEVICE)

In [None]:
class TestDataset(Dataset):
    """
    Custom PyTorch Dataset for loading test images.
    It returns the transformed image, its filename, and its original dimensions.
    """
    def __init__(self, test_dir, transform=None):
        self.test_dir = test_dir
        self.image_names = sorted(os.listdir(test_dir))
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.test_dir, image_name)

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        original_height, original_width, _ = image.shape

        if self.transform:
            transformed = self.transform(image=image)
            image = transformed['image']

        return image, image_name, (original_height, original_width)



def generate_submission_and_visualize(model, device, test_dir, num_to_show=5):
    """
    Uses the trained model to generate predictions on the test set,
    creates a submission file, and visualizes some results.
    """
    test_transforms = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
        ToTensorV2(),
    ])

    test_dataset = TestDataset(test_dir, transform=test_transforms)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    model.eval()



    images_to_display = []
    masks_to_display = []

    print("\n--- Generating predictions on the test set... ---")
    with torch.no_grad():
        for images, image_names, original_dims in tqdm(test_loader, desc="Predicting"):
            images = images.to(device)

            preds = torch.sigmoid(model(images))
            preds = (preds > 0.5).cpu().numpy()

            for i in range(len(preds)):
                pred_mask_256 = preds[i].squeeze().astype(np.uint8)
                original_h, original_w = original_dims[0][i].item(), original_dims[1][i].item()
                resized_mask = cv2.resize(pred_mask_256, (original_w, original_h), interpolation=cv2.INTER_NEAREST)



                if len(images_to_display) < num_to_show:
                    original_image_path = os.path.join(test_dir, image_names[i])
                    images_to_display.append(original_image_path)
                    masks_to_display.append(resized_mask)



    print("\n--- Visualizing a few test predictions (in original size) ---")
    for i in range(len(images_to_display)):
        original_image = cv2.imread(images_to_display[i])
        original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
        predicted_mask = masks_to_display[i]

        fig, axs = plt.subplots(1, 2, figsize=(15, 7))
        axs[0].imshow(original_image)
        axs[0].set_title(f"Original Test Image: {os.path.basename(images_to_display[i])}")
        axs[0].axis('off')

        axs[1].imshow(predicted_mask, cmap='gray')
        axs[1].set_title("Predicted Mask (Resized)")
        axs[1].axis('off')
        plt.show()


In [8]:
model = RollasUnet(in_channels=3, out_channels=1).to(DEVICE)

print("Loading best model weights from 'best_model.pth'...")
try:
    model.load_state_dict(torch.load("best_model.pth", map_location=DEVICE))
except FileNotFoundError:
    print("Error: 'best_model.pth' not found. Please ensure the model was trained and saved correctly.")
else:
    generate_submission_and_visualize(model, DEVICE, TEST_DIR)