# GMLFNet: Training on Kaggle

**Gradient-Based Meta-Learning with Fast Adaptation Weights for Robust Multi-Centre Polyp Segmentation**

## Setup Instructions
1. **Enable GPU**: Settings → Accelerator → GPU T4 x2
2. **Enable Internet**: Settings → Internet → On (required for dataset download & pretrained weights)
3. **Run All**: Click "Run All" to execute the full training pipeline

## What this notebook does
1. Installs dependencies (learn2learn, albumentations, timm, etc.)
2. Downloads polyp segmentation datasets from Google Drive
3. Creates all GMLFNet source modules inline
4. Builds the model and verifies architecture
5. Runs MAML meta-learning training (200 epochs)
6. Evaluates on test centers (ETIS-LaribPolypDB, CVC-300)
7. Saves results and checkpoints to Kaggle output

In [None]:
# Cell 1: Install Dependencies
!pip install -q learn2learn albumentations timm gdown opencv-python-headless pyyaml tqdm wandb
print("Dependencies installed successfully!")

In [None]:
# Cell 2: Download and Organize Datasets
import os
import shutil
import zipfile
from pathlib import Path

import gdown

DATASET_ROOT = Path("/kaggle/working/datasets")
DATASET_ROOT.mkdir(parents=True, exist_ok=True)

DATASET_URLS = {
    "TrainDataset": "https://drive.google.com/uc?id=1lODorfB33jbd-im-qrtUgWnZXxB94F55",
    "TestDataset": "https://drive.google.com/uc?id=1o8OfBvYE6K-EpDyvzsmMPndnUMwb540R",
}

for name, url in DATASET_URLS.items():
    zip_path = DATASET_ROOT / f"{name}.zip"
    extract_dir = DATASET_ROOT / name
    
    if not zip_path.exists():
        print(f"Downloading {name}...")
        gdown.download(url, str(zip_path), quiet=False)
    else:
        print(f"{name}.zip already exists, skipping download.")
    
    if not extract_dir.exists():
        print(f"Extracting {name}...")
        with zipfile.ZipFile(zip_path, "r") as zf:
            zf.extractall(str(DATASET_ROOT))
        print(f"Extracted to {extract_dir}")

# Organize per-center folders
centers = ["CVC-300", "CVC-ClinicDB", "CVC-ColonDB", "ETIS-LaribPolypDB", "Kvasir"]
test_dir = DATASET_ROOT / "TestDataset"

for center in centers:
    src = test_dir / center
    dst = DATASET_ROOT / center
    if src.exists() and not dst.exists():
        print(f"Organizing {center}...")
        shutil.copytree(str(src), str(dst))

# Print statistics
print("\n" + "=" * 50)
print("Dataset Statistics")
print("=" * 50)
for center in centers:
    img_dir = DATASET_ROOT / center / "images"
    if img_dir.exists():
        n = len(list(img_dir.iterdir()))
        print(f"  {center:25s}: {n:4d} images")
    else:
        print(f"  {center:25s}: NOT FOUND")

print("\nDataset setup complete!")

In [None]:
# Cell 3: Create project directory structure
import sys

PROJECT_ROOT = Path("/kaggle/working/GMLFNet")
for d in ["data", "models", "trainers", "utils", "configs"]:
    (PROJECT_ROOT / d).mkdir(parents=True, exist_ok=True)
    # Create __init__.py
    (PROJECT_ROOT / d / "__init__.py").write_text("")

# Add project root to sys.path
sys.path.insert(0, str(PROJECT_ROOT))
print(f"Project root: {PROJECT_ROOT}")
print(f"Added to sys.path")

In [None]:
# Cell 4a: Write data/augmentations.py
augmentations_code = '''
import albumentations as A
from albumentations.pytorch import ToTensorV2


def get_train_transforms(image_size=352):
    """Training augmentations for polyp segmentation."""
    return A.Compose([
        A.Resize(image_size, image_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ColorJitter(
            brightness=0.2, contrast=0.2,
            saturation=0.2, hue=0.1, p=0.5
        ),
        A.GaussianBlur(blur_limit=(3, 7), p=0.3),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
        ToTensorV2(),
    ])


def get_test_transforms(image_size=352):
    """Test/validation transforms (resize + normalize only)."""
    return A.Compose([
        A.Resize(image_size, image_size),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
        ToTensorV2(),
    ])
'''.strip()

(PROJECT_ROOT / "data" / "augmentations.py").write_text(augmentations_code)
print("Written: data/augmentations.py")

In [None]:
# Cell 4b: Write data/datasets.py
datasets_code = '''
import os
from pathlib import Path

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset

from .augmentations import get_train_transforms, get_test_transforms


class PolypCenterDataset(Dataset):
    """Dataset for a single polyp segmentation center."""

    def __init__(self, root, center_name, transform=None, image_size=352):
        self.root = Path(root)
        self.center_name = center_name
        self.image_size = image_size
        self.transform = transform

        self.image_dir = self.root / center_name / "images"
        self.mask_dir = self.root / center_name / "masks"

        if not self.image_dir.exists():
            raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
        if not self.mask_dir.exists():
            raise FileNotFoundError(f"Mask directory not found: {self.mask_dir}")

        valid_exts = {".png", ".jpg", ".jpeg", ".bmp", ".tif"}
        self.image_files = sorted([
            f for f in self.image_dir.iterdir()
            if f.suffix.lower() in valid_exts
        ])

        if len(self.image_files) == 0:
            raise RuntimeError(f"No images found in {self.image_dir}")

        self.mask_files = []
        for img_path in self.image_files:
            mask_path = self._find_mask(img_path)
            if mask_path is None:
                raise FileNotFoundError(
                    f"No matching mask for {img_path.name} in {self.mask_dir}"
                )
            self.mask_files.append(mask_path)

    def _find_mask(self, img_path):
        stem = img_path.stem
        for ext in [".png", ".jpg", ".jpeg", ".bmp", ".tif"]:
            mask_path = self.mask_dir / f"{stem}{ext}"
            if mask_path.exists():
                return mask_path
        return None

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        mask_path = self.mask_files[idx]

        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        mask = (mask > 128).astype(np.float32)

        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
        else:
            image = cv2.resize(image, (self.image_size, self.image_size))
            mask = cv2.resize(mask, (self.image_size, self.image_size))
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            mask = torch.from_numpy(mask).float()

        if mask.ndim == 2:
            mask = mask.unsqueeze(0)

        return {
            "image": image,
            "mask": mask,
            "center": self.center_name,
            "filename": img_path.name,
        }


def build_center_datasets(root, centers, transform=None, image_size=352):
    datasets = {}
    for center in centers:
        datasets[center] = PolypCenterDataset(
            root=root, center_name=center,
            transform=transform, image_size=image_size,
        )
    return datasets
'''.strip()

(PROJECT_ROOT / "data" / "datasets.py").write_text(datasets_code)
print("Written: data/datasets.py")

In [None]:
# Cell 4c: Write data/meta_sampler.py
meta_sampler_code = '''
import random
from collections import namedtuple
from typing import Dict, List

import torch
from torch.utils.data import DataLoader

from .datasets import PolypCenterDataset


Task = namedtuple("Task", [
    "support_images", "support_masks",
    "query_images", "query_masks",
    "center_name",
])


class CenterEpisodicSampler:
    """Creates meta-learning episodes where each task = one center."""

    def __init__(self, center_datasets, support_size=16, query_size=16, device=None):
        self.center_datasets = center_datasets
        self.support_size = support_size
        self.query_size = query_size
        self.device = device or torch.device("cpu")
        self.center_names = list(center_datasets.keys())

        total_needed = support_size + query_size
        for name, ds in center_datasets.items():
            if len(ds) < total_needed:
                print(f"Warning: {name} has {len(ds)} images but "
                      f"{total_needed} needed. Sampling with replacement.")

    def sample_episode(self):
        tasks = []
        for center_name in self.center_names:
            task = self._sample_task(center_name)
            tasks.append(task)
        return tasks

    def _sample_task(self, center_name):
        dataset = self.center_datasets[center_name]
        total_needed = self.support_size + self.query_size
        n = len(dataset)

        if n >= total_needed:
            indices = random.sample(range(n), total_needed)
        else:
            indices = random.choices(range(n), k=total_needed)

        support_indices = indices[:self.support_size]
        query_indices = indices[self.support_size:]

        support_images, support_masks = self._collate(dataset, support_indices)
        query_images, query_masks = self._collate(dataset, query_indices)

        return Task(
            support_images=support_images.to(self.device),
            support_masks=support_masks.to(self.device),
            query_images=query_images.to(self.device),
            query_masks=query_masks.to(self.device),
            center_name=center_name,
        )

    def _collate(self, dataset, indices):
        images, masks = [], []
        for idx in indices:
            sample = dataset[idx]
            images.append(sample["image"])
            masks.append(sample["mask"])
        return torch.stack(images), torch.stack(masks)

    def __len__(self):
        max_size = max(len(ds) for ds in self.center_datasets.values())
        return max(1, max_size // (self.support_size + self.query_size))
'''.strip()

(PROJECT_ROOT / "data" / "meta_sampler.py").write_text(meta_sampler_code)
print("Written: data/meta_sampler.py")

In [None]:
# Cell 4d: Write models/backbone.py
backbone_code = '''
import torch
import torch.nn as nn

try:
    import timm
except ImportError:
    timm = None


class Res2NetBackbone(nn.Module):
    """Res2Net-50 (v1b, 26w, 4s) encoder."""
    out_channels = [256, 512, 1024, 2048]

    def __init__(self, pretrained=True):
        super().__init__()
        if timm is None:
            raise ImportError("timm is required. Install: pip install timm")
        self.backbone = timm.create_model(
            "res2net50_26w_4s", pretrained=pretrained,
            features_only=True, out_indices=(1, 2, 3, 4),
        )

    def forward(self, x):
        return self.backbone(x)


class PVTv2B2Backbone(nn.Module):
    """PVTv2-B2 transformer encoder."""
    out_channels = [64, 128, 320, 512]

    def __init__(self, pretrained=True):
        super().__init__()
        if timm is None:
            raise ImportError("timm is required. Install: pip install timm")
        self.backbone = timm.create_model(
            "pvt_v2_b2", pretrained=pretrained,
            features_only=True, out_indices=(0, 1, 2, 3),
        )

    def forward(self, x):
        return self.backbone(x)


def get_backbone(name="res2net50", pretrained=True):
    if name == "res2net50":
        return Res2NetBackbone(pretrained=pretrained)
    elif name == "pvt_v2_b2":
        return PVTv2B2Backbone(pretrained=pretrained)
    else:
        raise ValueError(f"Unknown backbone: {name}")
'''.strip()

(PROJECT_ROOT / "models" / "backbone.py").write_text(backbone_code)
print("Written: models/backbone.py")

In [None]:
# Cell 4e: Write models/decoder.py
decoder_code = '''
import torch
import torch.nn as nn
import torch.nn.functional as F


class RFB(nn.Module):
    """Receptive Field Block for multi-scale feature enhancement."""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.branch0 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
        )
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, dilation=1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=3, dilation=3),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=5, dilation=5),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
        )
        self.conv_cat = nn.Sequential(
            nn.Conv2d(4 * out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
        )
        self.conv_res = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        cat = torch.cat([x0, x1, x2, x3], dim=1)
        out = self.conv_cat(cat) + self.conv_res(x)
        return F.relu(out, inplace=True)


class PartialDecoder(nn.Module):
    """Aggregates high-level features for initial prediction."""

    def __init__(self, channel):
        super().__init__()
        self.conv_upsample1 = nn.Sequential(
            nn.Conv2d(channel, channel, 3, padding=1),
            nn.BatchNorm2d(channel), nn.ReLU(inplace=True),
        )
        self.conv_upsample2 = nn.Sequential(
            nn.Conv2d(channel, channel, 3, padding=1),
            nn.BatchNorm2d(channel), nn.ReLU(inplace=True),
        )
        self.conv_concat = nn.Sequential(
            nn.Conv2d(3 * channel, channel, 3, padding=1),
            nn.BatchNorm2d(channel), nn.ReLU(inplace=True),
            nn.Conv2d(channel, channel, 3, padding=1),
            nn.BatchNorm2d(channel), nn.ReLU(inplace=True),
        )
        self.conv_out = nn.Conv2d(channel, 1, 1)

    def forward(self, f2, f3, f4):
        f4_up = F.interpolate(f4, size=f2.shape[2:], mode="bilinear", align_corners=False)
        f4_up = self.conv_upsample1(f4_up)
        f3_up = F.interpolate(f3, size=f2.shape[2:], mode="bilinear", align_corners=False)
        f3_up = self.conv_upsample2(f3_up)
        cat = torch.cat([f2, f3_up, f4_up], dim=1)
        fused = self.conv_concat(cat)
        out = self.conv_out(fused)
        return out, fused


class ReverseAttention(nn.Module):
    """Reverse Attention module."""

    def __init__(self, in_channels, channel):
        super().__init__()
        self.conv_input = nn.Sequential(
            nn.Conv2d(in_channels, channel, 1),
            nn.BatchNorm2d(channel), nn.ReLU(inplace=True),
        )
        self.conv_refine = nn.Sequential(
            nn.Conv2d(channel, channel, 3, padding=1),
            nn.BatchNorm2d(channel), nn.ReLU(inplace=True),
            nn.Conv2d(channel, channel, 3, padding=1),
            nn.BatchNorm2d(channel), nn.ReLU(inplace=True),
        )
        self.conv_out = nn.Conv2d(channel, 1, 1)

    def forward(self, x, prev_pred):
        reverse_mask = 1 - torch.sigmoid(prev_pred)
        x = self.conv_input(x)
        x = x * reverse_mask
        x = self.conv_refine(x)
        out = self.conv_out(x)
        return out


class MultiScaleDecoder(nn.Module):
    """Multi-scale decoder with RFB, Partial Decoder, and Reverse Attention."""

    def __init__(self, encoder_channels, decoder_channel=32):
        super().__init__()
        self.rfb2 = RFB(encoder_channels[1], decoder_channel)
        self.rfb3 = RFB(encoder_channels[2], decoder_channel)
        self.rfb4 = RFB(encoder_channels[3], decoder_channel)
        self.partial_decoder = PartialDecoder(decoder_channel)
        self.ra4 = ReverseAttention(encoder_channels[3], decoder_channel)
        self.ra3 = ReverseAttention(encoder_channels[2], decoder_channel)
        self.ra2 = ReverseAttention(encoder_channels[1], decoder_channel)

    def forward(self, features, modulations=None):
        f1, f2, f3, f4 = features
        input_size = f1.shape[2] * 4, f1.shape[3] * 4

        x2 = self.rfb2(f2)
        x3 = self.rfb3(f3)
        x4 = self.rfb4(f4)

        if modulations is not None and len(modulations) >= 3:
            gamma2, beta2 = modulations[0]
            gamma3, beta3 = modulations[1]
            gamma4, beta4 = modulations[2]
            x2 = gamma2 * x2 + beta2
            x3 = gamma3 * x3 + beta3
            x4 = gamma4 * x4 + beta4

        pred_init, fused = self.partial_decoder(x2, x3, x4)

        side_preds = []
        pred5 = F.interpolate(pred_init, size=input_size, mode="bilinear", align_corners=False)
        side_preds.append(pred5)

        pred4 = self.ra4(f4, F.interpolate(pred_init, size=f4.shape[2:], mode="bilinear", align_corners=False))
        pred4_up = F.interpolate(pred4, size=input_size, mode="bilinear", align_corners=False)
        side_preds.append(pred4_up)

        pred3 = self.ra3(f3, F.interpolate(pred4, size=f3.shape[2:], mode="bilinear", align_corners=False))
        pred3_up = F.interpolate(pred3, size=input_size, mode="bilinear", align_corners=False)
        side_preds.append(pred3_up)

        pred2 = self.ra2(f2, F.interpolate(pred3, size=f2.shape[2:], mode="bilinear", align_corners=False))
        main_pred = F.interpolate(pred2, size=input_size, mode="bilinear", align_corners=False)

        return main_pred, side_preds
'''.strip()

(PROJECT_ROOT / "models" / "decoder.py").write_text(decoder_code)
print("Written: models/decoder.py")

In [None]:
# Cell 4f: Write models/fast_adapt_weights.py
faw_code = '''
import torch
import torch.nn as nn
import torch.nn.functional as F


class FastAdaptationWeights(nn.Module):
    """Fast Adaptation Weights (FAW) module - thesis contribution.
    
    Generates per-layer FiLM modulation parameters (gamma, beta) that enable
    rapid domain adaptation within the MAML inner loop.
    """

    def __init__(self, encoder_channels, num_modulation_layers=3,
                 modulation_channels=32, hidden_dim=64, num_layers=2):
        super().__init__()
        self.num_modulation_layers = num_modulation_layers
        total_stats_dim = sum(encoder_channels)

        layers = []
        in_dim = total_stats_dim
        for i in range(num_layers - 1):
            layers.extend([nn.Linear(in_dim, hidden_dim), nn.ReLU(inplace=True)])
            in_dim = hidden_dim

        if num_layers == 1:
            layers.extend([nn.Linear(total_stats_dim, hidden_dim), nn.ReLU(inplace=True)])

        self.stats_encoder = nn.Sequential(*layers)

        self.gamma_heads = nn.ModuleList([
            nn.Linear(hidden_dim, modulation_channels)
            for _ in range(num_modulation_layers)
        ])
        self.beta_heads = nn.ModuleList([
            nn.Linear(hidden_dim, modulation_channels)
            for _ in range(num_modulation_layers)
        ])
        self._init_identity()

    def _init_identity(self):
        for head in self.gamma_heads:
            nn.init.zeros_(head.weight)
            nn.init.ones_(head.bias)
        for head in self.beta_heads:
            nn.init.zeros_(head.weight)
            nn.init.zeros_(head.bias)

    def forward(self, encoder_features):
        stats = []
        for feat in encoder_features:
            pooled = F.adaptive_avg_pool2d(feat, 1).flatten(1)
            stats.append(pooled)
        stats = torch.cat(stats, dim=1)

        h = self.stats_encoder(stats)

        modulations = []
        for gamma_head, beta_head in zip(self.gamma_heads, self.beta_heads):
            gamma = gamma_head(h).unsqueeze(-1).unsqueeze(-1)
            beta = beta_head(h).unsqueeze(-1).unsqueeze(-1)
            modulations.append((gamma, beta))

        return modulations

    def get_param_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
'''.strip()

(PROJECT_ROOT / "models" / "fast_adapt_weights.py").write_text(faw_code)
print("Written: models/fast_adapt_weights.py")

In [None]:
# Cell 4g: Write models/gmlf_net.py
gmlf_net_code = '''
import torch
import torch.nn as nn

from .backbone import get_backbone
from .decoder import MultiScaleDecoder
from .fast_adapt_weights import FastAdaptationWeights


class GMLFNet(nn.Module):
    """GMLFNet segmentation model."""

    def __init__(self, backbone_name="res2net50", decoder_channel=32,
                 faw_hidden_dim=64, faw_num_layers=2,
                 pretrained=True, use_faw=True):
        super().__init__()
        self.use_faw = use_faw

        self.encoder = get_backbone(backbone_name, pretrained=pretrained)
        enc_channels = self.encoder.out_channels

        if use_faw:
            self.faw = FastAdaptationWeights(
                encoder_channels=enc_channels,
                num_modulation_layers=3,
                modulation_channels=decoder_channel,
                hidden_dim=faw_hidden_dim,
                num_layers=faw_num_layers,
            )
        else:
            self.faw = None

        self.decoder = MultiScaleDecoder(
            encoder_channels=enc_channels,
            decoder_channel=decoder_channel,
        )

    def forward(self, x):
        features = self.encoder(x)
        modulations = None
        if self.use_faw and self.faw is not None:
            modulations = self.faw(features)
        main_pred, side_preds = self.decoder(features, modulations=modulations)
        return main_pred, side_preds

    def get_faw_parameters(self):
        if self.faw is not None:
            return list(self.faw.parameters())
        return []

    def get_non_faw_parameters(self):
        return list(self.encoder.parameters()) + list(self.decoder.parameters())

    def freeze_non_faw(self):
        for param in self.encoder.parameters():
            param.requires_grad_(False)
        for param in self.decoder.parameters():
            param.requires_grad_(False)
        if self.faw is not None:
            for param in self.faw.parameters():
                param.requires_grad_(True)

    def unfreeze_all(self):
        for param in self.parameters():
            param.requires_grad_(True)

    def print_param_summary(self):
        enc_params = sum(p.numel() for p in self.encoder.parameters())
        dec_params = sum(p.numel() for p in self.decoder.parameters())
        faw_params = self.faw.get_param_count() if self.faw else 0
        total = enc_params + dec_params + faw_params
        print(f"Parameter Summary:")
        print(f"  Encoder:  {enc_params:>10,}")
        print(f"  Decoder:  {dec_params:>10,}")
        print(f"  FAW:      {faw_params:>10,}")
        print(f"  Total:    {total:>10,}")
        print(f"  FAW ratio: {faw_params/total*100:.2f}% of total")


def build_model(cfg):
    model = GMLFNet(
        backbone_name=cfg.model.backbone,
        decoder_channel=cfg.model.decoder_channels[-1] if hasattr(cfg.model, "decoder_channels") else 32,
        faw_hidden_dim=cfg.model.faw_hidden_dim,
        faw_num_layers=cfg.model.faw_num_layers,
        pretrained=True,
        use_faw=True,
    )
    return model
'''.strip()

(PROJECT_ROOT / "models" / "gmlf_net.py").write_text(gmlf_net_code)
print("Written: models/gmlf_net.py")

In [None]:
# Cell 4h: Write models/losses.py
losses_code = '''
import torch
import torch.nn as nn
import torch.nn.functional as F


class StructureLoss(nn.Module):
    """Structure-aware loss from PraNet."""

    def forward(self, pred, mask):
        weit = 1 + 5 * torch.abs(
            F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask
        )
        wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction="none")
        wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

        pred_sigmoid = torch.sigmoid(pred)
        inter = ((pred_sigmoid * mask) * weit).sum(dim=(2, 3))
        union = ((pred_sigmoid + mask) * weit).sum(dim=(2, 3))
        wiou = 1 - (inter + 1) / (union - inter + 1)

        return (wbce + wiou).mean()


class BCEDiceLoss(nn.Module):
    """Combined BCE + Dice loss."""

    def __init__(self, bce_weight=0.5, dice_weight=0.5, smooth=1.0):
        super().__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.smooth = smooth

    def forward(self, pred, mask):
        bce = F.binary_cross_entropy_with_logits(pred, mask)
        pred_sigmoid = torch.sigmoid(pred)
        inter = (pred_sigmoid * mask).sum(dim=(2, 3))
        total = pred_sigmoid.sum(dim=(2, 3)) + mask.sum(dim=(2, 3))
        dice = 1 - (2 * inter + self.smooth) / (total + self.smooth)
        dice = dice.mean()
        return self.bce_weight * bce + self.dice_weight * dice


class GMLFNetLoss(nn.Module):
    """Composite loss with deep supervision for GMLFNet."""

    def __init__(self, structure_weight=1.0, side_weights=None):
        super().__init__()
        self.structure_loss = StructureLoss()
        self.structure_weight = structure_weight
        self.side_weights = side_weights or [0.5, 0.3, 0.2]

    def forward(self, predictions, mask):
        main_pred, side_preds = predictions

        if main_pred.shape[2:] != mask.shape[2:]:
            mask_resized = F.interpolate(
                mask, size=main_pred.shape[2:],
                mode="bilinear", align_corners=False
            )
        else:
            mask_resized = mask

        total_loss = self.structure_weight * self.structure_loss(main_pred, mask_resized)

        for i, side_pred in enumerate(side_preds):
            weight = self.side_weights[i] if i < len(self.side_weights) else 0.1
            if side_pred.shape[2:] != mask_resized.shape[2:]:
                side_mask = F.interpolate(
                    mask, size=side_pred.shape[2:],
                    mode="bilinear", align_corners=False
                )
            else:
                side_mask = mask_resized
            total_loss += weight * self.structure_loss(side_pred, side_mask)

        return total_loss
'''.strip()

(PROJECT_ROOT / "models" / "losses.py").write_text(losses_code)
print("Written: models/losses.py")

In [None]:
# Cell 4i: Write utils/metrics.py
metrics_code = '''
import numpy as np
import torch


class SegmentationMetrics:
    """Accumulates predictions and computes segmentation metrics."""

    def __init__(self, threshold=0.5, beta_sq=0.3):
        self.threshold = threshold
        self.beta_sq = beta_sq
        self.reset()

    def reset(self):
        self.dice_scores = []
        self.iou_scores = []
        self.precision_scores = []
        self.recall_scores = []
        self.mae_scores = []
        self.smeasure_scores = []
        self.emeasure_scores = []

    @torch.no_grad()
    def update(self, pred, mask):
        pred = pred.cpu().numpy()
        mask = mask.cpu().numpy()

        for i in range(pred.shape[0]):
            p = pred[i, 0]
            m = mask[i, 0]
            p_bin = (p >= self.threshold).astype(np.float32)
            m_bin = m.astype(np.float32)

            tp = (p_bin * m_bin).sum()
            fp = (p_bin * (1 - m_bin)).sum()
            fn = ((1 - p_bin) * m_bin).sum()

            dice = (2 * tp + 1e-8) / (2 * tp + fp + fn + 1e-8)
            self.dice_scores.append(dice)

            iou = (tp + 1e-8) / (tp + fp + fn + 1e-8)
            self.iou_scores.append(iou)

            precision = (tp + 1e-8) / (tp + fp + 1e-8)
            self.precision_scores.append(precision)

            recall = (tp + 1e-8) / (tp + fn + 1e-8)
            self.recall_scores.append(recall)

            mae = np.abs(p - m_bin).mean()
            self.mae_scores.append(mae)

            sm = self._compute_smeasure(p, m_bin)
            self.smeasure_scores.append(sm)

            em = self._compute_emeasure(p_bin, m_bin)
            self.emeasure_scores.append(em)

    def compute(self):
        n = len(self.dice_scores)
        if n == 0:
            return {"dice": 0.0, "iou": 0.0, "precision": 0.0,
                    "recall": 0.0, "fmeasure": 0.0, "mae": 1.0,
                    "smeasure": 0.0, "emeasure": 0.0}

        precision = np.mean(self.precision_scores)
        recall = np.mean(self.recall_scores)
        fmeasure = ((1 + self.beta_sq) * precision * recall + 1e-8) / \\
                   (self.beta_sq * precision + recall + 1e-8)

        return {
            "dice": float(np.mean(self.dice_scores)),
            "iou": float(np.mean(self.iou_scores)),
            "precision": float(precision),
            "recall": float(recall),
            "fmeasure": float(fmeasure),
            "mae": float(np.mean(self.mae_scores)),
            "smeasure": float(np.mean(self.smeasure_scores)),
            "emeasure": float(np.mean(self.emeasure_scores)),
        }

    def _compute_smeasure(self, pred, mask, alpha=0.5):
        y = mask.mean()
        if y == 0:
            return 1.0 - pred.mean()
        elif y == 1:
            return pred.mean()
        else:
            so = self._s_object(pred, mask)
            sr = self._s_region(pred, mask)
            return max(0.0, alpha * so + (1 - alpha) * sr)

    def _s_object(self, pred, mask):
        fg_pred = pred * mask
        o_fg = self._object_score(fg_pred, mask)
        bg_pred = (1 - pred) * (1 - mask)
        o_bg = self._object_score(bg_pred, 1 - mask)
        u = mask.mean()
        return u * o_fg + (1 - u) * o_bg

    def _object_score(self, pred, mask):
        x = pred[mask > 0.5]
        if len(x) == 0:
            return 0.0
        mu = x.mean()
        std = x.std() + 1e-8
        return 2 * mu / (mu ** 2 + 1 + std + 1e-8)

    def _s_region(self, pred, mask):
        h, w = mask.shape
        cx, cy = h // 2, w // 2
        score = 0.0
        for si, sj in [(slice(0, cx), slice(0, cy)),
                        (slice(0, cx), slice(cy, w)),
                        (slice(cx, h), slice(0, cy)),
                        (slice(cx, h), slice(cy, w))]:
            p_region = pred[si, sj]
            m_region = mask[si, sj]
            weight = m_region.size / (h * w)
            score += weight * self._ssim_like(p_region, m_region)
        return score

    def _ssim_like(self, pred, mask):
        mu_p = pred.mean()
        mu_m = mask.mean()
        sigma_p = pred.std() + 1e-8
        sigma_m = mask.std() + 1e-8
        sigma_pm = ((pred - mu_p) * (mask - mu_m)).mean()
        c1, c2 = 0.01 ** 2, 0.03 ** 2
        luminance = (2 * mu_p * mu_m + c1) / (mu_p ** 2 + mu_m ** 2 + c1)
        contrast = (2 * sigma_p * sigma_m + c2) / (sigma_p ** 2 + sigma_m ** 2 + c2)
        structure = (sigma_pm + c2 / 2) / (sigma_p * sigma_m + c2 / 2)
        return luminance * contrast * structure

    def _compute_emeasure(self, pred_bin, mask):
        if mask.sum() == 0 and pred_bin.sum() == 0:
            return 1.0
        if mask.sum() == 0 or pred_bin.sum() == 0:
            return 0.0
        mu_pred = pred_bin.mean()
        mu_mask = mask.mean()
        align_pred = pred_bin - mu_pred
        align_mask = mask - mu_mask
        align_matrix = 2 * (align_pred * align_mask) / \\
                       (align_pred ** 2 + align_mask ** 2 + 1e-8)
        enhanced = ((align_matrix + 1) ** 2) / 4
        return enhanced.mean()
'''.strip()

(PROJECT_ROOT / "utils" / "metrics.py").write_text(metrics_code)
print("Written: utils/metrics.py")

In [None]:
# Cell 4j: Write utils/misc.py
misc_code = '''
import os
import random
from pathlib import Path

import numpy as np
import torch
import yaml


class Config:
    """Simple nested config object from a dictionary."""

    def __init__(self, d):
        for k, v in d.items():
            if isinstance(v, dict):
                setattr(self, k, Config(v))
            elif isinstance(v, list):
                setattr(self, k, [Config(i) if isinstance(i, dict) else i for i in v])
            else:
                setattr(self, k, v)

    def __repr__(self):
        return str(self.__dict__)

    def to_dict(self):
        result = {}
        for k, v in self.__dict__.items():
            if isinstance(v, Config):
                result[k] = v.to_dict()
            else:
                result[k] = v
        return result


def load_config(config_path="configs/default.yaml"):
    with open(config_path, "r") as f:
        cfg_dict = yaml.safe_load(f)
    return Config(cfg_dict)


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device("cpu")
        print("Using CPU")
    return device


def save_checkpoint(model, optimizer, epoch, metrics, path):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    state = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "metrics": metrics,
    }
    torch.save(state, path)
    print(f"Checkpoint saved: {path}")


def load_checkpoint(model, optimizer=None, path=None):
    if path is None or not Path(path).exists():
        print("No checkpoint found, starting from scratch.")
        return 0, {}
    checkpoint = torch.load(path, map_location="cpu", weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    if optimizer is not None and "optimizer_state_dict" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    epoch = checkpoint.get("epoch", 0)
    metrics = checkpoint.get("metrics", {})
    print(f"Loaded checkpoint from epoch {epoch}")
    return epoch, metrics
'''.strip()

(PROJECT_ROOT / "utils" / "misc.py").write_text(misc_code)
print("Written: utils/misc.py")

In [None]:
# Cell 4k: Write trainers/meta_trainer.py
meta_trainer_code = '''
import time
from pathlib import Path

import torch
import torch.nn as nn
import learn2learn as l2l
from tqdm import tqdm

from data.meta_sampler import CenterEpisodicSampler
from models.losses import GMLFNetLoss
from utils.misc import save_checkpoint


class MAMLMetaTrainer:
    """MAML meta-trainer for GMLFNet."""

    def __init__(self, model, sampler, cfg, device):
        self.cfg = cfg
        self.device = device
        self.sampler = sampler

        self.maml = l2l.algorithms.MAML(
            model, lr=cfg.meta.inner_lr,
            first_order=cfg.meta.first_order,
            allow_unused=True, allow_nograd=True,
        )
        self.maml.to(device)

        self.outer_optimizer = torch.optim.Adam(
            self.maml.parameters(), lr=cfg.meta.outer_lr,
        )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.outer_optimizer, T_max=cfg.training.epochs, eta_min=1e-6,
        )

        self.loss_fn = GMLFNetLoss()
        self.inner_steps = cfg.meta.inner_steps
        self.selective_adaptation = True

    def meta_train_step(self):
        self.outer_optimizer.zero_grad()
        meta_loss = 0.0
        task_losses = {}

        episode = self.sampler.sample_episode()

        for task in episode:
            learner = self.maml.clone()

            if self.selective_adaptation:
                for name, param in learner.named_parameters():
                    if "faw" not in name:
                        param.requires_grad_(False)
                    else:
                        param.requires_grad_(True)

            for step in range(self.inner_steps):
                support_pred = learner(task.support_images)
                support_loss = self.loss_fn(support_pred, task.support_masks)
                learner.adapt(support_loss)

            if self.selective_adaptation:
                for param in learner.parameters():
                    param.requires_grad_(True)

            query_pred = learner(task.query_images)
            query_loss = self.loss_fn(query_pred, task.query_masks)
            meta_loss += query_loss
            task_losses[task.center_name] = query_loss.item()

        meta_loss /= len(episode)
        meta_loss.backward()

        torch.nn.utils.clip_grad_norm_(
            self.maml.parameters(), self.cfg.training.grad_clip,
        )
        self.outer_optimizer.step()

        return meta_loss.item(), task_losses

    def train_epoch(self, epoch):
        self.maml.train()
        steps_per_epoch = len(self.sampler)
        total_loss = 0.0

        pbar = tqdm(range(steps_per_epoch), desc=f"Epoch {epoch}")
        for step in pbar:
            loss, task_losses = self.meta_train_step()
            total_loss += loss
            task_str = " | ".join(
                f"{k[:8]}:{v:.4f}" for k, v in task_losses.items()
            )
            pbar.set_postfix_str(f"loss={loss:.4f} | {task_str}")

        self.scheduler.step()
        return total_loss / steps_per_epoch

    def train(self, evaluator=None, logger=None):
        best_dice = 0.0
        start_epoch = 0

        if self.cfg.training.resume:
            checkpoint = torch.load(
                self.cfg.training.resume,
                map_location=self.device, weights_only=False,
            )
            self.maml.module.load_state_dict(checkpoint["model_state_dict"])
            self.outer_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            start_epoch = checkpoint.get("epoch", 0) + 1
            best_dice = checkpoint.get("metrics", {}).get("mean_dice", 0.0)
            print(f"Resumed from epoch {start_epoch}, best_dice={best_dice:.4f}")

        for epoch in range(start_epoch, self.cfg.training.epochs):
            t0 = time.time()
            avg_loss = self.train_epoch(epoch)
            elapsed = time.time() - t0

            print(f"Epoch {epoch}: avg_loss={avg_loss:.4f}, time={elapsed:.1f}s, "
                  f"lr={self.scheduler.get_last_lr()[0]:.6f}")

            if evaluator and (epoch + 1) % self.cfg.logging.save_interval == 0:
                results = evaluator.full_evaluation(self.maml)
                mean_dice = sum(
                    r["dice"] for r in results.values()
                ) / len(results)

                print(f"  Eval mean_dice={mean_dice:.4f}")
                for center, metrics in results.items():
                    print(f"    {center}: dice={metrics[\"dice\"]:.4f}, "
                          f"iou={metrics[\"iou\"]:.4f}")

                if mean_dice > best_dice:
                    best_dice = mean_dice
                    save_checkpoint(
                        self.maml.module, self.outer_optimizer, epoch,
                        {"mean_dice": mean_dice, **results},
                        Path(self.cfg.logging.log_dir) / "best_model.pth",
                    )

            if (epoch + 1) % self.cfg.logging.save_interval == 0:
                save_checkpoint(
                    self.maml.module, self.outer_optimizer, epoch,
                    {"avg_loss": avg_loss},
                    Path(self.cfg.logging.log_dir) / f"checkpoint_epoch{epoch}.pth",
                )

        print(f"\\nTraining complete. Best mean Dice: {best_dice:.4f}")
'''.strip()

(PROJECT_ROOT / "trainers" / "meta_trainer.py").write_text(meta_trainer_code)
print("Written: trainers/meta_trainer.py")

In [None]:
# Cell 4l: Write trainers/evaluator.py
evaluator_code = '''
import torch
import torch.nn.functional as F
import learn2learn as l2l
from torch.utils.data import DataLoader

from utils.metrics import SegmentationMetrics
from models.losses import GMLFNetLoss


class Evaluator:
    """Evaluator for polyp segmentation models."""

    def __init__(self, test_datasets, cfg, device):
        self.test_datasets = test_datasets
        self.cfg = cfg
        self.device = device
        self.loss_fn = GMLFNetLoss()

    @torch.no_grad()
    def evaluate_zero_shot(self, model, center_name):
        if isinstance(model, l2l.algorithms.MAML):
            eval_model = model.module
        else:
            eval_model = model

        eval_model.eval()
        dataset = self.test_datasets[center_name]
        loader = DataLoader(dataset, batch_size=1, shuffle=False)
        metrics = SegmentationMetrics()

        for batch in loader:
            image = batch["image"].to(self.device)
            mask = batch["mask"].to(self.device)
            main_pred, _ = eval_model(image)
            pred = torch.sigmoid(main_pred)
            if pred.shape[2:] != mask.shape[2:]:
                pred = F.interpolate(
                    pred, size=mask.shape[2:],
                    mode="bilinear", align_corners=False,
                )
            metrics.update(pred, mask)

        return metrics.compute()

    def evaluate_few_shot(self, maml_model, center_name,
                          k_support=5, adaptation_steps=5):
        dataset = self.test_datasets[center_name]
        n = len(dataset)

        if n <= k_support:
            print(f"Warning: {center_name} has only {n} images")
            return self.evaluate_zero_shot(maml_model, center_name)

        indices = list(range(n))
        support_indices = indices[:k_support]
        query_indices = indices[k_support:]

        support_images = []
        support_masks = []
        for idx in support_indices:
            sample = dataset[idx]
            support_images.append(sample["image"])
            support_masks.append(sample["mask"])
        support_images = torch.stack(support_images).to(self.device)
        support_masks = torch.stack(support_masks).to(self.device)

        learner = maml_model.clone()
        for name, param in learner.named_parameters():
            if "faw" not in name:
                param.requires_grad_(False)
            else:
                param.requires_grad_(True)

        for step in range(adaptation_steps):
            pred = learner(support_images)
            loss = self.loss_fn(pred, support_masks)
            learner.adapt(loss)

        learner.eval()
        metrics = SegmentationMetrics()

        with torch.no_grad():
            for idx in query_indices:
                sample = dataset[idx]
                image = sample["image"].unsqueeze(0).to(self.device)
                mask = sample["mask"].unsqueeze(0).to(self.device)
                main_pred, _ = learner(image)
                pred = torch.sigmoid(main_pred)
                if pred.shape[2:] != mask.shape[2:]:
                    pred = F.interpolate(
                        pred, size=mask.shape[2:],
                        mode="bilinear", align_corners=False,
                    )
                metrics.update(pred, mask)

        return metrics.compute()

    def full_evaluation(self, model, mode="zero_shot"):
        results = {}
        for center_name in self.test_datasets:
            if mode == "few_shot" and isinstance(model, l2l.algorithms.MAML):
                results[center_name] = self.evaluate_few_shot(model, center_name)
            else:
                results[center_name] = self.evaluate_zero_shot(model, center_name)
        return results
'''.strip()

(PROJECT_ROOT / "trainers" / "evaluator.py").write_text(evaluator_code)
print("Written: trainers/evaluator.py")

In [None]:
# Cell 4m: Verify all modules are importable
print("Verifying module imports...")

# Force reimport
import importlib
for mod_name in list(sys.modules.keys()):
    if 'data.' in mod_name or 'models.' in mod_name or 'trainers.' in mod_name or 'utils.' in mod_name:
        del sys.modules[mod_name]

from data.augmentations import get_train_transforms, get_test_transforms
from data.datasets import PolypCenterDataset, build_center_datasets
from data.meta_sampler import CenterEpisodicSampler
from models.backbone import get_backbone
from models.decoder import MultiScaleDecoder
from models.fast_adapt_weights import FastAdaptationWeights
from models.gmlf_net import GMLFNet, build_model
from models.losses import GMLFNetLoss, StructureLoss
from utils.metrics import SegmentationMetrics
from utils.misc import Config, set_seed, get_device, save_checkpoint

print("All modules imported successfully!")

In [None]:
# Cell 5: Configuration
import yaml

config_dict = {
    "data": {
        "root": "/kaggle/working/datasets",
        "image_size": 352,
        "train_centers": ["Kvasir", "CVC-ClinicDB", "CVC-ColonDB"],
        "test_centers": ["ETIS-LaribPolypDB", "CVC-300"],
        "num_workers": 2,
        "pin_memory": True,
    },
    "model": {
        "backbone": "res2net50",
        "decoder_channels": [256, 128, 64, 32],
        "faw_hidden_dim": 64,
        "faw_num_layers": 2,
    },
    "meta": {
        "algorithm": "maml",
        "inner_lr": 0.01,
        "inner_steps": 5,
        "outer_lr": 0.001,
        "tasks_per_batch": 3,
        "support_size": 16,
        "query_size": 16,
        "first_order": True,
    },
    "training": {
        "epochs": 200,
        "seed": 42,
        "grad_clip": 1.0,
        "scheduler": "cosine",
        "warmup_epochs": 10,
        "resume": "",
    },
    "loss": {
        "bce_weight": 0.5,
        "dice_weight": 0.5,
        "structure_weight": 0.2,
    },
    "logging": {
        "backend": "tensorboard",
        "log_dir": "/kaggle/working/runs",
        "save_interval": 10,
        "wandb_project": "GMLFNet",
    },
}

# Save config
config_path = PROJECT_ROOT / "configs" / "default.yaml"
with open(config_path, "w") as f:
    yaml.dump(config_dict, f, default_flow_style=False)

cfg = Config(config_dict)
print("Configuration loaded:")
print(f"  Backbone: {cfg.model.backbone}")
print(f"  Train centers: {cfg.data.train_centers}")
print(f"  Test centers: {cfg.data.test_centers}")
print(f"  Inner LR: {cfg.meta.inner_lr}, Inner Steps: {cfg.meta.inner_steps}")
print(f"  Outer LR: {cfg.meta.outer_lr}")
print(f"  Epochs: {cfg.training.epochs}")
print(f"  FOMAML: {cfg.meta.first_order}")

In [None]:
# Cell 6: Build Model & Verify
import torch

set_seed(cfg.training.seed)
device = get_device()

model = build_model(cfg)
model.print_param_summary()

# Sanity check: forward pass with dummy input
print("\nRunning sanity check...")
dummy_input = torch.randn(2, 3, 352, 352).to(device)
model = model.to(device)
with torch.no_grad():
    main_pred, side_preds = model(dummy_input)
print(f"  Input shape:  {dummy_input.shape}")
print(f"  Output shape: {main_pred.shape}")
print(f"  Side outputs: {len(side_preds)} (shapes: {[s.shape for s in side_preds]})")
print("Sanity check passed!")

# Move back to CPU for MAML wrapping
model = model.cpu()
del dummy_input
torch.cuda.empty_cache()

In [None]:
# Cell 7: Build Datasets & Sampler
print("Building datasets...")

train_transform = get_train_transforms(cfg.data.image_size)
test_transform = get_test_transforms(cfg.data.image_size)

train_datasets = build_center_datasets(
    root=cfg.data.root,
    centers=cfg.data.train_centers,
    transform=train_transform,
    image_size=cfg.data.image_size,
)

test_datasets = build_center_datasets(
    root=cfg.data.root,
    centers=cfg.data.test_centers,
    transform=test_transform,
    image_size=cfg.data.image_size,
)

print("\nTraining datasets:")
for name, ds in train_datasets.items():
    print(f"  {name}: {len(ds)} images")

print("\nTest datasets:")
for name, ds in test_datasets.items():
    print(f"  {name}: {len(ds)} images")

# Create episodic sampler
sampler = CenterEpisodicSampler(
    center_datasets=train_datasets,
    support_size=cfg.meta.support_size,
    query_size=cfg.meta.query_size,
    device=device,
)
print(f"\nEpisodic sampler: {len(sampler)} episodes/epoch")

In [None]:
# Cell 8: Train (Meta-Learning)
from trainers.meta_trainer import MAMLMetaTrainer
from trainers.evaluator import Evaluator

print("Initializing MAML meta-trainer...")

trainer = MAMLMetaTrainer(
    model=model,
    sampler=sampler,
    cfg=cfg,
    device=device,
)

evaluator = Evaluator(
    test_datasets=test_datasets,
    cfg=cfg,
    device=device,
)

print("Starting meta-learning training...")
print(f"  Epochs: {cfg.training.epochs}")
print(f"  Inner steps: {cfg.meta.inner_steps}")
print(f"  Inner LR: {cfg.meta.inner_lr}")
print(f"  Outer LR: {cfg.meta.outer_lr}")
print(f"  FOMAML: {cfg.meta.first_order}")
print(f"  Selective adaptation (FAW only): True")
print(f"  Eval interval: every {cfg.logging.save_interval} epochs")
print("=" * 60)

trainer.train(evaluator=evaluator)

In [None]:
# Cell 9: Final Evaluation
import json

print("=" * 60)
print("FINAL EVALUATION")
print("=" * 60)

# Load best model
best_path = Path(cfg.logging.log_dir) / "best_model.pth"
if best_path.exists():
    checkpoint = torch.load(best_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    print(f"Loaded best model from epoch {checkpoint['epoch']}")
    print(f"Best mean Dice during training: {checkpoint['metrics'].get('mean_dice', 'N/A')}")
else:
    print("No best model checkpoint found, using current model.")

model = model.to(device)
model.eval()

# Evaluate on all centers (train + test)
all_centers = cfg.data.train_centers + cfg.data.test_centers
all_datasets = {**train_datasets, **test_datasets}

all_evaluator = Evaluator(test_datasets=all_datasets, cfg=cfg, device=device)

print("\n--- Zero-Shot Evaluation ---")
results = {}
for center in all_centers:
    metrics = all_evaluator.evaluate_zero_shot(model, center)
    results[center] = metrics

# Print formatted table
print(f"\n{'Center':25s} | {'Dice':>7s} | {'IoU':>7s} | {'F-meas':>7s} | {'MAE':>7s} | {'S-meas':>7s} | {'E-meas':>7s}")
print("-" * 90)
for center, m in results.items():
    marker = " *" if center in cfg.data.test_centers else ""
    print(f"{center + marker:25s} | {m['dice']:7.4f} | {m['iou']:7.4f} | {m['fmeasure']:7.4f} | "
          f"{m['mae']:7.4f} | {m['smeasure']:7.4f} | {m['emeasure']:7.4f}")

# Compute mean over test centers
test_results = {k: v for k, v in results.items() if k in cfg.data.test_centers}
mean_dice = sum(m['dice'] for m in test_results.values()) / len(test_results)
mean_iou = sum(m['iou'] for m in test_results.values()) / len(test_results)
print(f"\nMean over test centers: Dice={mean_dice:.4f}, IoU={mean_iou:.4f}")
print("\n* = unseen test center (zero-shot generalization)")

In [None]:
# Cell 10: Save Results to Kaggle Output
output_dir = Path("/kaggle/working")

# Save evaluation results
results_path = output_dir / "evaluation_results.json"
with open(results_path, "w") as f:
    json.dump(results, f, indent=2)
print(f"Results saved to: {results_path}")

# Copy best model to output
if best_path.exists():
    output_model = output_dir / "best_model.pth"
    shutil.copy2(best_path, output_model)
    print(f"Best model copied to: {output_model}")

# List all checkpoints
runs_dir = Path(cfg.logging.log_dir)
if runs_dir.exists():
    checkpoints = list(runs_dir.glob("*.pth"))
    print(f"\nAll checkpoints in {runs_dir}:")
    for cp in sorted(checkpoints):
        size_mb = cp.stat().st_size / (1024 * 1024)
        print(f"  {cp.name}: {size_mb:.1f} MB")

print("\nDone! Download results from the Output tab.")