<a href="https://www.kaggle.com/code/nicholas33/stage1-aneurysmnet-intracranial-training-nb153?scriptVersionId=258046122" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
# ====================================================
# RSNA INTRACRANIAL ANEURYSM - STAGE 1 TRAINING (v2)
# Uses Stage-0 prebuilt v2 cache (volumes, masks, pseudo_masks, brainmasks, manifest)
# Two-phase training with per-sample segmentation weights and rich progress logs:
#   Phase 1: real masks weighted (real_seg_weight), synthetic seg weight = 0.0
#   Phase 2: real same, synthetic seg weight = small (default 0.075)
# Saves: stage1_phase1_best.pth, stage1_phase2_best.pth, stage1_segmentation_best.pth
# ====================================================

import os
import math
import time
import random
import numpy as np
import pandas as pd
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

In [None]:
# ====================================================
# Config
# ====================================================
class Config:
    # --- Paths ---
    PREBUILT_ROOT = "/kaggle/input/rsna2025-v2-intracranial-aneurysm-detection-nb153/stage1_AneurysmNet_prebuilt_v2"
    MANIFEST_PATH = os.path.join(PREBUILT_ROOT, "meta/manifest.csv")
    VOLUMES_DIR   = os.path.join(PREBUILT_ROOT, "volumes")
    MASKS_DIR     = os.path.join(PREBUILT_ROOT, "masks")          # real
    PSEUDO_DIR    = os.path.join(PREBUILT_ROOT, "pseudo_masks")   # synthetic
    BRAINMASKS_DIR= os.path.join(PREBUILT_ROOT, "brainmasks")

    # --- Data ---
    TARGET_SIZE = (48, 112, 112)  # (D,H,W)
    USE_BRAINMASKS = True
    BRAINMASK_KEY = 'm'
    BRAINMASK_MIN_FRAC = 0.02  # if below, skip masking

    # --- Training ---
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    MIXED_PRECISION = True
    STAGE1_BATCH_SIZE = 8  # bump to 12/16 if GPU mem allows
    NUM_WORKERS = 2        # bump to 4/8 to reduce CPU bottlenecks
    PREFETCH_FACTOR = 2
    PERSISTENT_WORKERS = True
    STAGE1_LR = 2e-4
    WEIGHT_DECAY = 1e-4
    EPOCHS_PHASE1 = 15
    EPOCHS_PHASE2 = 15
    EARLY_STOP_PATIENCE = 5

    # --- Segmentation weights ---
    REAL_SEG_DEFAULT_W = 0.7      # used if manifest lacks real_seg_weight
    PHASE1_SYNTH_SEG_W = 0.0
    PHASE2_SYNTH_SEG_W = 0.075
    FOCAL_LOSS_WEIGHT = 0.2

    # --- Splits ---
    FOLDS = 1   # set >1 later if you want CV here
    SEED = 42

# ====================================================
# Utils
# ====================================================

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True  # speeds up fixed-size convs


def load_manifest_df() -> pd.DataFrame:
    df = pd.read_csv(Config.MANIFEST_PATH)
    # Required columns: series_id, label, vol_relpath
    for col in ["series_id", "label", "vol_relpath"]:
        if col not in df.columns:
            raise RuntimeError(f"Manifest missing required column: {col}")
    return df


def gpu_mem_str():
    if not torch.cuda.is_available():
        return "cpu"
    try:
        a = torch.cuda.memory_allocated() / (1024**3)
        r = torch.cuda.memory_reserved() / (1024**3)
        return f"{a:.2f}G/{r:.2f}G"
    except Exception:
        return "gpu"

In [None]:
# ====================================================
# Dataset
# ====================================================
class PrebuiltDataset(Dataset):
    def __init__(self, df: pd.DataFrame, phase_synth_w: float):
        self.df = df.reset_index(drop=True)
        self.phase_synth_w = float(phase_synth_w)

    def __len__(self):
        return len(self.df)

    def _load_volume(self, sid: str) -> np.ndarray:
        path = os.path.join(Config.VOLUMES_DIR, f"{sid}.npy")
        v = np.load(path).astype(np.float32)  # (D,H,W), 0..1
        v = np.nan_to_num(v, nan=0.0, posinf=1.0, neginf=0.0)
        return v

    def _load_brainmask(self, sid: str, frac: Optional[float], shp: Tuple[int,int,int]):
        if not Config.USE_BRAINMASKS:
            return None
        if frac is not None and float(frac) < Config.BRAINMASK_MIN_FRAC:
            return None
        p = os.path.join(Config.BRAINMASKS_DIR, f"{sid}_brainmask.npz")
        if not os.path.exists(p):
            return None
        try:
            bm = np.load(p)[Config.BRAINMASK_KEY].astype(np.float32)
            bm = np.nan_to_num(bm, nan=0.0, posinf=1.0, neginf=0.0)
            if bm.shape != shp or bm.sum() <= 0:
                return None
            return bm
        except Exception:
            return None

    def _load_mask(self, sid: str, mask_rel: str, is_synth: int, label: int) -> Tuple[np.ndarray, bool]:
        # Returns (mask[D,H,W] float32 in {0,1}, is_synthetic: bool)
        if isinstance(mask_rel, str) and len(mask_rel) > 0:
            if mask_rel.startswith('masks/'):
                p = os.path.join(Config.PREBUILT_ROOT, mask_rel)
                if os.path.exists(p):
                    m = np.load(p).astype(np.float32)
                    m = np.nan_to_num(m, nan=0.0, posinf=1.0, neginf=0.0)
                    return (m > 0).astype(np.float32), False
            elif mask_rel.startswith('pseudo_masks/'):
                p = os.path.join(Config.PREBUILT_ROOT, mask_rel)
                if os.path.exists(p):
                    m = np.load(p).astype(np.float32)
                    m = np.nan_to_num(m, nan=0.0, posinf=1.0, neginf=0.0)
                    return (m > 0).astype(np.float32), True
        # Fallbacks
        D,H,W = Config.TARGET_SIZE
        if int(label) == 1:
            return np.zeros((D,H,W), dtype=np.float32), True
        else:
            return np.zeros((D,H,W), dtype=np.float32), False

    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        sid = str(r['series_id'])
        label = int(r['label'])
        mask_rel = r.get('mask_relpath', '') if isinstance(r.get('mask_relpath', ''), str) else ''
        is_synth_col = int(r.get('is_synthetic', 0))
        real_seg_weight = r.get('real_seg_weight', np.nan)
        brain_frac = r.get('brain_voxel_fraction', np.nan)

        vol = self._load_volume(sid)  # (D,H,W)
        bm = self._load_brainmask(sid, brain_frac if pd.notna(brain_frac) else None, vol.shape)
        if bm is not None:
            vol = vol * bm  # gate

        mask, is_synth = self._load_mask(sid, mask_rel, is_synth_col, label)

        # per-sample seg weight
        if is_synth:
            seg_w = self.phase_synth_w
        else:
            if pd.notna(real_seg_weight):
                try:
                    rsw = float(real_seg_weight)
                except Exception:
                    rsw = Config.REAL_SEG_DEFAULT_W
            else:
                rsw = Config.REAL_SEG_DEFAULT_W
            seg_w = float(np.clip(rsw, 0.2, 1.0))

        # to tensors
        vol_t = torch.from_numpy(vol).unsqueeze(0)         # [1,D,H,W]
        mask_t = torch.from_numpy((mask > 0).astype(np.float32)).unsqueeze(0)
        label_t = torch.tensor([float(label)], dtype=torch.float32)
        segw_t  = torch.tensor([float(seg_w)], dtype=torch.float32)

        return {
            'series_id': sid,
            'volume': vol_t,
            'mask': mask_t,
            'label': label_t,
            'seg_weight': segw_t,
            'is_synthetic_mask': torch.tensor([1.0 if is_synth else 0.0], dtype=torch.float32),
        }

# ====================================================
# Simple 3D UNet + classifier head
# ====================================================
class ConvBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1), nn.GroupNorm(num_groups=8, num_channels=out_ch), nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1), nn.GroupNorm(num_groups=8, num_channels=out_ch), nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.net(x)

class UNet3D(nn.Module):
    def __init__(self, in_ch=1, base=24):
        super().__init__()
        b = base
        self.enc1 = ConvBlock3D(in_ch, b)
        self.pool1 = nn.MaxPool3d(2)
        self.enc2 = ConvBlock3D(b, b*2)
        self.pool2 = nn.MaxPool3d(2)
        self.enc3 = ConvBlock3D(b*2, b*4)
        self.pool3 = nn.MaxPool3d((2,2,2))
        self.bott = ConvBlock3D(b*4, b*8)
        self.up3 = nn.ConvTranspose3d(b*8, b*4, 2, stride=2)
        self.dec3 = ConvBlock3D(b*8, b*4)
        self.up2 = nn.ConvTranspose3d(b*4, b*2, 2, stride=2)
        self.dec2 = ConvBlock3D(b*4, b*2)
        self.up1 = nn.ConvTranspose3d(b*2, b, 2, stride=2)
        self.dec1 = ConvBlock3D(b*2, b)
        self.seg_head = nn.Conv3d(b, 1, 1)
        # classification head from bottleneck features
        self.cls_pool = nn.AdaptiveAvgPool3d(1)
        self.cls_head = nn.Linear(b*8, 1)

    def forward(self, x):  # x: [B,1,D,H,W]
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b  = self.bott(self.pool3(e3))
        # decoder
        d3 = self.up3(b)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        seg = self.seg_head(d1)  # [B,1,D,H,W]
        # classifier from bottleneck
        cls = self.cls_head(self.cls_pool(b).flatten(1))  # [B,1]
        return seg, cls

# ====================================================
# Losses
# ====================================================
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
    def forward(self, logits, targets, reduction='mean'):
        probs = torch.sigmoid(logits)
        num = 2 * (probs * targets).sum(dim=(2,3,4)) + self.eps
        den = (probs.pow(2) + targets.pow(2)).sum(dim=(2,3,4)) + self.eps
        dice = 1 - (num / den)  # per-sample
        if reduction == 'none':
            return dice
        return dice.mean()

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, eps=1e-6):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.eps = eps
    def forward(self, logits, targets, reduction='mean'):
        probs = torch.sigmoid(logits).clamp(self.eps, 1-self.eps)
        ce = -(targets*torch.log(probs) + (1-targets)*torch.log(1-probs))
        pt = torch.where(targets==1, probs, 1-probs)
        loss = self.alpha * (1-pt).pow(self.gamma) * ce
        loss = loss.mean(dim=(2,3,4))  # per-sample
        if reduction == 'none':
            return loss
        return loss.mean()

class EnhancedCombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.dice_loss = DiceLoss()
        self.focal_loss = FocalLoss(alpha=0.25, gamma=2.0)
        self.bce_vox = nn.BCEWithLogitsLoss(reduction='none')
        self.bce_cls = nn.BCEWithLogitsLoss()
    def forward(self, seg_logits, cls_logits, seg_targets, cls_targets, seg_weights: torch.Tensor):
        # clamp seg logits to avoid AMP overflow
        seg_logits = torch.nan_to_num(seg_logits, nan=0.0, posinf=20.0, neginf=-20.0)
        B = seg_logits.shape[0]
        dice_ps = self.dice_loss(seg_logits, seg_targets, reduction='none')
        focal_ps= self.focal_loss(seg_logits, seg_targets, reduction='none')
        bce_elem= self.bce_vox(seg_logits, seg_targets)
        bce_ps  = bce_elem.view(B, -1).mean(dim=1)
        dice_ps = torch.nan_to_num(dice_ps, nan=0.0)
        focal_ps= torch.nan_to_num(focal_ps, nan=0.0)
        bce_ps  = torch.nan_to_num(bce_ps,  nan=0.0)
        seg_ps  = 0.5*dice_ps + 0.3*bce_ps + Config.FOCAL_LOSS_WEIGHT*focal_ps
        seg_ps  = torch.nan_to_num(seg_ps, nan=0.0)
        seg_w   = seg_weights.view(-1)
        if (seg_w == 0).all():
            seg_loss = seg_ps.new_tensor(0.0)
        else:
            seg_loss = (seg_ps * seg_w).mean()
        # classification loss
        cls_logits = torch.nan_to_num(cls_logits, nan=0.0, posinf=20.0, neginf=-20.0)
        cls_loss= self.bce_cls(cls_logits.view(-1), cls_targets.view(-1))
        total   = seg_loss + cls_loss
        return total, seg_loss.detach(), cls_loss.detach()

# ====================================================
# Train / Validate with progress bars
# ====================================================

def train_epoch(model, loader, optimizer, criterion, scaler, epoch=None, phase_name="P1"):
    model.train()
    t_loss = t_seg = t_cls = 0.0
    n = 0
    pbar = tqdm(loader, desc=f"Train {phase_name}{'' if epoch is None else f' [ep {epoch}]'}", leave=False)
    for batch in pbar:
        vol   = batch['volume'].to(Config.DEVICE, non_blocking=True)
        mask  = batch['mask'].to(Config.DEVICE, non_blocking=True)
        label = batch['label'].to(Config.DEVICE, non_blocking=True)
        segw  = batch['seg_weight'].to(Config.DEVICE, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=Config.MIXED_PRECISION):
            seg_logits, cls_logits = model(vol)
            loss, seg_loss, cls_loss = criterion(seg_logits, cls_logits, mask, label, segw)
        try:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        except Exception:
            optimizer.zero_grad(set_to_none=True)
            continue
        bs = vol.size(0)
        t_loss += loss.item()*bs; t_seg += seg_loss.item()*bs; t_cls += cls_loss.item()*bs; n += bs
        pbar.set_postfix(loss=f"{t_loss/max(n,1):.4f}", seg=f"{t_seg/max(n,1):.4f}", cls=f"{t_cls/max(n,1):.4f}", lr=f"{optimizer.param_groups[0]['lr']:.2e}")
    return t_loss/n, t_seg/n, t_cls/n

@torch.no_grad()
def validate_epoch(model, loader, criterion, epoch=None, phase_name="P1"):
    model.eval()
    t_loss = t_seg = t_cls = 0.0
    n = 0
    all_probs = []
    all_labels= []
    pbar = tqdm(loader, desc=f"Valid {phase_name}{'' if epoch is None else f' [ep {epoch}]'}", leave=False)
    for batch in pbar:
        vol   = batch['volume'].to(Config.DEVICE, non_blocking=True)
        mask  = batch['mask'].to(Config.DEVICE, non_blocking=True)
        label = batch['label'].to(Config.DEVICE, non_blocking=True)
        segw  = batch['seg_weight'].to(Config.DEVICE, non_blocking=True)
        seg_logits, cls_logits = model(vol)
        loss, seg_loss, cls_loss = criterion(seg_logits, cls_logits, mask, label, segw)
        bs = vol.size(0)
        t_loss += loss.item()*bs; t_seg += seg_loss.item()*bs; t_cls += cls_loss.item()*bs; n += bs
        pbar.set_postfix(loss=f"{t_loss/max(n,1):.4f}", seg=f"{t_seg/max(n,1):.4f}", cls=f"{t_cls/max(n,1):.4f}")
        all_probs.append(torch.sigmoid(cls_logits).detach().cpu().view(-1).numpy())
        all_labels.append(label.detach().cpu().view(-1).numpy())
    all_probs = np.concatenate(all_probs) if len(all_probs)>0 else np.array([])
    all_labels = np.concatenate(all_labels) if len(all_labels)>0 else np.array([])
    auc = np.nan
    try:
        if len(all_probs)>0 and len(np.unique(all_labels)) > 1:
            auc = float(roc_auc_score(all_labels, all_probs))
    except Exception:
        pass
    return t_loss/n, t_seg/n, t_cls/n, auc

# ====================================================
# Main
# ====================================================

def run_training():
    set_seed(Config.SEED)
    df = load_manifest_df()

    # Build a single stratified split (can expand to CV later)
    y = df['label'].astype(int).values
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=Config.SEED)
    train_idx, val_idx = next(skf.split(np.zeros_like(y), y))
    train_df = df.iloc[train_idx].reset_index(drop=True)
    val_df   = df.iloc[val_idx].reset_index(drop=True)

    # Phase 1 datasets/loads
    ds_train_p1 = PrebuiltDataset(train_df, phase_synth_w=Config.PHASE1_SYNTH_SEG_W)
    ds_val_p1   = PrebuiltDataset(val_df,   phase_synth_w=Config.PHASE1_SYNTH_SEG_W)
    dl_train = DataLoader(ds_train_p1, batch_size=Config.STAGE1_BATCH_SIZE, shuffle=True,
                          num_workers=Config.NUM_WORKERS, pin_memory=True,
                          prefetch_factor=Config.PREFETCH_FACTOR,
                          persistent_workers=Config.PERSISTENT_WORKERS)
    dl_val   = DataLoader(ds_val_p1,   batch_size=Config.STAGE1_BATCH_SIZE, shuffle=False,
                          num_workers=Config.NUM_WORKERS, pin_memory=True,
                          prefetch_factor=Config.PREFETCH_FACTOR,
                          persistent_workers=Config.PERSISTENT_WORKERS)

    model = UNet3D(in_ch=1, base=24).to(Config.DEVICE)
    # Multi-GPU (if available)
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs via DataParallel")
        model = nn.DataParallel(model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.STAGE1_LR, weight_decay=Config.WEIGHT_DECAY)
    criterion = EnhancedCombinedLoss().to(Config.DEVICE)
    scaler = torch.amp.GradScaler('cuda', enabled=Config.MIXED_PRECISION)

    best_loss = float('inf'); best_state = None; patience = 0
    for epoch in range(1, Config.EPOCHS_PHASE1+1):
        print(f"\n[Phase 1] Epoch {epoch}/{Config.EPOCHS_PHASE1}")
        tr_loss, tr_seg, tr_cls = train_epoch(model, dl_train, optimizer, criterion, scaler, epoch=epoch, phase_name='P1')
        va_loss, va_seg, va_cls, va_auc = validate_epoch(model, dl_val, criterion, epoch=epoch, phase_name='P1')
        print(f"Train Loss: {tr_loss:.4f} | Seg: {tr_seg:.4f} | Cls: {tr_cls:.4f} | GPU {gpu_mem_str()}")
        print(f" Val  Loss: {va_loss:.4f} | Seg: {va_seg:.4f} | Cls: {va_cls:.4f} | AUC: {va_auc if not np.isnan(va_auc) else 'NA'} | GPU {gpu_mem_str()}")
        if va_loss < best_loss - 1e-5:
            best_loss = va_loss; best_state = {k:v.detach().cpu() for k,v in (model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()).items()}; patience = 0
            torch.save(best_state, 'stage1_phase1_best.pth')
            print("💾 Saved Phase 1 best checkpoint")
        else:
            patience += 1
            if patience >= Config.EARLY_STOP_PATIENCE:
                print("Early stopping Phase 1")
                break

    if best_state is not None:
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(best_state)
        else:
            model.load_state_dict(best_state)

    # Phase 2: small synthetic weight
    print("\n====== PHASE 2: enabling small synthetic seg supervision ======")
    ds_train_p2 = PrebuiltDataset(train_df, phase_synth_w=Config.PHASE2_SYNTH_SEG_W)
    ds_val_p2   = PrebuiltDataset(val_df,   phase_synth_w=Config.PHASE2_SYNTH_SEG_W)
    dl_train2 = DataLoader(ds_train_p2, batch_size=Config.STAGE1_BATCH_SIZE, shuffle=True,
                           num_workers=Config.NUM_WORKERS, pin_memory=True,
                           prefetch_factor=Config.PREFETCH_FACTOR,
                           persistent_workers=Config.PERSISTENT_WORKERS)
    dl_val2   = DataLoader(ds_val_p2,   batch_size=Config.STAGE1_BATCH_SIZE, shuffle=False,
                           num_workers=Config.NUM_WORKERS, pin_memory=True,
                           prefetch_factor=Config.PREFETCH_FACTOR,
                           persistent_workers=Config.PERSISTENT_WORKERS)

    # optional: lower LR a bit for fine-tune
    for g in optimizer.param_groups:
        g['lr'] = Config.STAGE1_LR * 0.5

    best2 = float('inf'); best2_state = None; patience = 0
    for epoch in range(1, Config.EPOCHS_PHASE2+1):
        print(f"\n[Phase 2] Epoch {epoch}/{Config.EPOCHS_PHASE2}")
        tr_loss, tr_seg, tr_cls = train_epoch(model, dl_train2, optimizer, criterion, scaler, epoch=epoch, phase_name='P2')
        va_loss, va_seg, va_cls, va_auc = validate_epoch(model, dl_val2, criterion, epoch=epoch, phase_name='P2')
        print(f"Train Loss: {tr_loss:.4f} | Seg: {tr_seg:.4f} | Cls: {tr_cls:.4f} | GPU {gpu_mem_str()}")
        print(f" Val  Loss: {va_loss:.4f} | Seg: {va_seg:.4f} | Cls: {va_cls:.4f} | AUC: {va_auc if not np.isnan(va_auc) else 'NA'} | GPU {gpu_mem_str()}")
        if va_loss < best2 - 1e-5:
            best2 = va_loss; best2_state = {k:v.detach().cpu() for k,v in (model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()).items()}; patience = 0
            torch.save(best2_state, 'stage1_phase2_best.pth')
            print("💾 Saved Phase 2 best checkpoint")
        else:
            patience += 1
            if patience >= Config.EARLY_STOP_PATIENCE:
                print("Early stopping Phase 2")
                break

    final_state = best2_state or best_state or (model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict())
    torch.save(final_state, 'stage1_segmentation_best.pth')
    print("\n✅ Stage 1 complete. Saved: stage1_segmentation_best.pth")


if __name__ == '__main__':
    run_training()


In [None]:
# import os, numpy as np, torch, matplotlib.pyplot as plt, cv2, pandas as pd
# from torch.utils.data import DataLoader, Subset

# device = Config.DEVICE

# # Rebuild the same validation split used in training (first 20% of rows)
# df = pd.read_csv(Config.TRAIN_CSV_PATH)
# if Config.DEBUG_MODE:
#     df = df.head(Config.DEBUG_SAMPLES)
# val_size = len(df) // 5
# val_df = df[:val_size].reset_index(drop=True)

# # If you exported a prebuilt dataset, enable this for faster loading
# # Config.USE_EXTERNAL_CACHE = True
# # Config.EXTERNAL_CACHE_DIR = '/kaggle/input/<your-prebuilt-dataset-name>/stage1_prebuilt'

# processor = SimpleDICOMProcessor()
# val_dataset = SimpleSegmentationDataset(val_df, Config.SERIES_DIR, processor, 'val')

# def collate_segmentation_batch(batch_list):
#     volumes = torch.stack([sample['volume'].contiguous() for sample in batch_list], dim=0)
#     masks = torch.stack([sample['mask'].contiguous() for sample in batch_list], dim=0)
#     has = torch.stack([
#         sample['has_aneurysm'] if isinstance(sample['has_aneurysm'], torch.Tensor)
#         else torch.tensor(sample['has_aneurysm'], dtype=torch.float32)
#         for sample in batch_list
#     ], dim=0)
#     synth = torch.stack([
#         sample.get('is_synthetic_mask', torch.tensor(0.0, dtype=torch.float32))
#         for sample in batch_list
#     ], dim=0)
#     series_ids = [sample['series_id'] for sample in batch_list]
#     return {'volume': volumes, 'mask': masks, 'has_aneurysm': has, 'is_synthetic_mask': synth, 'series_id': series_ids}

# loader_kwargs = dict(batch_size=2, num_workers=4, pin_memory=True, persistent_workers=True, prefetch_factor=4)
# val_loader_small = DataLoader(Subset(val_dataset, list(range(min(20, len(val_dataset))))),
#                               shuffle=False, drop_last=False, collate_fn=collate_segmentation_batch, **loader_kwargs)

# # Load model
# model = Simple3DSegmentationNet().to(device)
# ckpt = torch.load('stage1_segmentation_best.pth', map_location=device)
# state = ckpt.get('model_state_dict', ckpt)
# if any(k.startswith('module.') for k in state.keys()):
#     state = {k.replace('module.', '', 1): v for k, v in state.items()}
# _ = model.load_state_dict(state, strict=False)
# model.eval()
# print("Model loaded for inference.")

In [None]:
# to_show = 20
# shown = 0
# thr = 0.5

# for batch in val_loader_small:
#     with torch.no_grad():
#         vol = batch['volume'].to(device, non_blocking=True)
#         seg_logits, cls_logits = model(vol)
#         seg_prob = torch.sigmoid(seg_logits).float().cpu().numpy()  # [B,1,D,H,W]
#         volumes = batch['volume'].float().cpu().numpy()              # [B,1,D,H,W]
#     B = volumes.shape[0]
#     for i in range(B):
#         series_id = batch['series_id'][i]
#         v = volumes[i, 0]  # [D,H,W]
#         p = seg_prob[i, 0] # [D,H,W]

#         # choose best slice by max predicted prob
#         slice_idx = int(np.argmax(p.max(axis=(1, 2))))
#         img = v[slice_idx]
#         mask = p[slice_idx]

#         # overlay heatmap
#         overlay = (np.clip(img, 0, 1) * 255).astype(np.uint8)
#         heat = (np.clip(mask, 0, 1) * 255).astype(np.uint8)
#         heat_color = cv2.applyColorMap(heat, cv2.COLORMAP_JET)
#         heat_color = cv2.cvtColor(heat_color, cv2.COLOR_BGR2RGB)
#         overlay_rgb = np.stack([overlay]*3, axis=-1)
#         alpha = 0.4
#         blended = ((1 - alpha) * overlay_rgb + alpha * heat_color).astype(np.uint8)

#         # thresholded contours
#         binmask = (mask > thr).astype(np.uint8)
#         contours, _ = cv2.findContours(binmask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
#         blended_contour = blended.copy()
#         cv2.drawContours(blended_contour, contours, -1, (0, 255, 0), 1)

#         # quick stats
#         has_an = int(batch['has_aneurysm'][i].item())
#         print(f"{series_id} | label={has_an} | slice={slice_idx} | maxprob={mask.max():.3f} | posfrac={binmask.mean():.4f} | n_comp={len(contours)}")

#         # show and save
#         plt.figure(figsize=(9, 3))
#         plt.subplot(1, 3, 1); plt.imshow(img, cmap='gray'); plt.axis('off'); plt.title('Slice')
#         plt.subplot(1, 3, 2); plt.imshow(blended); plt.axis('off'); plt.title('Heatmap overlay')
#         plt.subplot(1, 3, 3); plt.imshow(blended_contour); plt.axis('off'); plt.title(f'> {thr} contours')
#         plt.tight_layout()
#         out_path = f"/kaggle/working/viz_{series_id}_s{slice_idx}.png"
#         plt.savefig(out_path, dpi=120, bbox_inches='tight'); plt.show()

#         shown += 1
#         if shown >= to_show:
#             break
#     if shown >= to_show:
#         break

# print("Saved visualizations to /kaggle/working for later inspection.")

In [None]:
# from scipy import ndimage

# def quick_3d_summary(prob_3d, thr=0.5):
#     bin3d = (prob_3d > thr).astype(np.uint8)
#     if bin3d.max() == 0:
#         return dict(posfrac=0.0, n_comp=0, largest=0)
#     lab, n = ndimage.label(bin3d)
#     sizes = [(lab==i).sum() for i in range(1, n+1)]
#     return dict(posfrac=bin3d.mean(), n_comp=n, largest=max(sizes) if sizes else 0)

# # Example on first few from val_loader_small
# with torch.no_grad():
#     for batch in val_loader_small:
#         vol = batch['volume'].to(device)
#         seg_prob = torch.sigmoid(model(vol)[0]).float().cpu().numpy()
#         for i in range(seg_prob.shape[0]):
#             s = quick_3d_summary(seg_prob[i,0], thr=0.5)
#             print(batch['series_id'][i], s)
#         break