# Brain MRI Tumor Segmentation — All-in-One Notebook

Run this cell to install dependencies if needed.


In [None]:
!pip install segmentation-models-pytorch nibabel albumentations scikit-image matplotlib scikit-learn


## config.py

Running content from `config.py`.


In [None]:
"""
============================================================================
CONFIG — Central Configuration for Brain MRI Tumor Segmentation
============================================================================
All hyperparameters, paths, and device settings in one place.
============================================================================
"""

import torch
from pathlib import Path

# ── Paths ──────────────────────────────────────────────────────────────────
PROJECT_ROOT = Path.cwd()
DATA_ROOT = PROJECT_ROOT / "Data"
CHECKPOINT_DIR = PROJECT_ROOT / "checkpoints"
OUTPUT_DIR = PROJECT_ROOT / "outputs"
LOG_FILE = PROJECT_ROOT / "training_log.csv"

# Create directories
CHECKPOINT_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)

# ── Device ─────────────────────────────────────────────────────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 2       # 2 for laptop (fewer CPU cores, shared thermal)
PIN_MEMORY = torch.cuda.is_available()
GPU_VRAM_GB = 4       # RTX 3050 Laptop GPU

# ── Data ───────────────────────────────────────────────────────────────────
IMAGE_SIZE = 224               # Optimized for 4GB VRAM (224 vs 256 saves ~25% memory)
NUM_MODALITIES = 4             # T1c, T1n, T2-FLAIR, T2w
MODALITY_SUFFIXES = [          # Order matters — stacked as input channels
    "brain_t1c",
    "brain_t1n",
    "brain_t2f",
    "brain_t2w",
]
MASK_SUFFIX = "tumorMask"

# Subset ratio for prototyping (0.15 = 15% of patients)
SUBSET_RATIO = 0.15

# Minimum brain tissue fraction in a slice to include it
# (filters out mostly-empty slices that add noise)
MIN_BRAIN_FRACTION = 0.02

# Train/Val/Test split ratios (patient-level, not slice-level)
TRAIN_RATIO = 0.80
VAL_RATIO = 0.10
TEST_RATIO = 0.10

# ── Training ───────────────────────────────────────────────────────────────
BATCH_SIZE = 4                 # Small batch for 4GB VRAM
GRAD_ACCUMULATION_STEPS = 8    # Effective batch = 4 * 8 = 32
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 100
WARMUP_EPOCHS = 5
GRADIENT_CHECKPOINTING = True  # Saves ~30% VRAM by recomputing activations
EARLY_STOPPING_PATIENCE = 10
USE_AMP = True                 # Mixed precision (FP16)

# ── Model ──────────────────────────────────────────────────────────────────
ENCODER_NAME = "efficientnet-b0"  # B0 instead of B3 — fits 4GB VRAM comfortably
ENCODER_WEIGHTS = "imagenet"      # Pre-trained weights
NUM_CLASSES = 1                   # Binary segmentation (tumor / no tumor)

# ── Augmentation ───────────────────────────────────────────────────────────
AUG_ROTATION_LIMIT = 15        # degrees
AUG_BRIGHTNESS_LIMIT = 0.1
AUG_CONTRAST_LIMIT = 0.1
AUG_ELASTIC_ALPHA = 50
AUG_ELASTIC_SIGMA = 10

# ── Logging ────────────────────────────────────────────────────────────────
PRINT_EVERY_N_BATCHES = 50     # Print progress every N batches



## augmentations.py

Running content from `augmentations.py`.


In [None]:
"""
============================================================================
AUGMENTATIONS — MRI-Appropriate On-The-Fly Data Augmentation
============================================================================
Uses albumentations for fast, GPU-friendly augmentations.
All transforms are applied during DataLoader iteration — no extra disk usage.
Medical imaging constraints:
  - No heavy warping that distorts anatomy
  - Conservative rotation (±15°)
  - Mild brightness/contrast jitter
  - Elastic deformation (helps segmentation generalization)
============================================================================
"""

import albumentations as A
from albumentations.pytorch import ToTensorV2
import config


def get_train_transforms():
    """
    Training augmentations — applied on-the-fly to each 2D slice.
    Both image (4-channel) and mask (1-channel) are transformed together
    to maintain spatial alignment.
    """
    return A.Compose([
        # -- Spatial transforms (applied to both image and mask) --
        A.Resize(config.IMAGE_SIZE, config.IMAGE_SIZE),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.3),
        A.Rotate(
            limit=config.AUG_ROTATION_LIMIT,
            border_mode=0,  # zero-padding at borders
            p=0.5
        ),
        # Elastic deformation — good for segmentation, mild settings
        A.ElasticTransform(
            alpha=config.AUG_ELASTIC_ALPHA,
            sigma=config.AUG_ELASTIC_SIGMA,
            p=0.3
        ),
        # -- Intensity transforms (applied to image only, not mask) --
        A.RandomBrightnessContrast(
            brightness_limit=config.AUG_BRIGHTNESS_LIMIT,
            contrast_limit=config.AUG_CONTRAST_LIMIT,
            p=0.4
        ),
        A.GaussianBlur(blur_limit=(3, 5), p=0.2),
        A.GaussNoise(p=0.2),
        # Convert to tensor
        ToTensorV2(),
    ])


def get_val_transforms():
    """
    Validation/test transforms — deterministic, no randomness.
    Only resize + tensor conversion.
    """
    return A.Compose([
        A.Resize(config.IMAGE_SIZE, config.IMAGE_SIZE),
        ToTensorV2(),
    ])



## dataset.py

Running content from `dataset.py`.


In [None]:
"""
============================================================================
DATASET — Memory-Efficient Brain MRI Dataset with 2D Slice Extraction
============================================================================
Key design decisions:
  1. LAZY LOADING: NIfTI volumes are NOT loaded into RAM at init time.
     Only metadata (paths, slice indices) are cached.
  2. ON-THE-FLY SLICING: Each __getitem__ loads only ONE 2D slice from 
     the 3D volume using nibabel's proxy object (memory-mapped).
  3. PATIENT-LEVEL SPLIT: Train/val/test are split by patient ID, not
     by slice, to prevent data leakage between sets.
  4. EMPTY SLICE FILTERING: Slices with <2% brain tissue are excluded.
============================================================================
"""

import os
import random
import numpy as np
import nibabel as nib
from pathlib import Path
from typing import List, Tuple, Optional, Dict
from collections import defaultdict

import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import config
from augmentations import get_train_transforms, get_val_transforms


# ═══════════════════════════════════════════════════════════════════════════
# HELPER: Scan dataset and build a slice index WITHOUT loading image data
# ═══════════════════════════════════════════════════════════════════════════

def build_slice_index(
    data_root: Path,
    subset_ratio: float = 1.0,
    seed: int = 42,
    verbose: bool = True,
) -> List[Dict]:
    """
    Scans the dataset folder structure and builds an index of all valid
    2D slices across all patients/timepoints.
    
    Returns a list of dicts, each containing:
      - patient_id: str
      - timepoint: str
      - modality_paths: dict mapping modality suffix -> file path
      - mask_path: path to tumor mask (or None)
      - slice_idx: int (axial slice index)
      - has_tumor: bool (whether this slice has any tumor voxels)
    
    IMPORTANT: This function loads each volume BRIEFLY to check dimensions
    and identify valid slices, then immediately releases the data.
    Mask files are small (~30KB) and loaded to determine per-slice tumor presence.
    """
    
    # -- Step 1: Collect all patient directories --
    patient_dirs = sorted([d for d in data_root.iterdir() if d.is_dir()])
    
    # -- Step 2: Apply subset sampling (patient-level) --
    if subset_ratio < 1.0:
        random.seed(seed)
        n_patients = max(1, int(len(patient_dirs) * subset_ratio))
        patient_dirs = sorted(random.sample(patient_dirs, n_patients))
        if verbose:
            print(f"[SUBSET] Using {n_patients}/{len(list(data_root.iterdir()))} "
                  f"patients ({subset_ratio*100:.0f}%)")
    
    slice_index = []
    total_slices = 0
    tumor_slices = 0
    skipped_empty = 0
    
    for i, patient_dir in enumerate(patient_dirs):
        patient_id = patient_dir.name
        
        for tp_dir in sorted(patient_dir.iterdir()):
            if not tp_dir.is_dir():
                continue
            timepoint = tp_dir.name
            
            # -- Collect modality file paths --
            modality_paths = {}
            mask_path = None
            
            for f in tp_dir.iterdir():
                if not f.is_file() or not f.name.endswith('.nii.gz'):
                    continue
                
                fname = f.name.replace('.nii.gz', '')
                for mod in config.MODALITY_SUFFIXES:
                    if fname.endswith(mod):
                        modality_paths[mod] = f
                        break
                
                if config.MASK_SUFFIX in fname:
                    mask_path = f
            
            # Skip if not all 4 modalities are present
            if len(modality_paths) != config.NUM_MODALITIES:
                if verbose:
                    print(f"[WARN] Missing modalities in {patient_id}/{timepoint}, skipping")
                continue
            
            # -- Load ONE modality header to get volume dimensions --
            first_mod_path = list(modality_paths.values())[0]
            try:
                nii = nib.load(str(first_mod_path))
                vol_shape = nii.shape  # e.g., (240, 240, 155)
                num_slices = vol_shape[2]  # axial slices
            except Exception as e:
                if verbose:
                    print(f"[WARN] Failed to read {first_mod_path}: {e}")
                continue
            
            # -- Load mask to determine per-slice tumor presence --
            # Masks are very small (~30KB compressed) so this is fast
            mask_data = None
            if mask_path is not None:
                try:
                    mask_nii = nib.load(str(mask_path))
                    mask_data = mask_nii.get_fdata(dtype=np.float32)
                except Exception as e:
                    if verbose:
                        print(f"[WARN] Failed to read mask {mask_path}: {e}")
            
            # -- Also load one modality to check brain content per slice --
            # (We only check sum > threshold, very fast)
            try:
                ref_data = nii.get_fdata(dtype=np.float32)
            except Exception:
                continue
            
            # -- Build slice entries --
            for slice_idx in range(num_slices):
                # Check if slice has enough brain tissue
                ref_slice = ref_data[:, :, slice_idx]
                brain_fraction = np.count_nonzero(ref_slice) / ref_slice.size
                
                if brain_fraction < config.MIN_BRAIN_FRACTION:
                    skipped_empty += 1
                    continue
                
                # Check tumor presence in this slice
                has_tumor = False
                if mask_data is not None:
                    mask_slice = mask_data[:, :, slice_idx]
                    has_tumor = np.any(mask_slice > 0)
                
                slice_entry = {
                    "patient_id": patient_id,
                    "timepoint": timepoint,
                    "modality_paths": {k: str(v) for k, v in modality_paths.items()},
                    "mask_path": str(mask_path) if mask_path else None,
                    "slice_idx": slice_idx,
                    "has_tumor": has_tumor,
                    "vol_shape": vol_shape,
                }
                slice_index.append(slice_entry)
                total_slices += 1
                if has_tumor:
                    tumor_slices += 1
            
            # Release memory immediately
            del ref_data, mask_data
        
        if verbose and (i + 1) % 20 == 0:
            print(f"  Scanned {i+1}/{len(patient_dirs)} patients...")
    
    if verbose:
        non_tumor = total_slices - tumor_slices
        print(f"\n[INDEX] Slice index built:")
        print(f"  Total valid slices: {total_slices}")
        print(f"  Tumor slices:       {tumor_slices} ({tumor_slices/max(1,total_slices)*100:.1f}%)")
        print(f"  Non-tumor slices:   {non_tumor} ({non_tumor/max(1,total_slices)*100:.1f}%)")
        print(f"  Skipped empty:      {skipped_empty}")
    
    return slice_index


# ═══════════════════════════════════════════════════════════════════════════
# HELPER: Patient-level train/val/test split
# ═══════════════════════════════════════════════════════════════════════════

def patient_split(
    slice_index: List[Dict],
    train_ratio: float = 0.8,
    val_ratio: float = 0.1,
    test_ratio: float = 0.1,
    seed: int = 42,
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
    """
    Splits slice_index into train/val/test by PATIENT ID.
    This prevents data leakage — no patient appears in multiple splits.
    """
    # Group slices by patient
    patients = defaultdict(list)
    for entry in slice_index:
        patients[entry["patient_id"]].append(entry)
    
    patient_ids = sorted(patients.keys())
    random.seed(seed)
    random.shuffle(patient_ids)
    
    n = len(patient_ids)
    n_train = int(n * train_ratio)
    n_val = int(n * val_ratio)
    
    train_pids = set(patient_ids[:n_train])
    val_pids = set(patient_ids[n_train:n_train + n_val])
    test_pids = set(patient_ids[n_train + n_val:])
    
    train_slices = [e for e in slice_index if e["patient_id"] in train_pids]
    val_slices = [e for e in slice_index if e["patient_id"] in val_pids]
    test_slices = [e for e in slice_index if e["patient_id"] in test_pids]
    
    print(f"\n[SPLIT] Patient-level split (no leakage):")
    print(f"  Train: {len(train_pids)} patients, {len(train_slices)} slices")
    print(f"  Val:   {len(val_pids)} patients, {len(val_slices)} slices")
    print(f"  Test:  {len(test_pids)} patients, {len(test_slices)} slices")
    
    return train_slices, val_slices, test_slices


# ═══════════════════════════════════════════════════════════════════════════
# MAIN DATASET CLASS
# ═══════════════════════════════════════════════════════════════════════════

class BrainMRIDataset(Dataset):
    """
    Memory-efficient PyTorch Dataset for brain MRI tumor segmentation.
    
    Each sample is a 2D axial slice with:
      - Input:  (4, H, W) tensor — 4 MRI modalities stacked as channels
      - Target: (1, H, W) tensor — binary tumor mask
    
    Key features:
      - Lazy loading: only the requested slice is loaded per __getitem__
      - On-the-fly augmentation via albumentations
      - Per-modality intensity normalization (z-score)
    """
    
    def __init__(
        self,
        slice_index: List[Dict],
        transform=None,
        normalize: bool = True,
    ):
        self.slice_index = slice_index
        self.transform = transform
        self.normalize = normalize
        
        # Cache for loaded volumes (LRU-style, limited to save RAM)
        # Key: modality_path, Value: nibabel proxy image
        # We don't cache pixel data — only the nib object for fast slicing
        self._nii_cache = {}
        self._cache_max_size = 50  # Keep at most 50 nii objects cached
    
    def __len__(self):
        return len(self.slice_index)
    
    def _load_nii(self, path: str):
        """Load a NIfTI file, using cache to avoid repeated disk reads."""
        if path not in self._nii_cache:
            # Evict oldest if cache full
            if len(self._nii_cache) >= self._cache_max_size:
                oldest_key = next(iter(self._nii_cache))
                del self._nii_cache[oldest_key]
            self._nii_cache[path] = nib.load(path)
        return self._nii_cache[path]
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        entry = self.slice_index[idx]
        slice_idx = entry["slice_idx"]
        
        # ── Load 4 modality slices ─────────────────────────────────────
        channels = []
        for mod in config.MODALITY_SUFFIXES:
            nii = self._load_nii(entry["modality_paths"][mod])
            # get_fdata loads data; we take only the slice we need
            vol = nii.dataobj[..., slice_idx].astype(np.float32)
            channels.append(vol)
        
        # Stack as (H, W, 4) for albumentations (expects HWC)
        image = np.stack(channels, axis=-1)  # (H, W, 4)
        
        # ── Load mask slice ────────────────────────────────────────────
        if entry["mask_path"] is not None:
            mask_nii = self._load_nii(entry["mask_path"])
            mask = mask_nii.dataobj[..., slice_idx].astype(np.float32)
            # Binarize mask (any tumor label > 0 becomes 1)
            mask = (mask > 0).astype(np.float32)
        else:
            # No mask = no tumor
            mask = np.zeros(image.shape[:2], dtype=np.float32)
        
        # ── Normalize per-modality (z-score) ───────────────────────────
        if self.normalize:
            for c in range(image.shape[-1]):
                ch = image[:, :, c]
                # Only normalize on non-zero voxels (brain region)
                nonzero = ch[ch > 0]
                if len(nonzero) > 0:
                    mean = nonzero.mean()
                    std = nonzero.std() + 1e-8
                    ch[ch > 0] = (ch[ch > 0] - mean) / std
                image[:, :, c] = ch
        
        # ── Apply augmentations ────────────────────────────────────────
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]  # (4, H, W) tensor
            mask = transformed["mask"]    # (H, W) tensor
        else:
            # Manual conversion to tensor if no transform
            image = torch.from_numpy(image.transpose(2, 0, 1))  # (4, H, W)
            mask = torch.from_numpy(mask)
        
        # Ensure mask has channel dimension: (1, H, W)
        if mask.ndim == 2:
            mask = mask.unsqueeze(0)
        
        return image.float(), mask.float()
    
    def get_tumor_weights(self) -> List[float]:
        """
        Returns per-sample weights for WeightedRandomSampler.
        Tumor slices get higher weight to balance class distribution.
        """
        tumor_count = sum(1 for e in self.slice_index if e["has_tumor"])
        non_tumor_count = len(self.slice_index) - tumor_count
        
        if tumor_count == 0 or non_tumor_count == 0:
            return [1.0] * len(self.slice_index)
        
        # Weight inversely proportional to class frequency
        w_tumor = len(self.slice_index) / (2.0 * tumor_count)
        w_non_tumor = len(self.slice_index) / (2.0 * non_tumor_count)
        
        weights = []
        for entry in self.slice_index:
            weights.append(w_tumor if entry["has_tumor"] else w_non_tumor)
        
        return weights


# ═══════════════════════════════════════════════════════════════════════════
# DATALOADER FACTORY
# ═══════════════════════════════════════════════════════════════════════════

def create_dataloaders(
    subset_ratio: float = None,
    batch_size: int = None,
    num_workers: int = None,
    seed: int = 42,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Creates train/val/test DataLoaders with proper configuration.
    
    Returns:
        (train_loader, val_loader, test_loader)
    """
    if subset_ratio is None:
        subset_ratio = config.SUBSET_RATIO
    if batch_size is None:
        batch_size = config.BATCH_SIZE
    if num_workers is None:
        num_workers = config.NUM_WORKERS
    
    print("=" * 60)
    print("BUILDING DATA PIPELINE")
    print("=" * 60)
    
    # Step 1: Build slice index
    print("\n[1/3] Scanning dataset and building slice index...")
    slice_index = build_slice_index(
        config.DATA_ROOT,
        subset_ratio=subset_ratio,
        seed=seed,
    )
    
    # Step 2: Patient-level split
    print("\n[2/3] Splitting by patient ID...")
    train_idx, val_idx, test_idx = patient_split(
        slice_index,
        train_ratio=config.TRAIN_RATIO,
        val_ratio=config.VAL_RATIO,
        test_ratio=config.TEST_RATIO,
        seed=seed,
    )
    
    # Step 3: Create datasets
    print("\n[3/3] Creating DataLoaders...")
    train_dataset = BrainMRIDataset(train_idx, transform=get_train_transforms())
    val_dataset = BrainMRIDataset(val_idx, transform=get_val_transforms())
    test_dataset = BrainMRIDataset(test_idx, transform=get_val_transforms())
    
    # Weighted sampler for training (handles class imbalance)
    train_weights = train_dataset.get_tumor_weights()
    sampler = WeightedRandomSampler(
        weights=train_weights,
        num_samples=len(train_weights),
        replacement=True,
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,  # WeightedRandomSampler replaces shuffle
        num_workers=num_workers,
        pin_memory=config.PIN_MEMORY,
        drop_last=True,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=config.PIN_MEMORY,
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=config.PIN_MEMORY,
    )
    
    print(f"\n[READY] DataLoaders created:")
    print(f"  Train: {len(train_dataset)} slices, {len(train_loader)} batches")
    print(f"  Val:   {len(val_dataset)} slices, {len(val_loader)} batches")
    print(f"  Test:  {len(test_dataset)} slices, {len(test_loader)} batches")
    print(f"  Batch size: {batch_size}")
    print(f"  Workers: {num_workers}")
    print(f"  Pin memory: {config.PIN_MEMORY}")
    
    return train_loader, val_loader, test_loader


# ═══════════════════════════════════════════════════════════════════════════
# QUICK TEST (run this file directly to verify pipeline)
# ═══════════════════════════════════════════════════════════════════════════

if False: # __name__ == "__main__":
    import sys
    sys.stdout.reconfigure(encoding='utf-8')
    
    print("Testing data pipeline with 5% subset...\n")
    train_loader, val_loader, test_loader = create_dataloaders(
        subset_ratio=0.05,
        batch_size=4,
        num_workers=0,  # 0 workers for testing
    )
    
    # Grab one batch
    print("\nLoading first batch...")
    images, masks = next(iter(train_loader))
    print(f"\n[BATCH]")
    print(f"  Images shape: {images.shape}")    # Expected: (4, 4, 256, 256)
    print(f"  Masks shape:  {masks.shape}")      # Expected: (4, 1, 256, 256)
    print(f"  Images dtype: {images.dtype}")
    print(f"  Masks dtype:  {masks.dtype}")
    print(f"  Images range: [{images.min():.4f}, {images.max():.4f}]")
    print(f"  Masks unique: {masks.unique().tolist()}")
    print(f"  Tumor pixels in batch: {masks.sum().item():.0f}")
    print("\nData pipeline test PASSED!")



## model.py

Running content from `model.py`.


In [None]:
"""
============================================================================
MODEL — 2D U-Net with EfficientNet Encoder for Brain Tumor Segmentation
============================================================================
Architecture:
  - Encoder: EfficientNet-B0 (pretrained on ImageNet, optimized for 4GB VRAM)
  - Decoder: U-Net decoder with skip connections
  - Input: (B, 4, H, W) — 4 MRI modalities as channels
  - Output: (B, 1, H, W) — binary tumor mask (sigmoid activated)
  
Uses segmentation_models_pytorch (smp) for clean implementation.
The first conv layer is adapted from 3-channel (ImageNet) to 4-channel input.
============================================================================
"""

import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
import segmentation_models_pytorch as smp

import config


class BrainTumorSegModel(nn.Module):
    """
    U-Net segmentation model for brain tumor detection.
    
    Also includes a classification head that outputs a binary
    cancer/no-cancer prediction from the encoder features.
    """
    
    def __init__(
        self,
        encoder_name: str = None,
        encoder_weights: str = None,
        in_channels: int = None,
        num_classes: int = None,
        gradient_checkpointing: bool = None,
    ):
        super().__init__()
        
        encoder_name = encoder_name or config.ENCODER_NAME
        encoder_weights = encoder_weights or config.ENCODER_WEIGHTS
        in_channels = in_channels or config.NUM_MODALITIES
        num_classes = num_classes or config.NUM_CLASSES
        self.use_gradient_checkpointing = (
            gradient_checkpointing if gradient_checkpointing is not None
            else getattr(config, 'GRADIENT_CHECKPOINTING', False)
        )
        
        # ── Segmentation backbone (U-Net with EfficientNet encoder) ────
        self.segmentation_model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=num_classes,
            activation=None,  # We apply sigmoid in forward/loss
        )
        
        # Enable gradient checkpointing on encoder to save VRAM
        if self.use_gradient_checkpointing:
            self._enable_gradient_checkpointing()
        
        # ── Classification head (from encoder bottleneck features) ─────
        # Get the encoder output channels from smp
        encoder_channels = self.segmentation_model.encoder.out_channels
        bottleneck_channels = encoder_channels[-1]  # Deepest feature map
        
        self.classification_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),        # Global average pooling
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(bottleneck_channels, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(128, 1),               # Binary: cancer / no cancer
        )
    
    def _enable_gradient_checkpointing(self):
        """Enable gradient checkpointing on encoder blocks to save ~30% VRAM."""
        encoder = self.segmentation_model.encoder
        for name, module in encoder.named_children():
            if hasattr(module, 'gradient_checkpointing'):
                module.gradient_checkpointing = True
            # Wrap major encoder blocks with checkpointing
            for child_name, child in module.named_children():
                if isinstance(child, nn.Sequential) and len(list(child.children())) > 0:
                    original_forward = child.forward
                    def make_ckpt_forward(mod):
                        orig = mod.forward
                        def ckpt_forward(*args, **kwargs):
                            if self.training:
                                return cp.checkpoint(orig, *args, use_reentrant=False, **kwargs)
                            return orig(*args, **kwargs)
                        return ckpt_forward
                    child.forward = make_ckpt_forward(child)
    
    def forward(self, x, return_classification=False):
        """
        Forward pass.
        
        Args:
            x: (B, 4, H, W) input tensor
            return_classification: if True, also return classification logits
            
        Returns:
            seg_logits: (B, 1, H, W) segmentation logits (pre-sigmoid)
            cls_logits: (B, 1) classification logits (only if return_classification)
        """
        # Use the full segmentation model's forward for segmentation output
        seg_logits = self.segmentation_model(x)
        
        if return_classification:
            # Separately get encoder features for classification head
            features = self.segmentation_model.encoder(x)
            cls_logits = self.classification_head(features[-1])
            return seg_logits, cls_logits
        
        return seg_logits
    
    def predict(self, x):
        """
        Inference-mode prediction with sigmoid activation.
        Returns:
            seg_probs: (B, 1, H, W) probabilities [0, 1]
            cls_probs: (B, 1) classification probabilities [0, 1]
        """
        self.eval()
        with torch.no_grad():
            seg_logits, cls_logits = self.forward(x, return_classification=True)
            seg_probs = torch.sigmoid(seg_logits)
            cls_probs = torch.sigmoid(cls_logits)
        return seg_probs, cls_probs


# ═══════════════════════════════════════════════════════════════════════════
# LOSS FUNCTIONS
# ═══════════════════════════════════════════════════════════════════════════

class DiceBCELoss(nn.Module):
    """
    Combined Dice Loss + Binary Cross-Entropy Loss.
    
    Dice Loss: directly optimizes the Dice coefficient (overlap metric).
    BCE Loss: provides stable gradient signal, especially for small tumors.
    
    Combined loss = alpha * DiceLoss + (1 - alpha) * BCELoss
    """
    
    def __init__(self, alpha: float = 0.5, smooth: float = 1.0):
        super().__init__()
        self.alpha = alpha
        self.smooth = smooth
        self.bce = nn.BCEWithLogitsLoss()
    
    def dice_loss(self, logits, targets):
        """Computes Dice loss from logits."""
        probs = torch.sigmoid(logits)
        
        # Flatten spatial dimensions
        probs_flat = probs.view(-1)
        targets_flat = targets.view(-1)
        
        intersection = (probs_flat * targets_flat).sum()
        dice = (2.0 * intersection + self.smooth) / (
            probs_flat.sum() + targets_flat.sum() + self.smooth
        )
        return 1.0 - dice
    
    def forward(self, logits, targets):
        dice = self.dice_loss(logits, targets)
        bce = self.bce(logits, targets)
        return self.alpha * dice + (1.0 - self.alpha) * bce


class CombinedLoss(nn.Module):
    """
    Full loss: segmentation loss + classification loss.
    
    Total = seg_weight * DiceBCE(seg_logits, seg_targets)
          + cls_weight * BCE(cls_logits, cls_targets)
    """
    
    def __init__(self, seg_weight: float = 0.8, cls_weight: float = 0.2):
        super().__init__()
        self.seg_weight = seg_weight
        self.cls_weight = cls_weight
        self.seg_loss_fn = DiceBCELoss()
        self.cls_loss_fn = nn.BCEWithLogitsLoss()
    
    def forward(self, seg_logits, seg_targets, cls_logits=None, cls_targets=None):
        seg_loss = self.seg_loss_fn(seg_logits, seg_targets)
        
        if cls_logits is not None and cls_targets is not None:
            cls_loss = self.cls_loss_fn(cls_logits, cls_targets)
            total = self.seg_weight * seg_loss + self.cls_weight * cls_loss
            return total, seg_loss, cls_loss
        
        return seg_loss, seg_loss, torch.tensor(0.0)


# ═══════════════════════════════════════════════════════════════════════════
# METRICS
# ═══════════════════════════════════════════════════════════════════════════

def compute_metrics(seg_logits, seg_targets, threshold=0.5):
    """
    Computes segmentation metrics from logits and targets.
    
    Returns dict with: dice, iou, precision, recall, accuracy
    """
    with torch.no_grad():
        probs = torch.sigmoid(seg_logits)
        preds = (probs > threshold).float()
        
        # Flatten
        preds_flat = preds.view(-1)
        targets_flat = seg_targets.view(-1)
        
        tp = (preds_flat * targets_flat).sum()
        fp = (preds_flat * (1 - targets_flat)).sum()
        fn = ((1 - preds_flat) * targets_flat).sum()
        tn = ((1 - preds_flat) * (1 - targets_flat)).sum()
        
        smooth = 1e-7
        
        dice = (2 * tp + smooth) / (2 * tp + fp + fn + smooth)
        iou = (tp + smooth) / (tp + fp + fn + smooth)
        precision = (tp + smooth) / (tp + fp + smooth)
        recall = (tp + smooth) / (tp + fn + smooth)
        accuracy = (tp + tn) / (tp + tn + fp + fn + smooth)
        
        return {
            "dice": dice.item(),
            "iou": iou.item(),
            "precision": precision.item(),
            "recall": recall.item(),
            "accuracy": accuracy.item(),
        }


# ═══════════════════════════════════════════════════════════════════════════
# QUICK TEST
# ═══════════════════════════════════════════════════════════════════════════

if False: # __name__ == "__main__":
    print("Testing model architecture...")
    
    model = BrainTumorSegModel()
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters:     {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Test forward pass
    x = torch.randn(2, 4, 256, 256)
    
    # Segmentation only
    seg_out = model(x)
    print(f"\nSegmentation output:    {seg_out.shape}")
    
    # Segmentation + Classification
    seg_out, cls_out = model(x, return_classification=True)
    print(f"Segmentation output:    {seg_out.shape}")
    print(f"Classification output:  {cls_out.shape}")
    
    # Test loss
    target_seg = torch.randint(0, 2, (2, 1, 256, 256)).float()
    target_cls = torch.randint(0, 2, (2, 1)).float()
    
    loss_fn = CombinedLoss()
    total_loss, seg_loss, cls_loss = loss_fn(seg_out, target_seg, cls_out, target_cls)
    print(f"\nTotal loss: {total_loss.item():.4f}")
    print(f"Seg loss:   {seg_loss.item():.4f}")
    print(f"Cls loss:   {cls_loss.item():.4f}")
    
    # Test metrics
    metrics = compute_metrics(seg_out, target_seg)
    print(f"\nMetrics: {metrics}")
    
    print("\nModel test PASSED!")



## train.py

Running content from `train.py`.


In [None]:
"""
============================================================================
TRAIN — Full Training Loop for Brain Tumor Segmentation
============================================================================
Features:
  - Mixed precision training (torch.cuda.amp) — ~40% memory savings
  - Gradient accumulation for effective larger batch sizes
  - Cosine annealing LR scheduler with linear warmup
  - Checkpoint saving every epoch + best model tracking
  - Resume-from-checkpoint capability
  - Early stopping with patience
  - CSV metric logging every epoch
  - Memory management (empty_cache after validation)
  - Auto batch-size halving on OOM error
============================================================================
"""

import os
import sys
import csv
import time
import argparse
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler
import numpy as np

import config
from dataset import create_dataloaders
from model import BrainTumorSegModel, CombinedLoss, compute_metrics

sys.stdout.reconfigure(encoding='utf-8')


# ═══════════════════════════════════════════════════════════════════════════
# LEARNING RATE SCHEDULER: Cosine Annealing with Linear Warmup
# ═══════════════════════════════════════════════════════════════════════════

class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):
    """
    Linear warmup for warmup_epochs, then cosine decay to min_lr.
    """
    
    def __init__(self, optimizer, warmup_epochs, total_epochs, min_lr=1e-6, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.min_lr = min_lr
        super().__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            # Linear warmup
            factor = (self.last_epoch + 1) / self.warmup_epochs
            return [base_lr * factor for base_lr in self.base_lrs]
        else:
            # Cosine annealing
            progress = (self.last_epoch - self.warmup_epochs) / max(1, self.total_epochs - self.warmup_epochs)
            factor = 0.5 * (1.0 + np.cos(np.pi * progress))
            return [self.min_lr + (base_lr - self.min_lr) * factor for base_lr in self.base_lrs]


# ═══════════════════════════════════════════════════════════════════════════
# CSV LOGGER
# ═══════════════════════════════════════════════════════════════════════════

class CSVLogger:
    """Logs metrics to a CSV file every epoch."""
    
    FIELDS = [
        "epoch", "lr", "train_loss", "train_dice", "train_iou",
        "val_loss", "val_dice", "val_iou",
        "val_precision", "val_recall", "val_accuracy",
        "epoch_time_sec",
    ]
    
    def __init__(self, filepath):
        self.filepath = filepath
        if not Path(filepath).exists():
            with open(filepath, 'w', newline='') as f:
                writer = csv.DictWriter(f, fieldnames=self.FIELDS)
                writer.writeheader()
    
    def log(self, row: dict):
        with open(self.filepath, 'a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=self.FIELDS)
            writer.writerow({k: f"{v:.6f}" if isinstance(v, float) else v 
                           for k, v in row.items()})


# ═══════════════════════════════════════════════════════════════════════════
# TRAINING FUNCTIONS
# ═══════════════════════════════════════════════════════════════════════════

def train_one_epoch(
    model, loader, criterion, optimizer, scaler, device,
    grad_accum_steps=1, epoch=0, total_epochs=0,
):
    """
    Train for one epoch with mixed precision and gradient accumulation.
    """
    model.train()
    running_loss = 0.0
    running_dice = 0.0
    running_iou = 0.0
    num_batches = 0
    
    optimizer.zero_grad()
    
    for batch_idx, (images, masks) in enumerate(loader):
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
        
        # Classification target: 1 if any tumor pixel in the sample
        cls_targets = (masks.sum(dim=(1, 2, 3)) > 0).float().unsqueeze(1)
        
        # Mixed precision forward pass
        with autocast(device_type='cuda', enabled=config.USE_AMP):
            seg_logits, cls_logits = model(images, return_classification=True)
            total_loss, seg_loss, cls_loss = criterion(
                seg_logits, masks, cls_logits, cls_targets
            )
            # Scale loss for gradient accumulation
            total_loss = total_loss / grad_accum_steps
        
        # Backward pass with gradient scaling
        scaler.scale(total_loss).backward()
        
        # Step optimizer every grad_accum_steps
        if (batch_idx + 1) % grad_accum_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        # Track metrics
        running_loss += total_loss.item() * grad_accum_steps
        metrics = compute_metrics(seg_logits.detach(), masks)
        running_dice += metrics["dice"]
        running_iou += metrics["iou"]
        num_batches += 1
        
        # Print progress
        if (batch_idx + 1) % config.PRINT_EVERY_N_BATCHES == 0:
            avg_loss = running_loss / num_batches
            avg_dice = running_dice / num_batches
            print(f"  Epoch [{epoch+1}/{total_epochs}] "
                  f"Batch [{batch_idx+1}/{len(loader)}] "
                  f"Loss: {avg_loss:.4f} | Dice: {avg_dice:.4f}")
        
        # Release batch from GPU
        del images, masks, seg_logits, cls_logits
        
        # GPU memory report after first batch
        if batch_idx == 0 and torch.cuda.is_available():
            alloc = torch.cuda.memory_allocated() / 1024**2
            reserved = torch.cuda.memory_reserved() / 1024**2
            total = torch.cuda.get_device_properties(0).total_memory / 1024**2
            print(f"  [GPU] After 1st batch: {alloc:.0f} MB allocated / "
                  f"{reserved:.0f} MB reserved / {total:.0f} MB total")
    
    return {
        "loss": running_loss / max(num_batches, 1),
        "dice": running_dice / max(num_batches, 1),
        "iou": running_iou / max(num_batches, 1),
    }


@torch.no_grad()
def validate(model, loader, criterion, device):
    """
    Validate model on val/test set in inference mode.
    """
    model.eval()
    running_loss = 0.0
    running_metrics = {"dice": 0, "iou": 0, "precision": 0, "recall": 0, "accuracy": 0}
    num_batches = 0
    
    for images, masks in loader:
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
        cls_targets = (masks.sum(dim=(1, 2, 3)) > 0).float().unsqueeze(1)
        
        with autocast(device_type='cuda', enabled=config.USE_AMP):
            seg_logits, cls_logits = model(images, return_classification=True)
            total_loss, _, _ = criterion(seg_logits, masks, cls_logits, cls_targets)
        
        running_loss += total_loss.item()
        metrics = compute_metrics(seg_logits, masks)
        for k in running_metrics:
            running_metrics[k] += metrics[k]
        num_batches += 1
        
        del images, masks, seg_logits, cls_logits
    
    n = max(num_batches, 1)
    return {
        "loss": running_loss / n,
        "dice": running_metrics["dice"] / n,
        "iou": running_metrics["iou"] / n,
        "precision": running_metrics["precision"] / n,
        "recall": running_metrics["recall"] / n,
        "accuracy": running_metrics["accuracy"] / n,
    }


# ═══════════════════════════════════════════════════════════════════════════
# CHECKPOINT MANAGEMENT
# ═══════════════════════════════════════════════════════════════════════════

def save_checkpoint(model, optimizer, scheduler, scaler, epoch, best_dice, filepath):
    """Save training state for resumption."""
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "scaler_state_dict": scaler.state_dict(),
        "best_dice": best_dice,
    }, filepath)


def load_checkpoint(filepath, model, optimizer=None, scheduler=None, scaler=None):
    """Load training state from checkpoint."""
    checkpoint = torch.load(filepath, map_location=config.DEVICE, weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    
    if optimizer and "optimizer_state_dict" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    if scheduler and "scheduler_state_dict" in checkpoint:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    if scaler and "scaler_state_dict" in checkpoint:
        scaler.load_state_dict(checkpoint["scaler_state_dict"])
    
    return checkpoint["epoch"], checkpoint.get("best_dice", 0.0)


# ═══════════════════════════════════════════════════════════════════════════
# MAIN TRAINING LOOP
# ═══════════════════════════════════════════════════════════════════════════

def train(args):
    """Main training function."""
    
    print("=" * 70)
    print("BRAIN TUMOR SEGMENTATION — TRAINING")
    print("=" * 70)
    print(f"Device: {config.DEVICE}")
    
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    
    # ── Create DataLoaders ─────────────────────────────────────────────
    train_loader, val_loader, test_loader = create_dataloaders(
        subset_ratio=args.subset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    
    # ── Create Model ───────────────────────────────────────────────────
    model = BrainTumorSegModel().to(config.DEVICE)
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nModel: U-Net + {config.ENCODER_NAME}")
    print(f"  Total params:     {total_params:,}")
    print(f"  Trainable params: {trainable_params:,}")
    
    # ── Loss, Optimizer, Scheduler ─────────────────────────────────────
    criterion = CombinedLoss(seg_weight=0.8, cls_weight=0.2)
    optimizer = optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=config.WEIGHT_DECAY,
    )
    scheduler = CosineWarmupScheduler(
        optimizer,
        warmup_epochs=config.WARMUP_EPOCHS,
        total_epochs=args.epochs,
    )
    scaler = GradScaler('cuda', enabled=config.USE_AMP)
    
    # ── Resume from checkpoint if available ────────────────────────────
    start_epoch = 0
    best_dice = 0.0
    
    if args.resume:
        ckpt_path = Path(args.resume)
        if ckpt_path.exists():
            print(f"\nResuming from checkpoint: {ckpt_path}")
            start_epoch, best_dice = load_checkpoint(
                ckpt_path, model, optimizer, scheduler, scaler
            )
            start_epoch += 1  # Start from next epoch
            print(f"  Resuming from epoch {start_epoch}, best dice: {best_dice:.4f}")
    
    # ── CSV Logger ─────────────────────────────────────────────────────
    logger = CSVLogger(config.LOG_FILE)
    
    # ── GPU memory check after first batch ─────────────────────────────
    if torch.cuda.is_available():
        print("\n[GPU] Memory after model creation:")
        print(f"  Allocated: {torch.cuda.memory_allocated() / 1024**2:.0f} MB")
        print(f"  Cached:    {torch.cuda.memory_reserved() / 1024**2:.0f} MB")
    
    # ── Training loop ──────────────────────────────────────────────────
    patience_counter = 0
    
    print(f"\n{'=' * 70}")
    print(f"Starting training: {args.epochs - start_epoch} epochs")
    print(f"  Batch size: {args.batch_size}")
    print(f"  Grad accumulation: {config.GRAD_ACCUMULATION_STEPS}")
    print(f"  Effective batch: {args.batch_size * config.GRAD_ACCUMULATION_STEPS}")
    print(f"  Mixed precision: {config.USE_AMP}")
    print(f"  Gradient checkpoint: {getattr(config, 'GRADIENT_CHECKPOINTING', False)}")
    print(f"  Image size: {config.IMAGE_SIZE}x{config.IMAGE_SIZE}")
    print(f"{'=' * 70}\n")
    
    for epoch in range(start_epoch, args.epochs):
        epoch_start = time.time()
        
        # ── Train ──────────────────────────────────────────────────
        try:
            train_metrics = train_one_epoch(
                model, train_loader, criterion, optimizer, scaler,
                config.DEVICE, config.GRAD_ACCUMULATION_STEPS,
                epoch, args.epochs,
            )
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                print(f"\n[OOM] GPU out of memory! Halving batch size...")
                torch.cuda.empty_cache()
                args.batch_size = max(1, args.batch_size // 2)
                print(f"  New batch size: {args.batch_size}")
                print(f"  Recreating DataLoaders...")
                train_loader, val_loader, test_loader = create_dataloaders(
                    subset_ratio=args.subset,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                )
                continue  # Retry this epoch
            raise
        
        # ── Validate ───────────────────────────────────────────────
        val_metrics = validate(model, val_loader, criterion, config.DEVICE)
        
        # Free GPU memory after validation
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # ── Step scheduler ─────────────────────────────────────────
        scheduler.step()
        current_lr = optimizer.param_groups[0]["lr"]
        
        epoch_time = time.time() - epoch_start
        
        # ── Log metrics ────────────────────────────────────────────
        logger.log({
            "epoch": epoch + 1,
            "lr": current_lr,
            "train_loss": train_metrics["loss"],
            "train_dice": train_metrics["dice"],
            "train_iou": train_metrics["iou"],
            "val_loss": val_metrics["loss"],
            "val_dice": val_metrics["dice"],
            "val_iou": val_metrics["iou"],
            "val_precision": val_metrics["precision"],
            "val_recall": val_metrics["recall"],
            "val_accuracy": val_metrics["accuracy"],
            "epoch_time_sec": epoch_time,
        })
        
        # ── Print epoch summary ────────────────────────────────────
        print(f"\n{'─' * 70}")
        print(f"Epoch {epoch+1}/{args.epochs}  ({epoch_time:.1f}s)  LR: {current_lr:.6f}")
        print(f"  Train | Loss: {train_metrics['loss']:.4f} | "
              f"Dice: {train_metrics['dice']:.4f} | IoU: {train_metrics['iou']:.4f}")
        print(f"  Val   | Loss: {val_metrics['loss']:.4f} | "
              f"Dice: {val_metrics['dice']:.4f} | IoU: {val_metrics['iou']:.4f} | "
              f"Prec: {val_metrics['precision']:.4f} | Rec: {val_metrics['recall']:.4f}")
        
        # ── Save checkpoint ────────────────────────────────────────
        ckpt_path = config.CHECKPOINT_DIR / f"checkpoint_epoch_{epoch+1:03d}.pt"
        save_checkpoint(model, optimizer, scheduler, scaler, epoch, best_dice, ckpt_path)
        
        # Save best model
        if val_metrics["dice"] > best_dice:
            best_dice = val_metrics["dice"]
            best_path = config.CHECKPOINT_DIR / "best_model.pt"
            save_checkpoint(model, optimizer, scheduler, scaler, epoch, best_dice, best_path)
            print(f"  >> New best model! Dice: {best_dice:.4f}")
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"  >> No improvement. Patience: {patience_counter}/{config.EARLY_STOPPING_PATIENCE}")
        
        # ── Early stopping ─────────────────────────────────────────
        if patience_counter >= config.EARLY_STOPPING_PATIENCE:
            print(f"\n[EARLY STOPPING] No improvement for {config.EARLY_STOPPING_PATIENCE} epochs.")
            break
    
    # ── Final evaluation on test set ───────────────────────────────────
    print(f"\n{'=' * 70}")
    print("FINAL EVALUATION ON TEST SET")
    print(f"{'=' * 70}")
    
    # Load best model
    best_path = config.CHECKPOINT_DIR / "best_model.pt"
    if best_path.exists():
        load_checkpoint(best_path, model)
        print(f"Loaded best model (Dice: {best_dice:.4f})")
    
    test_metrics = validate(model, test_loader, criterion, config.DEVICE)
    print(f"\nTest Results:")
    print(f"  Loss:      {test_metrics['loss']:.4f}")
    print(f"  Dice:      {test_metrics['dice']:.4f}")
    print(f"  IoU:       {test_metrics['iou']:.4f}")
    print(f"  Precision: {test_metrics['precision']:.4f}")
    print(f"  Recall:    {test_metrics['recall']:.4f}")
    print(f"  Accuracy:  {test_metrics['accuracy']:.4f}")
    
    # ── Export as TorchScript ──────────────────────────────────────────
    print(f"\nExporting model to TorchScript...")
    model.eval()
    model_cpu = model.to("cpu")
    example_input = torch.randn(1, 4, config.IMAGE_SIZE, config.IMAGE_SIZE)
    
    try:
        traced = torch.jit.trace(model_cpu, example_input)
        export_path = config.PROJECT_ROOT / "brain_tumor_segmentation.pt"
        traced.save(str(export_path))
        print(f"  Model exported to: {export_path}")
    except Exception as e:
        print(f"  [WARN] TorchScript tracing failed: {e}")
        print(f"  Saving state_dict instead...")
        torch.save(model_cpu.state_dict(), config.PROJECT_ROOT / "brain_tumor_segmentation_weights.pt")
    
    print(f"\nTraining complete! Best validation Dice: {best_dice:.4f}")
    print(f"Logs saved to: {config.LOG_FILE}")
    print(f"Checkpoints in: {config.CHECKPOINT_DIR}")


# ═══════════════════════════════════════════════════════════════════════════
# CLI ENTRY POINT
# ═══════════════════════════════════════════════════════════════════════════

if False: # __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train brain tumor segmentation model")
    parser.add_argument("--epochs", type=int, default=config.EPOCHS,
                        help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=config.BATCH_SIZE,
                        help="Batch size per GPU")
    parser.add_argument("--lr", type=float, default=config.LEARNING_RATE,
                        help="Initial learning rate")
    parser.add_argument("--subset", type=float, default=config.SUBSET_RATIO,
                        help="Fraction of patients to use (0.0-1.0)")
    parser.add_argument("--num_workers", type=int, default=config.NUM_WORKERS,
                        help="DataLoader num_workers")
    parser.add_argument("--resume", type=str, default=None,
                        help="Path to checkpoint to resume from")
    
    args = parser.parse_args()
    train(args)



## evaluate.py

Running content from `evaluate.py`.


In [None]:
"""
============================================================================
EVALUATE — Test Set Evaluation & Visualization
============================================================================
Features:
  - Test set evaluation with torch.no_grad()
  - Confusion matrix generation
  - Mask overlay visualization (predicted mask on original MRI)
  - Training curves from CSV log
  - Sample predictions grid
============================================================================
"""

import sys
import csv
import argparse
from pathlib import Path

import numpy as np
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend (no GUI needed)
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import torch
from torch.amp import autocast
from sklearn.metrics import (
    confusion_matrix, classification_report,
    roc_curve, auc, ConfusionMatrixDisplay,
)

import config
from dataset import create_dataloaders
from model import BrainTumorSegModel, CombinedLoss, compute_metrics
from train import load_checkpoint

sys.stdout.reconfigure(encoding='utf-8')


# ═══════════════════════════════════════════════════════════════════════════
# 1. PLOT TRAINING CURVES
# ═══════════════════════════════════════════════════════════════════════════

def plot_training_curves(log_file: Path, output_dir: Path):
    """
    Reads the CSV training log and plots loss, Dice, IoU curves.
    """
    epochs, train_loss, val_loss = [], [], []
    train_dice, val_dice = [], []
    train_iou, val_iou = [], []
    lrs = []
    
    with open(log_file, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            epochs.append(int(row["epoch"]))
            train_loss.append(float(row["train_loss"]))
            val_loss.append(float(row["val_loss"]))
            train_dice.append(float(row["train_dice"]))
            val_dice.append(float(row["val_dice"]))
            train_iou.append(float(row["train_iou"]))
            val_iou.append(float(row["val_iou"]))
            lrs.append(float(row["lr"]))
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle("Training Progress", fontsize=16, fontweight="bold")
    
    # Loss
    axes[0, 0].plot(epochs, train_loss, 'b-', label='Train', linewidth=2)
    axes[0, 0].plot(epochs, val_loss, 'r-', label='Validation', linewidth=2)
    axes[0, 0].set_title("Loss")
    axes[0, 0].set_xlabel("Epoch")
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Dice
    axes[0, 1].plot(epochs, train_dice, 'b-', label='Train', linewidth=2)
    axes[0, 1].plot(epochs, val_dice, 'r-', label='Validation', linewidth=2)
    axes[0, 1].set_title("Dice Coefficient")
    axes[0, 1].set_xlabel("Epoch")
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # IoU
    axes[1, 0].plot(epochs, train_iou, 'b-', label='Train', linewidth=2)
    axes[1, 0].plot(epochs, val_iou, 'r-', label='Validation', linewidth=2)
    axes[1, 0].set_title("IoU (Intersection over Union)")
    axes[1, 0].set_xlabel("Epoch")
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Learning rate
    axes[1, 1].plot(epochs, lrs, 'g-', linewidth=2)
    axes[1, 1].set_title("Learning Rate Schedule")
    axes[1, 1].set_xlabel("Epoch")
    axes[1, 1].set_yscale("log")
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    save_path = output_dir / "training_curves.png"
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    print(f"[SAVED] Training curves: {save_path}")


# ═══════════════════════════════════════════════════════════════════════════
# 2. CONFUSION MATRIX (per-slice classification)
# ═══════════════════════════════════════════════════════════════════════════

@torch.no_grad()
def generate_confusion_matrix(model, test_loader, device, output_dir: Path):
    """
    Generates confusion matrix for slice-level tumor detection.
    Each slice is classified as tumor/no-tumor based on the segmentation output.
    """
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    for images, masks in test_loader:
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
        
        with autocast(device_type='cuda', enabled=config.USE_AMP):
            seg_logits, cls_logits = model(images, return_classification=True)
        
        # Slice-level label: 1 if any tumor pixel exists
        labels = (masks.sum(dim=(1, 2, 3)) > 0).cpu().numpy()
        
        # Slice-level prediction: 1 if predicted mask has tumor pixels
        probs = torch.sigmoid(cls_logits).squeeze(-1).cpu().numpy()
        preds = (probs > 0.5).astype(int)
        
        all_preds.extend(preds)
        all_labels.extend(labels)
        all_probs.extend(probs)
        
        del images, masks
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    # ── Confusion Matrix ──
    cm = confusion_matrix(all_labels, all_preds)
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    disp = ConfusionMatrixDisplay(cm, display_labels=["No Tumor", "Tumor"])
    disp.plot(ax=ax, cmap="Blues", values_format="d")
    ax.set_title("Confusion Matrix (Slice-Level Classification)", fontsize=14)
    save_path = output_dir / "confusion_matrix.png"
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    print(f"[SAVED] Confusion matrix: {save_path}")
    
    # ── ROC Curve ──
    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    roc_auc = auc(fpr, tpr)
    
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    ax.plot(fpr, tpr, 'b-', linewidth=2, label=f'AUC = {roc_auc:.4f}')
    ax.plot([0, 1], [0, 1], 'r--', linewidth=1)
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title("ROC Curve (Slice-Level)", fontsize=14)
    ax.legend(fontsize=12)
    ax.grid(True, alpha=0.3)
    save_path = output_dir / "roc_curve.png"
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    print(f"[SAVED] ROC curve: {save_path}")
    
    # ── Classification Report ──
    report = classification_report(all_labels, all_preds,
                                    target_names=["No Tumor", "Tumor"])
    print(f"\nClassification Report:\n{report}")
    print(f"AUC-ROC: {roc_auc:.4f}")
    
    return roc_auc


# ═══════════════════════════════════════════════════════════════════════════
# 3. SAMPLE PREDICTIONS WITH MASK OVERLAY
# ═══════════════════════════════════════════════════════════════════════════

@torch.no_grad()
def generate_sample_predictions(model, test_loader, device, output_dir: Path, n_samples=10):
    """
    Generates overlay visualizations: original MRI + ground truth + prediction.
    """
    model.eval()
    
    samples_collected = 0
    fig, axes = plt.subplots(n_samples, 4, figsize=(16, 4 * n_samples))
    fig.suptitle("Sample Predictions", fontsize=16, fontweight="bold", y=1.01)
    
    column_titles = ["T1c Input", "Ground Truth Mask", "Predicted Mask", "Overlay"]
    
    for images, masks in test_loader:
        if samples_collected >= n_samples:
            break
        
        images = images.to(device, non_blocking=True)
        with autocast(device_type='cuda', enabled=config.USE_AMP):
            seg_logits = model(images)
        
        seg_probs = torch.sigmoid(seg_logits).cpu().numpy()
        images_np = images.cpu().numpy()
        masks_np = masks.cpu().numpy()
        
        for i in range(images.shape[0]):
            if samples_collected >= n_samples:
                break
            
            idx = samples_collected
            
            # T1c channel (first modality) for display
            t1c = images_np[i, 0]  # (H, W)
            gt_mask = masks_np[i, 0]  # (H, W)
            pred_mask = seg_probs[i, 0]  # (H, W)
            pred_binary = (pred_mask > 0.5).astype(float)
            
            # Normalize T1c for display
            t1c_display = (t1c - t1c.min()) / (t1c.max() - t1c.min() + 1e-8)
            
            # Column 1: T1c slice
            axes[idx, 0].imshow(t1c_display, cmap='gray')
            axes[idx, 0].set_title(column_titles[0] if idx == 0 else "")
            axes[idx, 0].axis('off')
            
            # Column 2: Ground truth mask
            axes[idx, 1].imshow(t1c_display, cmap='gray')
            axes[idx, 1].imshow(gt_mask, cmap='Reds', alpha=0.5)
            axes[idx, 1].set_title(column_titles[1] if idx == 0 else "")
            axes[idx, 1].axis('off')
            
            # Column 3: Predicted mask
            axes[idx, 2].imshow(t1c_display, cmap='gray')
            axes[idx, 2].imshow(pred_binary, cmap='Blues', alpha=0.5)
            axes[idx, 2].set_title(column_titles[2] if idx == 0 else "")
            axes[idx, 2].axis('off')
            
            # Column 4: Overlay (green=TP, red=FN, blue=FP)
            overlay = np.zeros((*t1c_display.shape, 3))
            overlay[..., :] = np.stack([t1c_display]*3, axis=-1)
            overlay[gt_mask > 0, 1] = 0.7  # Ground truth in green
            overlay[pred_binary > 0, 2] = 0.7  # Prediction in blue
            overlay[(gt_mask > 0) & (pred_binary > 0), :] = [0, 1, 0]  # TP in bright green
            overlay[(gt_mask > 0) & (pred_binary == 0), :] = [1, 0, 0]  # FN in red
            overlay[(gt_mask == 0) & (pred_binary > 0), :] = [0, 0, 1]  # FP in blue
            
            axes[idx, 3].imshow(np.clip(overlay, 0, 1))
            axes[idx, 3].set_title(column_titles[3] if idx == 0 else "")
            axes[idx, 3].axis('off')
            
            # Dice for this sample
            dice = compute_metrics(
                torch.tensor(pred_mask).unsqueeze(0).unsqueeze(0),
                torch.tensor(gt_mask).unsqueeze(0).unsqueeze(0),
            )["dice"]
            axes[idx, 0].set_ylabel(f"Sample {idx+1}\nDice: {dice:.3f}", fontsize=10)
            
            samples_collected += 1
        
        del images, masks
    
    plt.tight_layout()
    save_path = output_dir / "sample_predictions.png"
    fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    print(f"[SAVED] Sample predictions: {save_path}")
    
    # Save individual overlay images
    overlay_dir = output_dir / "overlays"
    overlay_dir.mkdir(exist_ok=True)
    print(f"[INFO] Individual overlays saved to: {overlay_dir}")


# ═══════════════════════════════════════════════════════════════════════════
# MAIN EVALUATION ENTRY POINT
# ═══════════════════════════════════════════════════════════════════════════

def evaluate(args):
    """Run full evaluation pipeline."""
    print("=" * 70)
    print("BRAIN TUMOR SEGMENTATION - EVALUATION")
    print("=" * 70)
    
    output_dir = config.OUTPUT_DIR
    output_dir.mkdir(exist_ok=True)
    
    # ── Plot training curves ──
    if config.LOG_FILE.exists():
        print("\n[1/3] Plotting training curves...")
        plot_training_curves(config.LOG_FILE, output_dir)
    else:
        print(f"[SKIP] No training log found at {config.LOG_FILE}")
    
    # ── Load model ──
    print("\n[2/3] Loading best model...")
    model = BrainTumorSegModel().to(config.DEVICE)
    
    ckpt_path = Path(args.checkpoint) if args.checkpoint else config.CHECKPOINT_DIR / "best_model.pt"
    if ckpt_path.exists():
        epoch, best_dice = load_checkpoint(ckpt_path, model)
        print(f"  Loaded from: {ckpt_path}")
        print(f"  Epoch: {epoch+1}, Best Dice: {best_dice:.4f}")
    else:
        print(f"  [ERROR] Checkpoint not found: {ckpt_path}")
        return
    
    # ── Create test DataLoader ──
    _, _, test_loader = create_dataloaders(
        subset_ratio=args.subset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    
    # ── Full test evaluation ──
    print("\n[3/3] Running evaluation...")
    criterion = CombinedLoss()
    
    from train import validate
    test_metrics = validate(model, test_loader, criterion, config.DEVICE)
    
    print(f"\n{'=' * 50}")
    print(f"TEST RESULTS")
    print(f"{'=' * 50}")
    for k, v in test_metrics.items():
        print(f"  {k:>12s}: {v:.4f}")
    
    # ── Confusion matrix & ROC ──
    print("\nGenerating confusion matrix and ROC curve...")
    roc_auc = generate_confusion_matrix(model, test_loader, config.DEVICE, output_dir)
    
    # ── Sample predictions ──
    print("\nGenerating sample predictions...")
    generate_sample_predictions(model, test_loader, config.DEVICE, output_dir,
                                n_samples=args.n_samples)
    
    print(f"\n{'=' * 70}")
    print(f"Evaluation complete! Results saved to: {output_dir}")
    print(f"{'=' * 70}")


if False: # __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate brain tumor segmentation model")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Path to model checkpoint")
    parser.add_argument("--subset", type=float, default=1.0,
                        help="Dataset subset ratio for evaluation")
    parser.add_argument("--batch_size", type=int, default=config.BATCH_SIZE,
                        help="Batch size")
    parser.add_argument("--num_workers", type=int, default=config.NUM_WORKERS,
                        help="DataLoader workers")
    parser.add_argument("--n_samples", type=int, default=10,
                        help="Number of sample predictions to visualize")
    
    args = parser.parse_args()
    evaluate(args)



## inference.py

Running content from `inference.py`.


In [None]:
"""
============================================================================
INFERENCE — Single-Image Prediction Pipeline
============================================================================
Usage:
    python inference.py --input "Data/PatientID_0003/Timepoint_1" --output "results/"

Input: Path to a timepoint folder containing 4 MRI modality NIfTI files
Output: 
  - Per-slice tumor masks saved as PNGs
  - Overlay visualizations
  - Classification result (cancerous / non-cancerous) with confidence
============================================================================
"""

import sys
import argparse
from pathlib import Path

import numpy as np
import nibabel as nib
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import torch
from torch.amp import autocast

import config

sys.stdout.reconfigure(encoding='utf-8')


def load_model(checkpoint_path: str = None):
    """
    Load the trained model from checkpoint.
    Tries TorchScript first, falls back to state_dict.
    """
    # Try TorchScript model first
    ts_path = config.PROJECT_ROOT / "brain_tumor_segmentation.pt"
    if ts_path.exists() and checkpoint_path is None:
        print(f"[MODEL] Loading TorchScript model: {ts_path}")
        model = torch.jit.load(str(ts_path), map_location=config.DEVICE)
        model.eval()
        return model, "torchscript"
    
    # Fall back to checkpoint
    from model import BrainTumorSegModel
    model = BrainTumorSegModel().to(config.DEVICE)
    
    ckpt_path = Path(checkpoint_path) if checkpoint_path else config.CHECKPOINT_DIR / "best_model.pt"
    if ckpt_path.exists():
        print(f"[MODEL] Loading checkpoint: {ckpt_path}")
        checkpoint = torch.load(ckpt_path, map_location=config.DEVICE, weights_only=False)
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        print(f"[ERROR] No model found at {ckpt_path}")
        sys.exit(1)
    
    model.eval()
    return model, "checkpoint"


def load_timepoint_data(timepoint_dir: Path):
    """
    Load all 4 MRI modalities from a timepoint directory.
    Returns: (volume_4d, affine) where volume_4d is (H, W, D, 4)
    """
    modality_volumes = []
    
    for mod in config.MODALITY_SUFFIXES:
        # Find the file matching this modality
        matching = list(timepoint_dir.glob(f"*{mod}.nii.gz"))
        if not matching:
            raise FileNotFoundError(f"Missing modality {mod} in {timepoint_dir}")
        
        nii = nib.load(str(matching[0]))
        vol = nii.get_fdata(dtype=np.float32)
        modality_volumes.append(vol)
        affine = nii.affine
    
    # Stack as (H, W, D, 4)
    volume_4d = np.stack(modality_volumes, axis=-1)
    return volume_4d, affine


def normalize_slice(slice_4ch):
    """Z-score normalize each channel of a 2D slice."""
    normalized = np.copy(slice_4ch)
    for c in range(slice_4ch.shape[-1]):
        ch = normalized[:, :, c]
        nonzero = ch[ch > 0]
        if len(nonzero) > 0:
            mean = nonzero.mean()
            std = nonzero.std() + 1e-8
            ch[ch > 0] = (ch[ch > 0] - mean) / std
        normalized[:, :, c] = ch
    return normalized


@torch.no_grad()
def predict_volume(model, volume_4d, model_type="checkpoint"):
    """
    Run inference on an entire 3D volume, slice by slice.
    
    Args:
        model: trained model
        volume_4d: (H, W, D, 4) numpy array
        model_type: "torchscript" or "checkpoint"
    
    Returns:
        predictions: (H_out, W_out, D) probability map
        classifications: (D,) per-slice cancer probability
    """
    H, W, D, C = volume_4d.shape
    predictions = np.zeros((config.IMAGE_SIZE, config.IMAGE_SIZE, D), dtype=np.float32)
    classifications = np.zeros(D, dtype=np.float32)
    
    for z in range(D):
        slice_4ch = volume_4d[:, :, z, :]  # (H, W, 4)
        
        # Skip mostly-empty slices
        brain_fraction = np.count_nonzero(slice_4ch[:, :, 0]) / (H * W)
        if brain_fraction < config.MIN_BRAIN_FRACTION:
            continue
        
        # Normalize
        slice_norm = normalize_slice(slice_4ch)
        
        # Resize to model input size
        from skimage.transform import resize
        slice_resized = resize(slice_norm, (config.IMAGE_SIZE, config.IMAGE_SIZE, C),
                               preserve_range=True, anti_aliasing=True)
        
        # To tensor: (1, 4, H, W)
        tensor = torch.from_numpy(slice_resized.transpose(2, 0, 1)).unsqueeze(0).float()
        tensor = tensor.to(config.DEVICE)
        
        with autocast(device_type='cuda', enabled=config.USE_AMP and config.DEVICE.type == 'cuda'):
            if model_type == "torchscript":
                seg_logits = model(tensor)
                cls_prob = 0.0  # TorchScript may not support dual output
            else:
                seg_logits, cls_logits = model(tensor, return_classification=True)
                cls_prob = torch.sigmoid(cls_logits).item()
        
        seg_prob = torch.sigmoid(seg_logits).squeeze().cpu().numpy()
        predictions[:, :, z] = seg_prob
        classifications[z] = cls_prob
    
    return predictions, classifications


def generate_visualizations(volume_4d, predictions, classifications, output_dir: Path):
    """
    Generate overlay visualizations for each clinically relevant slice.
    """
    H, W, D, C = volume_4d.shape
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Find slices with significant predictions
    tumor_slices = []
    for z in range(D):
        pred_slice = predictions[:, :, z]
        if pred_slice.max() > 0.3:  # At least some tumor probability
            tumor_slices.append((z, pred_slice.max()))
    
    tumor_slices.sort(key=lambda x: x[1], reverse=True)
    
    print(f"\n[VIS] Found {len(tumor_slices)} slices with tumor predictions")
    
    # Generate overlay for top slices
    n_show = min(20, len(tumor_slices))
    
    if n_show > 0:
        fig, axes = plt.subplots(n_show, 3, figsize=(12, 4 * n_show))
        if n_show == 1:
            axes = axes.reshape(1, -1)
        
        for i, (z, max_prob) in enumerate(tumor_slices[:n_show]):
            from skimage.transform import resize
            t1c = volume_4d[:, :, z, 0]
            t1c_resized = resize(t1c, (config.IMAGE_SIZE, config.IMAGE_SIZE),
                                 preserve_range=True, anti_aliasing=True)
            t1c_display = (t1c_resized - t1c_resized.min()) / (t1c_resized.max() - t1c_resized.min() + 1e-8)
            
            pred = predictions[:, :, z]
            pred_binary = (pred > 0.5).astype(float)
            
            # Original
            axes[i, 0].imshow(t1c_display, cmap='gray')
            axes[i, 0].set_title(f"Slice {z} - T1c" if i == 0 else f"Slice {z}")
            axes[i, 0].axis('off')
            
            # Prediction heatmap
            axes[i, 1].imshow(t1c_display, cmap='gray')
            axes[i, 1].imshow(pred, cmap='hot', alpha=0.6, vmin=0, vmax=1)
            axes[i, 1].set_title("Tumor Probability" if i == 0 else "")
            axes[i, 1].axis('off')
            
            # Binary overlay
            overlay = np.stack([t1c_display]*3, axis=-1)
            overlay[pred_binary > 0] = [1, 0.2, 0.2]  # Red for tumor
            axes[i, 2].imshow(np.clip(overlay, 0, 1))
            axes[i, 2].set_title("Tumor Region" if i == 0 else "")
            axes[i, 2].axis('off')
        
        plt.tight_layout()
        save_path = output_dir / "tumor_overlay_grid.png"
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
        plt.close(fig)
        print(f"[SAVED] Tumor overlay grid: {save_path}")
    
    # Save individual overlays for the top 5 slices
    for i, (z, max_prob) in enumerate(tumor_slices[:5]):
        from skimage.transform import resize
        t1c = volume_4d[:, :, z, 0]
        t1c_resized = resize(t1c, (config.IMAGE_SIZE, config.IMAGE_SIZE),
                             preserve_range=True, anti_aliasing=True)
        t1c_display = (t1c_resized - t1c_resized.min()) / (t1c_resized.max() - t1c_resized.min() + 1e-8)
        pred = predictions[:, :, z]
        
        fig, ax = plt.subplots(1, 1, figsize=(6, 6))
        ax.imshow(t1c_display, cmap='gray')
        ax.imshow(pred, cmap='hot', alpha=0.5, vmin=0, vmax=1)
        ax.set_title(f"Slice {z} | Tumor Prob: {max_prob:.2f}")
        ax.axis('off')
        
        save_path = output_dir / f"overlay_slice_{z:03d}.png"
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
        plt.close(fig)
    
    print(f"[SAVED] Individual overlays in: {output_dir}")


def main(args):
    """Main inference pipeline."""
    print("=" * 70)
    print("BRAIN TUMOR SEGMENTATION - INFERENCE")
    print("=" * 70)
    
    input_dir = Path(args.input)
    output_dir = Path(args.output)
    
    if not input_dir.exists():
        print(f"[ERROR] Input directory not found: {input_dir}")
        sys.exit(1)
    
    # ── Load model ──
    model, model_type = load_model(args.checkpoint)
    print(f"  Model type: {model_type}")
    
    # ── Load MRI data ──
    print(f"\n[DATA] Loading MRI from: {input_dir}")
    volume_4d, affine = load_timepoint_data(input_dir)
    print(f"  Volume shape: {volume_4d.shape}")
    print(f"  Modalities: {config.MODALITY_SUFFIXES}")
    
    # ── Run prediction ──
    print(f"\n[PREDICT] Running inference on {volume_4d.shape[2]} slices...")
    predictions, classifications = predict_volume(model, volume_4d, model_type)
    
    # ── Overall classification ──
    # Average classification confidence across all valid slices
    valid_cls = classifications[classifications > 0]
    if len(valid_cls) > 0:
        avg_confidence = valid_cls.mean()
        max_confidence = valid_cls.max()
    else:
        avg_confidence = 0.0
        max_confidence = 0.0
    
    is_cancerous = max_confidence > 0.5
    
    # Count tumor voxels in prediction
    tumor_volume = (predictions > 0.5).sum()
    total_volume = predictions.size
    tumor_fraction = tumor_volume / total_volume
    
    print(f"\n{'=' * 50}")
    print(f"DIAGNOSIS RESULT")
    print(f"{'=' * 50}")
    print(f"  Classification:     {'CANCEROUS' if is_cancerous else 'NON-CANCEROUS'}")
    print(f"  Max Confidence:     {max_confidence:.4f}")
    print(f"  Avg Confidence:     {avg_confidence:.4f}")
    print(f"  Tumor volume:       {tumor_volume} voxels ({tumor_fraction*100:.2f}% of brain)")
    print(f"  Affected slices:    {(predictions.max(axis=(0,1)) > 0.5).sum()}/{volume_4d.shape[2]}")
    
    # ── Generate visualizations ──
    print(f"\n[VIS] Generating visualizations...")
    generate_visualizations(volume_4d, predictions, classifications, output_dir)
    
    # ── Save prediction as NIfTI ──
    pred_nii_path = output_dir / "predicted_mask.nii.gz"
    from skimage.transform import resize
    # Resize back to original dimensions
    pred_full = resize(predictions, volume_4d.shape[:3], preserve_range=True, anti_aliasing=True)
    pred_nii = nib.Nifti1Image((pred_full > 0.5).astype(np.uint8), affine)
    nib.save(pred_nii, str(pred_nii_path))
    print(f"[SAVED] Predicted mask NIfTI: {pred_nii_path}")
    
    print(f"\n{'=' * 70}")
    print(f"Inference complete! Results saved to: {output_dir}")
    print(f"{'=' * 70}")
    
    # ── Clinical disclaimer ──
    print(f"""
 ** IMPORTANT DISCLAIMER **
 This model is for research and competition purposes only.
 It has NOT been clinically validated and should NEVER be used
 as the sole basis for medical diagnosis or treatment decisions.
 Always consult qualified medical professionals.
""")


if False: # __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run inference on a brain MRI scan")
    parser.add_argument("--input", type=str, required=True,
                        help="Path to timepoint folder with MRI NIfTI files")
    parser.add_argument("--output", type=str, default="inference_output",
                        help="Output directory for results")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Path to model checkpoint (optional)")
    
    args = parser.parse_args()
    main(args)



## Execution Instructions

The code above defines all necessary classes and functions.
Use the cells below to run training, evaluation, or inference.



In [None]:
# Train the model (Example: 5% subset for quick test)
# To train on full dataset, set subset=1.0 and epochs=50+

class Args:
    epochs = 2
    batch_size = 4
    lr = 1e-3
    subset = 0.05
    num_workers = 0  # 0 for safe Windows interaction in notebook
    resume = None

args = Args()
if __name__ == "__main__":
    try:
        train(args)
    except SystemExit:
        pass



In [None]:
# Evaluate the model
# (Assumes 'checkpoints/best_model.pt' exists after training)

if __name__ == "__main__":
    try:
        evaluate_args = argparse.Namespace(
            checkpoint=None,
            subset=0.05,
            batch_size=4,
            num_workers=0,
            n_samples=5
        )
        evaluate(evaluate_args)
    except SystemExit:
        pass
    except Exception as e:
        print(f"Evaluation error (maybe model not trained yet?): {e}")

