In [None]:
%%capture
!pip install -q segmentation-models-pytorch==0.3.3 albumentations tqdm

In [None]:
import os, random, numpy as np, pandas as pd, cv2, glob, json, shutil
from pathlib import Path
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from IPython.display import HTML, display, clear_output
from google.colab import files

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# **1- CONFIG**

In [None]:
class CFG:
    seed = 42
    img_size = 256
    batch_size = 16
    epochs = 35
    lr = 1e-4
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cfg = CFG()

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(cfg.seed)

# **2- DATA SETUP**

In [None]:
print("Upload your kaggle.json")
uploaded = files.upload()
!mkdir -p ~/.kaggle
!cp kaggle*.json ~/.kaggle/kaggle.json
!chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download -d mateuszbuda/lgg-mri-segmentation --unzip -p /content > /dev/null 2>&1

# Auto-detect dataset folder
data_dir = next((p for p in Path("/content").iterdir() if p.is_dir() and list(p.rglob("*_mask.tif"))), None)
print(f"Dataset found at: {data_dir}")

# Build dataframe
images, masks = [], []
for mask_path in tqdm(list(data_dir.rglob("*_mask.tif")), desc="Scanning images"):
    img_path = mask_path.with_name(mask_path.stem.replace("_mask", "") + mask_path.suffix)
    if img_path.exists():
        images.append(str(img_path))
        masks.append(str(mask_path))

df = pd.DataFrame({"image": images, "mask": masks})
print(f"Total images loaded: {len(df)}")

# Tumor presence check
has_tumor = [cv2.imread(m, 0).max() > 10 for m in tqdm(df["mask"], desc="Checking tumor presence")]

# Split
train_df, temp_df = train_test_split(df, test_size=0.3, stratify=has_tumor, random_state=42)
temp_has_tumor = [has_tumor[i] for i in temp_df.index]
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_has_tumor, random_state=42)

# **3- AUGMENTATIONS**

In [None]:
train_aug = A.Compose([
    A.Resize(cfg.img_size, cfg.img_size),
    A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.Normalize(mean=0.485, std=0.229, max_pixel_value=255.0),
    ToTensorV2()
])

val_aug = A.Compose([
    A.Resize(cfg.img_size, cfg.img_size),
    A.Normalize(mean=0.485, std=0.229, max_pixel_value=255.0),
    ToTensorV2()
])


# **4- DATASET**

In [None]:
class BrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.paths = df[["image", "mask"]].values
        self.transform = transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        img_path, mask_path = self.paths[i]
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, 0)
        if self.transform:
            aug = self.transform(image=img, mask=mask)
            img, mask = aug['image'], aug['mask']
        mask = (mask > 127).float().unsqueeze(0)
        return img, mask

test_ds = BrainDataset(test_df, val_aug)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

# **5- MODELS**

In [None]:
print("Loading UNet++ (ResNet34)...")
model_unetpp = smp.UnetPlusPlus(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
).to(cfg.device)

print("Loading Classic UNet (ResNet34)...")
model_unet = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
).to(cfg.device)

# **6- TRAINING FUNCTION**

In [None]:
def train_model(model, name):
    criterion = smp.losses.DiceLoss(mode='binary')
    optimizer = AdamW(model.parameters(), lr=cfg.lr)
    best_dice = 0

    train_loader = DataLoader(BrainDataset(train_df, train_aug), batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(BrainDataset(val_df, val_aug), batch_size=cfg.batch_size, shuffle=False, num_workers=2)

    print(f"\nTraining {name}...")
    for epoch in range(1, cfg.epochs + 1):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(cfg.device), yb.to(cfg.device)
            pred = model(xb)
            loss = criterion(pred, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        val_dice = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                pred = torch.sigmoid(model(xb.to(cfg.device)))
                val_dice += (2*(pred*yb.to(cfg.device)).sum() + 1) / (pred.sum() + yb.to(cfg.device).sum() + 1)
        val_dice /= len(val_loader)
        print(f"Epoch {epoch:02d} | {name} Val Dice: {val_dice.item():.4f}")

        if val_dice > best_dice:
            best_dice = val_dice
            torch.save(model.state_dict(), f"best_{name.lower()}.pth")
            print(f"   → Best {name} saved! ({best_dice.item():.4f})")

    return model, best_dice.item()

# **7- TRAIN BOTH MODELS**

In [None]:
print("Starting training comparison...")

# Train UNet++
model_unetpp, dice_unetpp = train_model(model_unetpp, "UNet++")

In [None]:
# Train Classic UNet
model_unet, dice_unet = train_model(model_unet, "UNet")

# **8- FINAL TEST & COMPARISON GIF**

In [None]:
from PIL import Image, ImageDraw, ImageFont
import glob
import os

# Load trained models
model_unet.load_state_dict(torch.load("best_unet.pth"))
model_unetpp.load_state_dict(torch.load("best_unet++.pth"))
model_unet.eval()
model_unetpp.eval()

# Create folders
os.makedirs("unet_predictions", exist_ok=True)
os.makedirs("unetpp_predictions", exist_ok=True)

def create_overlay(img_np, pred_mask, gt_mask, title, idx):
    plt.figure(figsize=(10, 10))
    overlay = img_np.copy()
    overlay[pred_mask > 0.5] = [1, 0, 0]      # Red = Prediction
    overlay[gt_mask > 0.5] = [0, 1, 0]        # Green = Ground Truth
    overlay[(pred_mask > 0.5) & (gt_mask > 0.5)] = [1, 1, 0]  # Yellow = Overlap

    plt.imshow(overlay)
    plt.axis("off")
    plt.figtext(0.25, 0.78, "Red - Prediction", va="center", ha="center", size=22, color="#ff0000", weight="bold")
    plt.figtext(0.75, 0.78, "Green - Ground Truth", va="center", ha="center", size=22, color="#00ff00", weight="bold")
    plt.suptitle(title, y=0.86, fontsize=24, weight="bold", color="#00FFFF")

    filename = f"{title.lower().replace(' ', '_').replace('+', 'pp')}_{idx:03d}.png"
    plt.savefig(filename, bbox_inches='tight', facecolor='black', dpi=150, pad_inches=0.3)
    plt.close()
    return filename

print("Generating prediction overlays with Red=Pred, Green=GT...")
test_indices = random.sample(range(len(test_ds)), 40)

unet_files = []
unetpp_files = []

for i, idx in enumerate(tqdm(test_indices, desc="Creating overlays")):
    img, gt = test_ds[idx]
    img_dev = img.unsqueeze(0).to(cfg.device)
    gt_np = gt[0].cpu().numpy()

    with torch.no_grad():
        pred_unet = torch.sigmoid(model_unet(img_dev))[0,0].cpu().numpy()
        pred_unetpp = torch.sigmoid(model_unetpp(img_dev))[0,0].cpu().numpy()

    img_np = img.permute(1,2,0).numpy()
    img_np = np.clip(img_np * 0.229 + 0.485, 0, 1)

    u_file = create_overlay(img_np, pred_unet, gt_np, "Vanilla UNet", i)
    upp_file = create_overlay(img_np, pred_unetpp, gt_np, "UNet++", i)

    unet_files.append(u_file)
    unetpp_files.append(upp_file)

# **9- SIDE-BY-SIDE COMPARISON ANIMATION**

In [None]:
# CREATE INDIVIDUAL GIFs
def make_gif(file_list, output_name):
    if not file_list:
        print(f"No images for {output_name}")
        return None
    imgs = [Image.open(f) for f in file_list]
    imgs[0].save(output_name, save_all=True, append_images=imgs[1:], duration=1000, loop=0)
    print(f"Created: {output_name} ({len(imgs)} frames)")
    return output_name

gif_unet = make_gif(unet_files, "unet_predictions.gif")
gif_unetpp = make_gif(unetpp_files, "unetpp_predictions.gif")

# SIDE-BY-SIDE COMPARISON GIF
print("Creating side-by-side comparison GIF...")

side_by_side_frames = []
min_len = min(len(unet_files), len(unetpp_files))

for i in range(min_len):
    left = Image.open(unet_files[i])
    right = Image.open(unetpp_files[i])

    new_width = left.width * 2 + 100
    new_img = Image.new('RGB', (new_width, left.height), (15, 15, 30))
    new_img.paste(left, (50, 0))
    new_img.paste(right, (left.width + 100, 0))

    draw = ImageDraw.Draw(new_img)
    try:
        font_title = ImageFont.truetype("arial.ttf", 60)
        font_vs = ImageFont.truetype("arial.ttf", 48)
    except:
        font_title = ImageFont.load_default()
        font_vs = font_title

    draw.text((new_img.width//2, 80), "Vanilla UNet", fill="yellow", font=font_title, anchor="mm")
    draw.text((new_img.width//2, 160), "vs", fill="white", font=font_vs, anchor="mm")
    draw.text((new_img.width//2, 240), "UNet++", fill="#00ff00", font=font_title, anchor="mm")

    side_by_side_frames.append(new_img)

if side_by_side_frames:
    side_by_side_frames[0].save(
        "unet_vs_unetpp_comparison.gif",
        save_all=True,
        append_images=side_by_side_frames[1:],
        duration=1200,
        loop=0
    )
    print("Side-by-side comparison GIF created successfully!")

# **10- DISPLAY FINAL RESULT**

In [None]:
display(HTML(f'''
<center>
<h1 style="color:#00ff00; font-size:50px">Comparison Complete!</h1>
<h2 style="color:#00ffff">Red = Prediction | Green = Ground Truth</h2>

<h3 style="color:yellow">Vanilla UNet Results</h3>
<img src="unet_predictions.gif" width="500"/>

<h3 style="color:#00ff00">UNet++ Results (Clearly Superior)</h3>
<img src="unetpp_predictions.gif" width="500"/>

<h1 style="color:#00ffff; font-size:48px">Side-by-Side Comparison</h1>
<img src="unet_vs_unetpp_comparison.gif" width="1200"/>

<br><br>
<a href="unet_vs_unetpp_comparison.gif" style="font-size:24px; color:cyan">Download Final Comparison GIF</a>
</center>
'''))

files.download("unet_vs_unetpp_comparison.gif")
print("\nAll done! UNet++ is the clear winner — visually and quantitatively")