In [1]:
import os
import time
import glob
import shutil

import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import segmentation_models_pytorch as smp
from matplotlib import pyplot as plt
from PIL import Image
import kagglehub

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Parameters
IMG_SIZE = 256
BATCH_SIZE = 8
EPOCHS = 20
LEARNING_RATE = 1e-3
if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"

ENCODER_NAME = "resnext50_32x4d"
ENCODER_DEPTH = 5
ENCODER_WEIGHTS = "imagenet"
# ENCODER_WEIGHTS = None
DECODER_CHANNELS = [256, 128, 64, 32, 16]

In [3]:
# Model
model = smp.UnetPlusPlus(
    encoder_name=ENCODER_NAME,  # Backbone
    encoder_depth=ENCODER_DEPTH,
    encoder_weights=ENCODER_WEIGHTS,  # Pre-trained on ImageNet
    decoder_channels=DECODER_CHANNELS[:ENCODER_DEPTH],
    in_channels=3,  # RGB images
    classes=1,  # Binary segmentation
    activation=None,  # Logits for BCEWithLogitsLoss
).to(DEVICE)

In [4]:
class BrainMRIDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)

        # Convert mask to tensor (single channel)
        mask = (
            torch.tensor(np.array(mask), dtype=torch.float32) / 255.0
        )  # Normalize to [0, 1]
        mask = mask.unsqueeze(0)  # Add channel dimension

        return image, mask

In [5]:
# Transforms
transform = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [6]:
# Paths
path = kagglehub.dataset_download("mateuszbuda/lgg-mri-segmentation")
data_dir = os.path.join(path, "kaggle_3m")
patients = os.listdir(data_dir)
image_paths = []
mask_paths = []

for patient in patients:
    patient_dir = os.path.join(data_dir, patient)
    if os.path.isdir(patient_dir):
        images = sorted(glob.glob(os.path.join(patient_dir, "*[!mask].tif")))
        masks = [f"{image.rsplit(".", 1)[0]}_mask.tif" for image in images]
        image_paths.extend(images)
        mask_paths.extend(masks)

In [7]:
# Train/Validation Split
split_idx = int(0.8 * len(image_paths))
train_images, val_images = image_paths[:split_idx], image_paths[split_idx:]
train_masks, val_masks = mask_paths[:split_idx], mask_paths[split_idx:]

In [8]:
# Datasets and Dataloaders
train_dataset = BrainMRIDataset(train_images, train_masks, transform=transform)
val_dataset = BrainMRIDataset(val_images, val_masks, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [9]:
# Loss and Optimizer
class Dice(torch.nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(Dice, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = F.sigmoid(inputs)

        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        return dice


class DiceBCELoss(torch.nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = F.sigmoid(inputs)

        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.0 * intersection + smooth) / (
            inputs.sum() + targets.sum() + smooth
        )
        BCE = F.binary_cross_entropy(inputs, targets, reduction="mean")
        Dice_BCE = BCE + dice_loss

        return Dice_BCE

In [10]:
# criterion = torch.nn.BCEWithLogitsLoss()
# criterion = DiceLoss()
dice = Dice()
# criterion = DiceLoss()
criterion = DiceBCELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
optimizer = torch.optim.Adamax(model.parameters(), lr=LEARNING_RATE)

In [11]:
# Training Loop
writer = SummaryWriter()
best_val_loss = float("inf")
for epoch in range(EPOCHS):
    tic = time.time()
    model.train()
    train_loss = 0
    train_dice = 0

    for images, masks in train_loader:
        images, masks = images.to(DEVICE), masks.to(DEVICE)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        train_loss += loss.item()
        train_dice += dice(outputs, masks).item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss /= len(train_loader)
    train_dice /= len(train_loader)

    # Validation Loop
    model.eval()
    val_loss = 0
    val_dice = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()
            val_dice += dice(outputs, masks).item()

    val_loss /= len(val_loader)
    val_dice /= len(val_loader)

    print(
        f"Epoch {epoch + 1}/{EPOCHS}, "
        f"Train DICE: {train_dice:.4f}, "
        f"Train Loss: {train_loss:.4f}, "
        f"Val DICE: {val_dice:.4f}, "
        f"Val Loss: {val_loss:.4f}"
    )
    toc = time.time()
    print(f"Duration: {toc - tic}s")
    writer.add_scalar("train/dice", train_dice, epoch)
    writer.add_scalar("train/loss", train_loss, epoch)
    writer.add_scalar("validation/dice", val_dice, epoch)
    writer.add_scalar("validation/loss", val_loss, epoch)

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_unetpp_model.pth")

print("Training complete!")

Epoch 1/20, Train DICE: 0.5421, Train Loss: 0.5210, Val DICE: 0.3841, Val Loss: 0.6278
Duration: 159.87156796455383s
Epoch 2/20, Train DICE: 0.8088, Train Loss: 0.2061, Val DICE: 0.4049, Val Loss: 0.6105
Duration: 169.11816596984863s
Epoch 3/20, Train DICE: 0.8405, Train Loss: 0.1724, Val DICE: 0.4345, Val Loss: 0.5787
Duration: 169.9835970401764s
Epoch 4/20, Train DICE: 0.8480, Train Loss: 0.1644, Val DICE: 0.4474, Val Loss: 0.5906
Duration: 170.3609549999237s
Epoch 5/20, Train DICE: 0.8530, Train Loss: 0.1586, Val DICE: 0.4409, Val Loss: 0.6003
Duration: 169.8536229133606s
Epoch 6/20, Train DICE: 0.8575, Train Loss: 0.1533, Val DICE: 0.5618, Val Loss: 0.4510
Duration: 169.5381441116333s
Epoch 7/20, Train DICE: 0.8812, Train Loss: 0.1285, Val DICE: 0.6525, Val Loss: 0.3679
Duration: 169.78781485557556s
Epoch 8/20, Train DICE: 0.8694, Train Loss: 0.1410, Val DICE: 0.7106, Val Loss: 0.3072
Duration: 169.60931277275085s
Epoch 9/20, Train DICE: 0.8794, Train Loss: 0.1311, Val DICE: 0.7507

In [12]:
model.load_state_dict(torch.load("best_unetpp_model.pth", weights_only=True))
model.eval()
for img_path, mask_path in zip(val_images, val_masks):
    image = Image.open(img_path).convert("RGB")
    mask = Image.open(mask_path).convert("L")

    image = transform(image).to(DEVICE).unsqueeze(0)
    output = model(image)

    dir_name = img_path.rsplit("/", 1)[1].rsplit(".", 1)[0]
    dir_name = f"validation/{dir_name}"
    if not os.path.exists("validation"):
        os.mkdir("validation")
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)
    shutil.copy2(img_path, dir_name)
    shutil.copy2(mask_path, dir_name)

    # Save the mask as an image
    plt.figure(figsize=(4, 4))
    plt.imshow(mask, cmap="gray")
    plt.axis("off")
    output_path = os.path.join(dir_name, "output.tif")
    plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
    plt.close()