In [None]:
# ====================================================
# CELL 1: IMPORTS & CONFIG
# ====================================================

import os
# Mitigate CUDA memory fragmentation (must be set before torch import)
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
import shutil
import math
import time
import numpy as np
import pandas as pd
import polars as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torch.nn.functional as F
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
import cv2
import pydicom
import nibabel as nib
from scipy import ndimage
from scipy.ndimage import label, center_of_mass
from PIL import Image
from sklearn.model_selection import StratifiedKFold, train_test_split
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from sklearn.metrics import roc_auc_score
import kaggle_evaluation.rsna_inference_server
from collections import defaultdict
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import gc

# Thread limits for Kaggle (idempotent & safe)
try:
    import cv2 as _cv2_for_threads
    try:
        _cv2_for_threads.setNumThreads(0)
    except Exception:
        pass
except Exception:
    pass
try:
    os.environ.setdefault('OMP_NUM_THREADS', '1')
    os.environ.setdefault('MKL_NUM_THREADS', '1')
    torch.set_num_threads(1)
except Exception:
    pass



In [None]:
# Competition Configuration
class Config:
    # Paths
    TRAIN_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train.csv'
    SERIES_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/series/'
    SEGMENTATION_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/segmentations/'
    
    # Stage 2 Configuration
    ROI_SIZE = (224, 224)
    ROIS_PER_SERIES = 5
    BATCH_SIZE = 8  # Smaller batch = more gradient updates per epoch (434 vs 217 steps)
    VAL_BATCH_SIZE = 8  # Match training batch size
    GRAD_ACCUM_STEPS = 6  # Effective batch size: 8 × 6 = 48 (maintain throughput)
    VAL_MAX_ENCODER_TOKENS = 384
    EPOCHS = 3  # QUICK TEST: Verify predictions are varying before full run
    LEARNING_RATE = 2e-4  # EVEN HIGHER: Strong signal to break uniform predictions
    N_FOLDS = 1  # SINGLE FOLD TEST: Fast verification (~40 min)
    
    # Competition constants
    ID_COL = 'SeriesInstanceUID'
    LABEL_COLS = [
        'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery',
        'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
        'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery', 'Anterior Communicating Artery',
        'Left Anterior Cerebral Artery', 'Right Anterior Cerebral Artery',
        'Left Posterior Communicating Artery', 'Right Posterior Communicating Artery',
        'Basilar Tip', 'Other Posterior Circulation', 'Aneurysm Present',
    ]
    
    # Device and training
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    MIXED_PRECISION = True
    MICRO_BATCH_SIZE = 8  # MATCH batch size
    AUTO_TUNE_MICRO = False  # DISABLED: Can cause hanging with DataLoader workers
    SLICE_SUBSAMPLE_TRAIN = 0  # DISABLED: Model expects fixed 32 slices (24 causes shape mismatch)
    STAGE2_CACHE_DIR = '/kaggle/working/stage2_cache'
    DEBUG_MODE = False
    DEBUG_SAMPLES = 0
    REUSE_EXISTING_ROIS = True  # if cached training_df exists, reuse to skip long ROI extraction
    DIRECT_VOLUME_MODE = True   # Direct volume mode (top_example-style) using Stage 0 32x384x384 volumes
    # STAGE0_PREBUILT_ROOT = '/kaggle/input/rsna2025-v2-intracranial-aneurysm-detection-nb153/stage1_AneurysmNet_prebuilt_v2'
    NUM_WORKERS = 4 #12
    PREFETCH_FACTOR = 2 #8
    PIN_MEMORY = True
    PERSISTENT_WORKERS = True
    CACHE_VOLUMES = True
    USE_SHARD_LOADER = False  # DISABLED: Load .npz shards directly, not via shard_loader.py wrapper
    CACHE_DIR = ''  # disable single-dir cache
    CACHE_DIRS = [
        '/kaggle/input/rsna2025-shard0-fp16nb153/stage2_cache_vols',
        '/kaggle/input/rsna2025-shard1-fp16nb153/stage2_cache_vols',
        '/kaggle/input/rsna2025-shard2-fp16nb153/stage2_cache_vols',
    ]
    CACHE_DTYPE = 'float16'  # use higher-fidelity cache
    CACHE_VERBOSE = False  # set True to log cache hits/writes/skips
    CACHE_LOG_EVERY_N = 100

    # Shard loading configuration
    SHARD_CHANNEL_MODE = 'cta'  # 'cta' for rsna2p5d, 'best3' for mil2p5d
    SHARD_DEPTH_SAMPLING = 'uniform'  # 'uniform', 'center_weighted', 'interpolate'
    SHARD_TARGET_SPATIAL = (384, 384)  # Will be overridden to (320, 320) for mil2p5d
    MODEL_DIRS = [ 
        '/kaggle/working',  # Where models are saved during training        
        #'/kaggle/input/rsna2025-stage2-5fold-32ch-f16/pytorch/default/3'
    ]
    # Control training in local/dev mode
    TRAIN_ON_START = True
    # Early stopping
    EARLY_STOPPING_PATIENCE = 3

    # Validation schedule (fast vs full)
    FAST_VAL = True
    FAST_VAL_EVERY = 1
    FULL_VAL_EVERY = 3
    RUN_FULL_ON_EPOCH_1 = False
    FAST_VAL_SUBSET_FRAC = 0.33
    FAST_VAL_MAX_TOKENS = 256
    FAST_VAL_IMPROVE_EPS = 0.002

    # Architecture selection: 'rsna2p5d' or 'mil2p5d'
    MODEL_ARCH = 'mil2p5d'  # SESSION 1: RSNA2P5D Training

    # MIL transformer params - OPTIMIZED FOR SPEED
    # Competition-winning approach: Balance speed vs quality
    MIL_BACKBONE = 'tf_efficientnet_b0'  # B0 is 3-4x faster than B3, minimal quality loss
    MIL_SPATIAL_SIZE = 224  # Reduced from 320 to 224 (2x faster, medical models tolerate this)
    MIL_D_MODEL = 512  # Reduced from 768 (faster attention, still competitive)
    MIL_NHEAD = 8
    MIL_N_LAYERS = 2  # Keep at 2 for balance
    MIL_USE_GRAD_CHECKPOINT = True  # Saves VRAM, allows larger batches
    MIL_COMPILE_ENCODER = False  # Set True if PyTorch >= 2.0 available

    # Training extras
    USE_WEIGHTED_SAMPLER = False
    USE_EMA = False 
    EMA_DECAY = 0.999

    # Hierarchical consistency loss (sites <= "Aneurysm Present")
    ENABLE_HIER_LOSS = False
    HIER_LOSS_LAMBDA = 0.2

    # Asymmetric Loss (ASL) for site logits
    USE_ASL = False
    ASL_GAMMA_POS = 0.0
    ASL_GAMMA_NEG = 4.0
    ASL_CLIP = 0.05
    ASL_WEIGHT_SITES = 1.0
    MAIN_BCE_WEIGHT = 1.0
    
    # Class weight cap (prevent extreme weights like 93.5)
    CLASS_WEIGHT_CAP = 20.0  # ENABLED: Allow class weights for imbalanced data (capped at 30)
    # This lets pos_weight reflect true imbalance (rare sites + 13× "Aneurysm Present")

print(f"✅ Configuration loaded - Device: {Config.DEVICE}")

# Speed-friendly backend settings
try:
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_float32_matmul_precision('high')
except Exception:
    pass
    
# ====================================================
# CELL 1.1: LIGHTWEIGHT DICOM PREPROCESSOR (32x384x384)
# ====================================================

class DICOMPreprocessorKaggle:
    """Minimal, memory-safe DICOM → (32,384,384) volume preprocessor (offline, no deps)."""
    def __init__(self, target_shape=(32, 384, 384)):
        self.target_depth, self.target_height, self.target_width = target_shape

    def process_series(self, series_path: str) -> np.ndarray:
        sid = os.path.basename(series_path.rstrip('/'))
        
        # Load .npz shards directly (64, 384, 384, 4) uint8
        if getattr(Config, 'CACHE_VOLUMES', False):
            for cache_dir in getattr(Config, 'CACHE_DIRS', []):
                if not isinstance(cache_dir, str) or not len(cache_dir):
                    continue
                npz_path = os.path.join(cache_dir, f"{sid}.npz")
                if os.path.exists(npz_path):
                    try:
                        # Load compressed .npz (memory-mapped for speed)
                        data = np.load(npz_path)
                        vol_u8 = data['vol']  # (64, 384, 384, 4) uint8
                        
                        # Extract CTA channel (channel 0) for rsna2p5d
                        vol_u8_cta = vol_u8[..., 0]  # (64, 384, 384)
                        
                        # Downsample from 64 to 32 slices (uniform sampling)
                        indices = np.linspace(0, vol_u8_cta.shape[0] - 1, self.target_depth).astype(np.int32)
                        vol_u8_32 = vol_u8_cta[indices]  # (32, 384, 384)
                        
                        # Convert uint8 [0,255] to float [0,1]
                        if getattr(Config, 'CACHE_DTYPE', 'float16') == 'float16':
                            vol = (vol_u8_32.astype(np.float16)) / 255.0
                        else:
                            vol = (vol_u8_32.astype(np.float32)) / 255.0
                        
                        if getattr(Config, 'CACHE_VERBOSE', False):
                            print(f"[NPZ] Loaded {sid} from {os.path.basename(cache_dir)}")
                        
                        return vol
                    except Exception as e:
                        if getattr(Config, 'CACHE_VERBOSE', False):
                            print(f"[NPZ] Error loading {npz_path}: {e}")
                        continue
        
        # OLD: Fallback to legacy _f16.npy cache format
        if getattr(Config, 'CACHE_VOLUMES', False):
            try:
                # Build a read list: prefer shard dirs then single CACHE_DIR
                read_dirs = []
                if isinstance(getattr(Config, 'CACHE_DIRS', []), list) and len(Config.CACHE_DIRS) > 0:
                    read_dirs.extend([d for d in Config.CACHE_DIRS if isinstance(d, str) and len(d)])
                if isinstance(getattr(Config, 'CACHE_DIR', ''), str) and len(Config.CACHE_DIR):
                    read_dirs.append(Config.CACHE_DIR)
                # For each dir, try float16 only (no uint8 fallback)
                hit_count = getattr(self, '_cache_hit_count', 0)
                for d in read_dirs:
                    cache_base = os.path.join(d, f"{sid}_32x384x384")
                    cache_path_try = cache_base + '_f16.npy'
                    if os.path.exists(cache_path_try):
                        cached = np.load(cache_path_try, mmap_mode='r')
                        # Keep FP16 on CPU to reduce bandwidth
                        if cached.dtype == np.float16 or getattr(Config, 'CACHE_DTYPE', 'float16') == 'float16':
                            vol = cached.astype(np.float16)
                        else:
                            vol = cached.astype(np.float32)
                        if getattr(Config, 'CACHE_VERBOSE', False):
                            hit_count += 1
                            self._cache_hit_count = hit_count
                            if hit_count % max(1, getattr(Config, 'CACHE_LOG_EVERY_N', 50)) == 0:
                                print(f"[CACHE] hit: {d}/{os.path.basename(cache_path_try)}")
                        return vol
            except Exception:
                cache_path = None
        # Collect DICOMs
        dicoms = []
        for root, _, files in os.walk(series_path):
            for f in files:
                if f.endswith('.dcm'):
                    try:
                        ds = pydicom.dcmread(os.path.join(root, f), force=True)
                        if hasattr(ds, 'PixelData'):
                            dicoms.append(ds)
                    except Exception:
                        continue
        if len(dicoms) == 0:
            return np.zeros((self.target_depth, self.target_height, self.target_width), dtype=np.float32)

        # Sort by patient-space normal if possible, else by InstanceNumber
        try:
            orient = np.array(dicoms[0].ImageOrientationPatient, dtype=np.float32)
            rowv, colv = orient[:3], orient[3:]
            normal = np.cross(rowv, colv)
            def sort_key(ds):
                ipp = np.array(getattr(ds, 'ImagePositionPatient', [0,0,0]), dtype=np.float32)
                return float(np.dot(ipp, normal))
            dicoms = sorted(dicoms, key=sort_key)
        except Exception:
            dicoms = sorted(dicoms, key=lambda ds: getattr(ds, 'InstanceNumber', 0))

        base_h = int(getattr(dicoms[0], 'Rows', 256))
        base_w = int(getattr(dicoms[0], 'Columns', 256))
        c, w = 50.0, 350.0
        lo, hi = c - w/2.0, c + w/2.0
        modality = (getattr(dicoms[0], 'Modality', '') or '').upper()

        slices = []
        for ds in dicoms:
            try:
                fr = ds.pixel_array
            except Exception:
                continue
            if fr.ndim >= 3:
                h, w2 = fr.shape[-2], fr.shape[-1]
                frames = fr.reshape(int(np.prod(fr.shape[:-2])), h, w2)
            else:
                frames = fr[np.newaxis, ...]
            for sl in frames:
                sl = sl.astype(np.float32)
                if getattr(ds, 'PhotometricInterpretation', 'MONOCHROME2') == 'MONOCHROME1':
                    sl = sl.max() - sl
                slope = float(getattr(ds, 'RescaleSlope', 1.0)); intercept = float(getattr(ds, 'RescaleIntercept', 0.0))
                sl = sl * slope + intercept
                if sl.shape != (base_h, base_w):
                    sl = cv2.resize(sl, (base_w, base_h))
                if modality == 'CT':
                    s = np.clip(sl, lo, hi)
                    s = (s - lo) / (hi - lo + 1e-6)
                else:
                    mean = float(sl.mean()); std = float(sl.std() + 1e-6)
                    s = (sl - mean) / std; zc = 3.0
                    s = np.clip(s, -zc, zc); s = (s + zc) / (2.0*zc)
                slices.append(s.astype(np.float32))

        volf = np.stack(slices, axis=0) if slices else np.zeros((1, base_h, base_w), dtype=np.float32)
        # Trilinear depth+inplane resampling to (D,H,W)
        try:
            v = torch.from_numpy(volf)[None, None].to(dtype=torch.float32)  # (1,1,D,H,W)
            v = F.interpolate(
                v,
                size=(self.target_depth, self.target_height, self.target_width),
                mode='trilinear',
                align_corners=False,
            )
            out = v[0, 0].numpy().astype(np.float32)
        except Exception:
            # Fallback to previous per-slice resize if torch not available
            D = volf.shape[0]
            idx = np.linspace(0, max(D-1,0), num=self.target_depth).astype(int) if D>0 else np.zeros(self.target_depth, dtype=int)
            vT = volf[idx]
            out = np.empty((self.target_depth, self.target_height, self.target_width), dtype=np.float32)
            for i in range(self.target_depth):
                out[i] = cv2.resize(vT[i], (self.target_width, self.target_height))
        p1, p99 = np.percentile(out, [1, 99])
        if p99 > p1:
            out = np.clip(out, p1, p99)
            out = (out - p1) / (p99 - p1 + 1e-8)
        out = np.nan_to_num(out, nan=0.0, posinf=1.0, neginf=0.0)
        # Save cache using configured dtype
        try:
            if getattr(Config, 'CACHE_VOLUMES', False):
                # Respect soft size cap
                try:
                    total_bytes = 0
                    for f in os.listdir(Config.CACHE_DIR):
                        fp = os.path.join(Config.CACHE_DIR, f)
                        try:
                            total_bytes += os.path.getsize(fp)
                        except Exception:
                            pass
                    if total_bytes > Config.CACHE_MAX_GB * (1024**3):
                        if getattr(Config, 'CACHE_VERBOSE', False):
                            print("[CACHE] size cap reached; skip saving")
                        raise RuntimeError('Cache size cap reached, skip saving')
                except Exception:
                    pass
                # Ensure cache dir exists and choose dtype and filename suffix
                os.makedirs(Config.CACHE_DIR, exist_ok=True)
                sid = os.path.basename(series_path.rstrip('/'))
                cache_base = os.path.join(Config.CACHE_DIR, f"{sid}_32x384x384")
                dtype_choice = getattr(Config, 'CACHE_DTYPE', 'float16')
                if dtype_choice == 'uint8':
                    # Quantize to 8-bit [0,255] after normalization
                    arr8 = (out * 255.0).round().astype(np.uint8)
                    # Optional quick verification to ensure acceptable quantization error
                    try:
                        diff_max = float(np.max(np.abs(out - (arr8.astype(np.float32) / 255.0))))
                        if diff_max > 0.02:  # fallback threshold for safety
                            # Fallback to float16 if quantization too lossy
                            np.save(cache_base + '_f16.npy', out.astype(np.float16), allow_pickle=False)
                            if getattr(Config, 'CACHE_VERBOSE', False):
                                print(f"[CACHE] wrote f16 (fallback): {os.path.basename(cache_base)}_f16.npy")
                        else:
                            np.save(cache_base + '_u8.npy', arr8, allow_pickle=False)
                            if getattr(Config, 'CACHE_VERBOSE', False):
                                print(f"[CACHE] wrote u8: {os.path.basename(cache_base)}_u8.npy (max_err={diff_max:.4f})")
                    except Exception:
                        np.save(cache_base + '_u8.npy', arr8, allow_pickle=False)
                        if getattr(Config, 'CACHE_VERBOSE', False):
                            print(f"[CACHE] wrote u8 (no-check): {os.path.basename(cache_base)}_u8.npy")
                else:
                    np.save(cache_base + '_f16.npy', out.astype(np.float16), allow_pickle=False)
                    if getattr(Config, 'CACHE_VERBOSE', False):
                        print(f"[CACHE] wrote f16: {os.path.basename(cache_base)}_f16.npy")
        except Exception:
            pass
        # Return dtype aligned with cache preference for lower CPU bandwidth
        if getattr(Config, 'CACHE_DTYPE', 'float16') == 'float16':
            return out.astype(np.float16)
        else:
            return out.astype(np.float32)

def process_dicom_series_safe(series_path: str, target_shape=(32,384,384)) -> np.ndarray:
    try:
        pre = DICOMPreprocessorKaggle(target_shape)
        return pre.process_series(series_path)
    finally:
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception:
            pass
        try:
            gc.collect()
        except Exception:
            pass


# ====================================================
# CELL 2: DATA LOADING & ROI EXTRACTION
# ====================================================

class Simple3DSegmentationNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.dummy = nn.Identity()
    def forward(self, x):
        return self.dummy(x)

class Stage1Predictor:
    def __init__(self, *args, **kwargs):
        pass
    def predict_segmentation_with_volume(self, series_path):
        # Not used in direct-volume path; kept minimal for compatibility
        return np.zeros((1,1,1), dtype=np.float32), np.zeros((32,384,384), dtype=np.float32)

class SimpleDICOMProcessor:
    def __init__(self, *args, **kwargs):
        pass
        
class ROIExtractor:
    """Research-backed ROI extraction with adaptive count and quality filtering"""
    def __init__(self, stage1_predictor, roi_size=(224, 224)):
        self.stage1_predictor = stage1_predictor
        self.roi_size = roi_size
        self.processor = SimpleDICOMProcessor()

        # Research-backed thresholds
        # Relaxed thresholds to avoid over-pruning when Stage 1 is weak
        self.min_confidence_threshold = 0.15
        self.high_confidence_threshold = 0.5
        self.max_rois_per_series = getattr(Config, 'ROIS_PER_SERIES', 3)
        # Post-process controls
        self.border_margin = 2            # suppress edge activations near skull
        self.min_region_size = 6         # minimum connected component size (pixels)
        self.morph_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))

    def extract_top3_rois(self, series_path):
        """Extract 0-5 ROIs based on segmentation quality (research-backed)"""
        # Cache ROI results per series to avoid recomputation
        try:
            # Prefer external cache dir when available
            cache_root = getattr(Config, 'STAGE2_CACHE_EXTERNAL_DIR', '')
            if not isinstance(cache_root, str) or not os.path.isdir(cache_root):
                cache_root = Config.STAGE2_CACHE_DIR
            os.makedirs(cache_root, exist_ok=True)
            sid = os.path.basename(series_path)
            cache_path = os.path.join(cache_root, f"{sid}_rois.npy")
            if os.path.exists(cache_path):
                arr = np.load(cache_path, allow_pickle=True)
                return list(arr)
        except Exception:
            cache_path = None
        rois = self.extract_adaptive_rois(series_path)
        try:
            if cache_path is not None:
                np.save(cache_path, np.array(rois, dtype=object), allow_pickle=True)
        except Exception:
            pass
        return rois

    def extract_adaptive_rois(self, series_path):
        """Extract 0-5 ROIs based on segmentation quality (research-backed)"""
        try:
            print(f"🔍 DEBUG: Quality-based ROI extraction for {os.path.basename(series_path)}")
            
            # Get Stage 1 seg mask and the preprocessed volume (avoid reloading original DICOMs here)
            seg_mask, original_volume = self.stage1_predictor.predict_segmentation_with_volume(series_path)
            print(f"🔍 DEBUG: Segmentation mask shape: {seg_mask.shape}; Volume shape: {original_volume.shape}")
            
            # STEP 1: Assess overall segmentation quality
            seg_quality = self._assess_segmentation_quality(seg_mask)
            print(f"🔍 DEBUG: Segmentation quality score: {seg_quality:.3f}")
            
            # STEP 2: If segmentation is poor, still attempt candidate extraction; fallback only if none
            low_quality = seg_quality < self.min_confidence_threshold
            if low_quality:
                print(f"🔍 DEBUG: Low segmentation quality ({seg_quality:.3f} < {self.min_confidence_threshold}), attempting candidate extraction anyway")
            
            # STEP 4: Extract ROIs with confidence-based filtering
            roi_candidates = self._find_quality_based_rois(seg_mask, original_volume)
            
            if low_quality and not roi_candidates:
                print("🔍 DEBUG: No candidates under low-quality mask, using volume-based fallback")
                return self._get_quality_fallback_rois_from_volume(original_volume, self.max_rois_per_series)

            # STEP 5: Adaptive ROI count
            selected_rois = self._select_adaptive_rois(roi_candidates, seg_quality, original_volume)
            
            print(f"🔍 DEBUG: Selected {len(selected_rois)} ROIs based on quality assessment")
            return selected_rois
            
        except Exception as e:
            print(f"❌ Error in quality-based ROI extraction: {e}")
            return self._get_emergency_fallback_rois()
    
    def _assess_segmentation_quality(self, seg_mask):
        """Assess segmentation quality using connected components and border penalties."""
        try:
            D, H, W = seg_mask.shape
            largest_area_frac = 0.0
            largest_mean_conf = 0.0
            total_components = 0
            border_touch_penalty = 0.0

            for z in range(D):
                sm = seg_mask[z]
                # suppress borders
                sm_proc = sm.copy()
                sm_proc[:self.border_margin, :] = 0
                sm_proc[-self.border_margin:, :] = 0
                sm_proc[:, :self.border_margin] = 0
                sm_proc[:, -self.border_margin:] = 0

               # Adaptive thresholding based on actual max values
                max_val = float(sm_proc.max())
                if max_val > 0.3:
                    thr = max(0.05, 0.3 * max_val)
                elif max_val > 0.1:
                    thr = max(0.03, 0.4 * max_val)
                else:
                    thr = max(0.02, 0.5 * max_val)
                binmask = (sm_proc > thr).astype(np.uint8)
                if binmask.max() == 0:
                    continue
                # small opening to remove speckle
                binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, self.morph_kernel)

                labeled, n = label(binmask)
                if n == 0:
                    continue
                total_components += int(n)

                # evaluate components
                for comp_id in range(1, n + 1):
                    comp = (labeled == comp_id)
                    comp_size = int(comp.sum())
                    if comp_size < self.min_region_size:
                        continue
                    mean_conf = float(sm[comp].mean())
                    area_frac = comp_size / float(H * W)
                    if area_frac > largest_area_frac:
                        largest_area_frac = area_frac
                    if mean_conf > largest_mean_conf:
                        largest_mean_conf = mean_conf

                    # simple border-touch penalty if component abuts image edge
                    ys, xs = np.where(comp)
                    if ys.size > 0:
                        if (ys.min() <= self.border_margin or ys.max() >= H - self.border_margin - 1 or
                            xs.min() <= self.border_margin or xs.max() >= W - self.border_margin - 1):
                            border_touch_penalty += 0.02

            # compose quality score
            area_score = min(largest_area_frac / 0.02, 1.0)  # cap around ~2% of slice (aneurysm-sized)
            comp_penalty = min(0.1, 0.0015 * total_components) + min(0.1, border_touch_penalty)
            quality_score = max(0.0, 0.6 * largest_mean_conf + 0.4 * area_score - comp_penalty)

            # robust floor based on global mask stats to avoid spurious 0.0 quality
            max_val = float(seg_mask.max())
            mean_val = float(seg_mask.mean())
            if max_val >= 0.55:
                quality_score = max(quality_score, 0.35)
            elif max_val >= 0.45:
                quality_score = max(quality_score, 0.25)
            elif mean_val >= 0.25:
                quality_score = max(quality_score, 0.22)

            return float(quality_score)
        except Exception:
            return 0.1
    
    def _find_quality_based_rois(self, seg_mask, original_volume):
        """Find ROI candidates with confidence scores (no hardcoded count)"""
        print("🔍 DEBUG: Finding quality-based ROI candidates...")
        
        # Resize segmentation mask to match original volume
        if seg_mask.shape != original_volume.shape:
            print("🔍 DEBUG: Resizing segmentation mask with cv2...")
            seg_mask_resized = np.zeros(original_volume.shape, dtype=np.float32)
            for i in range(min(seg_mask.shape[0], original_volume.shape[0])):
                if i < seg_mask.shape[0]:
                    resized_slice = cv2.resize(
                        seg_mask[i],
                        (original_volume.shape[2], original_volume.shape[1])
                    )
                    seg_mask_resized[i] = resized_slice
        else:
            seg_mask_resized = seg_mask
        
        # 3D peak proposals first (relative peak logic; does not lower thresholds)
        roi_candidates = self._proposals_from_3d_peaks(seg_mask_resized)
        if len(roi_candidates) == 0:
            # Fall back to 2D slice-wise CC method
            roi_candidates = []
        
        H, W = original_volume.shape[1], original_volume.shape[2]
        for slice_idx in range(seg_mask_resized.shape[0]):
            slice_mask = seg_mask_resized[slice_idx].copy()

            # Suppress borders to avoid skull/edge activations
            slice_mask[:self.border_margin, :] = 0
            slice_mask[-self.border_margin:, :] = 0
            slice_mask[:, :self.border_margin] = 0
            slice_mask[:, -self.border_margin:] = 0

            # Adaptive dynamic threshold tied to local max (aligned with quality assessment)
            max_val = float(slice_mask.max())
            if max_val > 0.2:
                thr = max(self.min_confidence_threshold, 0.25 * max_val)
            elif max_val > 0.1:
                thr = max(0.03, 0.30 * max_val)
            else:
                thr = max(0.02, 0.25 * max_val)
            high_conf_regions = (slice_mask > thr).astype(np.uint8)
            if high_conf_regions.max() == 0:
                # Percentile-based fallback with small dilation to form blobs
                p90 = float(np.percentile(slice_mask, 90))
                if p90 > 0:
                    mask_peaks = (slice_mask >= p90).astype(np.uint8)
                    # small dilation to merge nearby high pixels
                    mask_peaks = cv2.dilate(mask_peaks, self.morph_kernel, iterations=1)
                    labeled_regions, num_regions = label(mask_peaks)
                    for region_id in range(1, num_regions + 1):
                        region_mask = (labeled_regions == region_id)
                        region_size = int(region_mask.sum())
                        if region_size < 3:
                            continue
                        ys, xs = np.where(region_mask)
                        if ys.size == 0:
                            continue
                        # Skip borders
                        if (ys.min() <= self.border_margin or ys.max() >= H - self.border_margin - 1 or
                            xs.min() <= self.border_margin or xs.max() >= W - self.border_margin - 1):
                            continue
                        com = center_of_mass(region_mask)
                        y, x = int(com[0]), int(com[1])
                        region_confidence = float(slice_mask[region_mask].mean())
                        roi_candidates.append({
                            'slice_idx': slice_idx,
                            'y': y,
                            'x': x,
                            'confidence': region_confidence,
                            'region_size': region_size
                        })
                continue
            # Apply opening only if region is sufficiently large; avoid eroding tiny blobs
            if int(high_conf_regions.sum()) > 50:
                high_conf_regions = cv2.morphologyEx(high_conf_regions, cv2.MORPH_OPEN, self.morph_kernel)

            labeled_regions, num_regions = label(high_conf_regions)
            for region_id in range(1, num_regions + 1):
                region_mask = (labeled_regions == region_id)
                region_size = int(region_mask.sum())
                if region_size < self.min_region_size:
                    continue
                ys, xs = np.where(region_mask)
                if ys.size == 0:
                    continue
                # Skip border-touching components
                if (ys.min() <= self.border_margin or ys.max() >= H - self.border_margin - 1 or
                    xs.min() <= self.border_margin or xs.max() >= W - self.border_margin - 1):
                    continue

                com = center_of_mass(region_mask)
                y, x = int(com[0]), int(com[1])
                region_confidence = float(slice_mask[region_mask].mean())

                roi_candidates.append({
                    'slice_idx': slice_idx,
                    'y': y,
                    'x': x,
                    'confidence': region_confidence,
                    'region_size': region_size
                })
        
        # Sort by confidence (descending)
        if not roi_candidates:
            # Volume-wise peak fallback: pick top maxima per slice (excluding borders)
            print("🔍 DEBUG: No ROI components found; using volume-wise peak fallback")
            D = seg_mask_resized.shape[0]
            peak_candidates = []
            for z in range(D):
                m = seg_mask_resized[z].copy()
                # suppress borders
                m[:self.border_margin, :] = 0
                m[-self.border_margin:, :] = 0
                m[:, :self.border_margin] = 0
                m[:, -self.border_margin:] = 0
                yx = np.unravel_index(np.argmax(m), m.shape)
                y, x = int(yx[0]), int(yx[1])
                conf = float(m[y, x])
                if conf > 0:
                    peak_candidates.append({
                        'slice_idx': z,
                        'y': y,
                        'x': x,
                        'confidence': conf,
                        'region_size': 1
                    })
            # Keep strongest few peaks across volume
            peak_candidates.sort(key=lambda c: c['confidence'], reverse=True)
            roi_candidates.extend(peak_candidates[: max( self.max_rois_per_series * 3, 6)])

        roi_candidates.sort(key=lambda x: x['confidence'], reverse=True)
        
        print(f"🔍 DEBUG: Found {len(roi_candidates)} ROI candidates")
        return roi_candidates

    def _proposals_from_3d_peaks(self, seg_mask_zyx: np.ndarray):
        """3D local-max proposals with seeded relative growth (no absolute threshold lowering)."""
        try:
            D, H, W = seg_mask_zyx.shape
            # Light 3D smoothing to stabilize local maxima
            try:
                sm = ndimage.gaussian_filter(seg_mask_zyx.astype(np.float32), sigma=0.75)
            except Exception:
                sm = seg_mask_zyx.astype(np.float32)
            # 3D local maxima via maximum filter
            footprint = np.ones((3,3,3), dtype=np.uint8)
            max_f = ndimage.maximum_filter(sm, footprint=footprint, mode='nearest')
            peaks = (sm == max_f)
            # Suppress borders
            b = self.border_margin
            if b > 0:
                peaks[:, :b, :] = False; peaks[:, -b:, :] = False
                peaks[:, :, :b] = False; peaks[:, :, -b:] = False
            coords = np.argwhere(peaks)
            if coords.shape[0] == 0:
                return []
            # Rank peaks by value and keep top-K to control cost
            values = sm[peaks]
            order = np.argsort(values)[::-1]
            top_k = min(64, order.size)
            selected = coords[order[:top_k]]
            # Non-maximum suppression by 3D distance
            kept = []
            min_dist = 4.0
            for (cz, cy, cx) in selected:
                if any(((cz-kz)**2 + (cy-ky)**2 + (cx-kx)**2) ** 0.5 < min_dist for kz,ky,kx in kept):
                    continue
                kept.append((int(cz), int(cy), int(cx)))
                if len(kept) >= 64:
                    break
            # Seeded relative growth
            proposals = []
            for cz, cy, cx in kept:
                peak = float(sm[cz, cy, cx])
                if peak <= 0:
                    continue
                rel_thr = max(0.6*peak, 1e-6)  # relative to each peak
                # collect voxels that descend from the peak (thresholded region)
                region = sm >= rel_thr
                labeled, num = ndimage.label(region)
                cid = int(labeled[cz, cy, cx])
                if cid == 0:
                    continue
                comp = (labeled == cid)
                size = int(comp.sum())
                if size < self.min_region_size:
                    continue
                # score = peak * mean(comp)
                conf = peak * float(sm[comp].mean() + 1e-6)
                # project to a representative slice (peak slice)
                ys, xs = np.where(comp[cz])
                if ys.size == 0:
                    # fallback to COM over full comp
                    zc, yc, xc = ndimage.center_of_mass(comp)
                    zc = int(round(zc)); yc = int(round(yc)); xc = int(round(xc))
                    if yc <= self.border_margin or yc >= H - self.border_margin - 1 or xc <= self.border_margin or xc >= W - self.border_margin - 1:
                        continue
                    proposals.append({
                        'slice_idx': int(zc),
                        'y': int(yc),
                        'x': int(xc),
                        'confidence': float(conf),
                        'region_size': size,
                    })
                else:
                    y = int(ys.mean()); x = int(xs.mean())
                    if y <= self.border_margin or y >= H - self.border_margin - 1 or x <= self.border_margin or x >= W - self.border_margin - 1:
                        continue
                    proposals.append({
                        'slice_idx': int(cz),
                        'y': y,
                        'x': x,
                        'confidence': float(conf),
                        'region_size': size,
                    })
            proposals.sort(key=lambda c: c['confidence'], reverse=True)
            return proposals
        except Exception:
            return []
    
    def _select_adaptive_rois(self, roi_candidates, seg_quality, original_volume):
        """Adaptively select ROIs based on segmentation quality (research-backed)"""
        if not roi_candidates:
            print("🔍 DEBUG: No candidates found, using fallback")
            return self._get_quality_fallback_rois_from_volume(original_volume)
        
        # Adaptive selection based on segmentation quality
        if seg_quality >= self.high_confidence_threshold:
            max_rois = self.max_rois_per_series
            min_confidence = 0.3
        elif seg_quality >= self.min_confidence_threshold + 0.2:
            max_rois = self.max_rois_per_series
            min_confidence = 0.2
        else:
            max_rois = self.max_rois_per_series
            min_confidence = 0.05
        
        # Filter and select ROIs
        filtered = [c for c in roi_candidates if c['confidence'] >= min_confidence]
        selected_candidates = filtered[:max_rois]
        # If not enough, top-off with next best candidates
        if len(selected_candidates) < max_rois:
            for c in roi_candidates:
                if c in selected_candidates:
                    continue
                selected_candidates.append(c)
                if len(selected_candidates) >= max_rois:
                    break
        
        # Convert to ROI format
        rois = []
        for i, candidate in enumerate(selected_candidates):
            roi_patch = self._extract_roi_patch(
                original_volume,
                candidate['slice_idx'], 
                candidate['y'], 
                candidate['x']
            )
            
            rois.append({
                'roi_image': roi_patch,
                'slice_idx': candidate['slice_idx'],
                'coordinates': (candidate['y'], candidate['x']),
                'confidence': candidate['confidence'],
                'roi_id': i
            })
        # Ensure at least max_rois via center-based fallback if still short
        if len(rois) < self.max_rois_per_series:
            needed = self.max_rois_per_series - len(rois)
            center_fallbacks = self._get_quality_fallback_rois_from_volume(original_volume, needed)
            rois.extend(center_fallbacks)
        print(f"🔍 DEBUG: Adaptively selected {len(rois)} ROIs (quality: {seg_quality:.3f})")
        return rois[: self.max_rois_per_series]
    
    def _get_quality_fallback_rois(self, series_path, seg_mask):
        """Fallback for poor segmentation quality: generate multiple center-based ROIs"""
        print("🔍 DEBUG: Using quality-aware fallback (multi-center ROIs)")
        original_volume = self._load_efficient_volume(series_path)
        return self._get_quality_fallback_rois_from_volume(original_volume, self.max_rois_per_series)

    def _get_quality_fallback_rois_from_volume(self, original_volume, count: int = 3):
        D, H, W = original_volume.shape
        # Choose slice indices: center and quartiles
        slices = sorted(set([D // 2, max(0, D // 4), min(D - 1, 3 * D // 4)]))
        # Ensure desired count
        while len(slices) < count:
            # Add random slices if needed
            slices.append(np.random.randint(0, D))
            slices = list(dict.fromkeys(slices))
        rois = []
        cy, cx = H // 2, W // 2
        for i, s in enumerate(slices[:count]):
            roi_patch = self._extract_roi_patch(original_volume, s, cy, cx)
            rois.append({
                'roi_image': roi_patch,
                'slice_idx': s,
                'coordinates': (cy, cx),
                'confidence': 0.2,
                'roi_id': i
            })
        return rois
    
    def _get_simple_fallback_rois(self):
        """Simple fallback when no quality ROIs found"""
        print("🔍 DEBUG: Using simple fallback (single center ROI)")
        dummy_roi = np.random.random((*Config.ROI_SIZE, 3)).astype(np.float32)
        return [{
            'roi_image': dummy_roi,
            'slice_idx': 25,
            'coordinates': (128, 128),
            'confidence': 0.1,
            'roi_id': 0
        }]
    
    def _get_emergency_fallback_rois(self):
        """Emergency fallback when everything fails"""
        print("🔍 DEBUG: Using emergency fallback ROI")
        dummy_roi = np.random.random((*Config.ROI_SIZE, 3)).astype(np.float32)
        return [{
            'roi_image': dummy_roi,
            'slice_idx': 0,
            'coordinates': (128, 128),
            'confidence': 0.1,
            'roi_id': 0
        }]

    
    def _load_efficient_volume(self, series_path):
        """Load volume with smart distributed sampling to cover entire brain"""
        try:
            # Cache original volume slices to reduce repeated I/O
            os.makedirs(Config.STAGE2_CACHE_DIR, exist_ok=True)
            sid = os.path.basename(series_path)
            vcache = os.path.join(Config.STAGE2_CACHE_DIR, f"{sid}_vol.npy")
            if os.path.exists(vcache):
                return np.load(vcache, allow_pickle=False)
            dicom_files = [f for f in os.listdir(series_path) if f.endswith('.dcm')]
            pixel_arrays = []
            
            # SMART SAMPLING: Distribute 50 slices across entire volume
            total_files = len(dicom_files)
            if total_files > 50:
                # Calculate step size to distribute slices evenly
                step = total_files / 50
                selected_indices = [int(i * step) for i in range(50)]
                selected_files = [dicom_files[i] for i in selected_indices]
                print(f"🔍 DEBUG: Smart sampling - selected {len(selected_files)} files from {total_files} total (every {step:.1f})")
            else:
                selected_files = dicom_files
                print(f"🔍 DEBUG: Using all {len(selected_files)} files (less than 50)")
            
            for f in selected_files:
                try:
                    ds = pydicom.dcmread(os.path.join(series_path, f), force=True)
                    if hasattr(ds, 'pixel_array'):
                        arr = ds.pixel_array
                        if arr.ndim == 2:
                            pixel_arrays.append(arr)
                except:
                    continue
            
            if pixel_arrays:
                # SMALLER target shape to reduce memory usage
                target_shape = (256, 256)  # Reduced from (512, 512)
                
                resized_arrays = []
                for arr in pixel_arrays:
                    # Use cv2.resize instead of ndimage.zoom (more reliable)
                    if arr.shape != target_shape:
                        resized_arr = cv2.resize(arr.astype(np.float32), target_shape)
                        resized_arrays.append(resized_arr)
                    else:
                        resized_arrays.append(arr.astype(np.float32))
                
                volume = np.stack(resized_arrays, axis=0)
                
                # Simple normalization
                p1, p99 = np.percentile(volume, [1, 99])
                volume = np.clip(volume, p1, p99)
                volume = (volume - p1) / (p99 - p1 + 1e-8)
                
                try:
                    np.save(vcache, volume.astype(np.float32), allow_pickle=False)
                except Exception:
                    pass
                return volume
            
        except Exception as e:
            print(f"Error loading efficient volume: {e}")
        
        # Fallback volume (matches our smart sampling approach)
        return np.random.random((50, 256, 256)).astype(np.float32)

    
    def _extract_roi_patch(self, volume, slice_idx, center_y, center_x):
        """Extract ROI with adjacent-slice context as RGB channels (s-1, s, s+1)."""
        D, H, W = volume.shape
        s_indices = [max(0, slice_idx - 1), slice_idx, min(D - 1, slice_idx + 1)]
        channels = []
        half_size = Config.ROI_SIZE[0] // 2
        for s in s_indices:
            slice_data = volume[s]
            h, w = slice_data.shape
            y1 = max(0, center_y - half_size)
            y2 = min(h, center_y + half_size)
            x1 = max(0, center_x - half_size)
            x2 = min(w, center_x + half_size)
            patch = slice_data[y1:y2, x1:x2]
            patch_resized = cv2.resize(patch, Config.ROI_SIZE)
            channels.append(patch_resized)
        patch_3ch = np.stack(channels, axis=2)
        return patch_3ch
    

def create_training_data(df, stage1_predictor):
    """Create training data. If DIRECT_VOLUME_MODE, bypass ROI extraction and use Stage0 32x384x384 volumes."""
    print("🔄 Extracting ROIs for training data...")
    
    # Direct volume mode: build dataframe pointing to Stage0 32x384x384 volumes
    if getattr(Config, 'DIRECT_VOLUME_MODE', False):
        print("✅ DIRECT_VOLUME_MODE: using Stage 0 32x384x384 volumes (no ROI extraction)")
        # On-the-fly generation: do not depend on prebuilt volumes; dataset will read DICOMs
        records = []
        for _, row in df.iterrows():
            sid = str(row[Config.ID_COL])
            rec = {
                'roi_id': f"{sid}_vol32",
                'roi_path': '',
                'series_id': sid,
                'roi_confidence': 1.0,
                'slice_idx': -1,
            }
            for col in Config.LABEL_COLS:
                rec[col] = row[col]
            records.append(rec)
        training_df = pd.DataFrame(records)
        print(f"✅ DIRECT_VOLUME_MODE: built {len(training_df)} samples from {len(df)} series")
        return training_df

    # Reuse cached ROIs/training dataframe if available
    cache_dir = 'rois'
    os.makedirs(cache_dir, exist_ok=True)
    cached_df_path_parquet = os.path.join(cache_dir, 'training_df.parquet')
    external_cached_df_path = os.path.join(getattr(Config, 'ROIS_EXTERNAL_DIR', ''), 'training_df.parquet')
    if Config.REUSE_EXISTING_ROIS:
        # Prefer working cache
        if os.path.exists(cached_df_path_parquet):
            try:
                cached = pl.read_parquet(cached_df_path_parquet).to_pandas()
                if len(cached) > 0 and all(c in cached.columns for c in ['roi_path', 'roi_id', 'series_id'] + Config.LABEL_COLS):
                    print(f"✅ Reusing cached training ROIs (working): {len(cached)} samples from {cached['series_id'].nunique()} series")
                    return cached
            except Exception:
                pass
        # Fallback to external cache
        if isinstance(external_cached_df_path, str) and len(external_cached_df_path) and os.path.exists(external_cached_df_path):
            try:
                cached = pl.read_parquet(external_cached_df_path).to_pandas()
                if len(cached) > 0 and all(c in cached.columns for c in ['roi_path', 'roi_id', 'series_id'] + Config.LABEL_COLS):
                    print(f"✅ Reusing cached training ROIs (external): {len(cached)} samples from {cached['series_id'].nunique()} series")
                    # Optionally copy into working for faster subsequent access
                    try:
                        pl.from_pandas(cached).write_parquet(cached_df_path_parquet)
                    except Exception:
                        pass
                    return cached
            except Exception:
                pass
        # If external parquet is missing but external ROI images exist, auto-build parquet
        ext_dir = getattr(Config, 'ROIS_EXTERNAL_DIR', '')
        if isinstance(ext_dir, str) and os.path.isdir(ext_dir):
            try:
                candidates = [f for f in os.listdir(ext_dir) if f.lower().endswith('.png')]
                if len(candidates) > 0:
                    print(f"🧩 Building training_df from external ROI images: {len(candidates)} files")
                    records = []
                    label_cols = list(Config.LABEL_COLS)
                    # Map labels by series_id for fast join
                    df_labels = df[[Config.ID_COL] + label_cols].copy()
                    df_labels[Config.ID_COL] = df_labels[Config.ID_COL].astype(str)
                    label_map = df_labels.set_index(Config.ID_COL).to_dict('index')
                    for fname in candidates:
                        base = os.path.splitext(fname)[0]
                        # Expect pattern: {series_id}_roi_{k}
                        # Robust parse: split on '_roi_'
                        if '_roi_' not in base:
                            continue
                        sid_part, roi_part = base.split('_roi_', 1)
                        series_id = sid_part
                        try:
                            roi_id_int = int(roi_part)
                        except Exception:
                            roi_id_int = 0
                        rec = {
                            'roi_id': f"{series_id}_roi_{roi_id_int}",
                            'roi_path': os.path.join(ext_dir, fname),
                            'series_id': series_id,
                            'roi_confidence': 0.2,
                            'slice_idx': -1,
                        }
                        # Attach labels
                        labs = label_map.get(series_id)
                        if labs is None:
                            # skip if label missing (should not happen for train)
                            continue
                        for col in label_cols:
                            rec[col] = labs[col]
                        records.append(rec)
                    if records:
                        training_df_ext = pd.DataFrame(records)
                        print(f"✅ Reconstructed training ROIs (external): {len(training_df_ext)} samples from {training_df_ext['series_id'].nunique()} series")
                        try:
                            pl.from_pandas(training_df_ext).write_parquet(cached_df_path_parquet)
                        except Exception:
                            pass
                        return training_df_ext
            except Exception:
                pass
    roi_extractor = ROIExtractor(stage1_predictor)
    training_data = []
    
    os.makedirs('rois', exist_ok=True)
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Extracting ROIs"):
        series_id = row[Config.ID_COL]
        series_path = os.path.join(Config.SERIES_DIR, series_id)
        
        if not os.path.exists(series_path):
            continue
        
        # Extract ROIs
        rois = roi_extractor.extract_top3_rois(series_path)
        
        # Create training samples
        for roi_data in rois:
            roi_filename = f"rois/{series_id}_roi_{roi_data['roi_id']}.png"
            
            # Save ROI image
            roi_image = (roi_data['roi_image'] * 255).astype(np.uint8)
            Image.fromarray(roi_image).save(roi_filename)
            
            # Create training record
            sample = {
                'roi_id': f"{series_id}_roi_{roi_data['roi_id']}",
                'roi_path': roi_filename,
                'series_id': series_id,
                'roi_confidence': roi_data['confidence'],
                'slice_idx': roi_data['slice_idx']
            }
            
            # Add all label columns
            for col in Config.LABEL_COLS:
                sample[col] = row[col]
            
            training_data.append(sample)
    
    training_df = pd.DataFrame(training_data)
    print(f"✅ Created {len(training_df)} training samples from {len(df)} series")
    # Save for reuse next runs
    try:
        pl.from_pandas(training_df).write_parquet(cached_df_path_parquet)
        print(f"💾 Saved training ROI dataframe → {cached_df_path_parquet}")
    except Exception:
        pass
    
    return training_df

print("✅ Data loading and ROI extraction functions loaded")



# ====================================================
# CELL 3: MODEL DEFINITION
# ====================================================

class AneurysmClassificationDataset(Dataset):
    """Dataset for classification. In direct volume mode, builds 32x384x384 on-the-fly from DICOMs."""
    def __init__(self, df, mode='train'):
        self.df = df
        self.mode = mode
        self.direct_volume = getattr(Config, 'DIRECT_VOLUME_MODE', False)
        if self.direct_volume:
            self.preprocessor = DICOMPreprocessorKaggle(target_shape=(32, 384, 384))
            self.alb_transform = None  # Skip Albumentations in direct-volume path
        else:
            # ROI image pipeline (3-channel PNGs) - keep minimal and albumentations-based
            if mode == 'train':
                self.alb_transform = A.Compose([
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.5),
                    A.Rotate(limit=15, p=0.5),
                    A.ColorJitter(brightness=0.2, contrast=0.2, p=0.5),
                    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
                    ToTensorV2(transpose_mask=False),
                ])
            else:
                self.alb_transform = A.Compose([
                    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
                    ToTensorV2(transpose_mask=False),
                ])
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        if self.direct_volume:
            # Build volume directly from series path (already normalized [0,1])
            series_id = str(row['series_id'])
            series_path = os.path.join(Config.SERIES_DIR, series_id)
            arr = process_dicom_series_safe(series_path, target_shape=(32,384,384))  # (32,384,384) fp16/fp32
            # Train-time slice subsample: keep quality by using full slices at val/test
            if self.mode == 'train':
                k = int(getattr(Config, 'SLICE_SUBSAMPLE_TRAIN', 0) or 0)
                if k and arr.shape[0] > k:
                    idx_keep = np.sort(np.random.choice(arr.shape[0], size=k, replace=False))
                    arr = arr[idx_keep]
            # Keep CPU tensor in fp16 if cache dtype is fp16
            if getattr(Config, 'CACHE_DTYPE', 'float16') == 'float16' and arr.dtype == np.float16:
                image = torch.as_tensor(arr, dtype=torch.float16)
            else:
                image = torch.as_tensor(arr, dtype=torch.float32)
            try:
                del arr
            except Exception:
                pass
        else:
            # Load ROI image
            roi_path = row['roi_path']
            try:
                pil_img = Image.open(roi_path).convert('RGB')
            except:
                pil_img = Image.fromarray(np.random.randint(0, 255, (*Config.ROI_SIZE, 3), dtype=np.uint8))
            np_img = np.array(pil_img)
            try:
                pil_img.close()
            except Exception:
                pass
            out = self.alb_transform(image=np_img)
            image = out['image']
            try:
                del np_img
            except Exception:
                pass
        
        # Get labels
        labels = torch.tensor([row[col] for col in Config.LABEL_COLS], dtype=torch.float32)
        
        return {
            'image': image,
            'labels': labels,
            'roi_id': row['roi_id'],
            'confidence': torch.tensor(row['roi_confidence'], dtype=torch.float32)
        }


class Normalize(nn.Module):
    """ImageNet normalization for 3-channel images."""
    def __init__(self):
        super().__init__()
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    
    def forward(self, x):
        return (x - self.mean) / self.std

class SliceMixer(nn.Module):
    """Mix 32 depth slices -> 3 channels via a learnable 1x1 conv."""
    def __init__(self, in_slices=32, out_ch=3, init='gaussian'):
        super().__init__()
        self.proj = nn.Conv2d(in_slices, out_ch, kernel_size=1, bias=False)
        nn.init.kaiming_uniform_(self.proj.weight, a=math.sqrt(5))
        if init == 'gaussian':
            with torch.no_grad():
                z = torch.linspace(-1, 1, in_slices)
                sig = 0.35
                centers = [-0.5, 0.0, 0.5]
                for c, mu in enumerate(centers[:out_ch]):
                    w = torch.exp(-0.5 * ((z - mu) / sig) ** 2)
                    w = w / w.sum()
                    self.proj.weight[c, :, 0, 0] = w
    def forward(self, x):
        return self.proj(x)


class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_pos=0.0, gamma_neg=4.0, clip=0.05, eps=1e-8):
        super().__init__()
        self.gp = float(gamma_pos)
        self.gn = float(gamma_neg)
        self.clip = float(clip) if clip is not None else None
        self.eps = float(eps)
    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        x = torch.sigmoid(logits)
        xs = x
        if self.clip is not None and self.clip > 0:
            xs = torch.clamp(x, self.clip, 1.0 - self.clip)
        pos = targets
        neg = 1.0 - targets
        loss_pos = -pos * torch.log(xs + self.eps) * ((1.0 - x) ** self.gp)
        loss_neg = -neg * torch.log(1.0 - xs + self.eps) * (x ** self.gn)
        return (loss_pos + loss_neg).mean()


class Simple2D(nn.Module):
    """
    Simple 2D model matching Run 1 (best: 0.6279 AUC).
    Takes 32-slice volume, extracts 3 center slices as RGB channels.
    Uses EfficientNetV2-S (exact match to working model).
    """
    def __init__(self, num_classes=len(Config.LABEL_COLS)):
        super().__init__()
        
        # Use EfficientNetV2-S with 3 input channels
        self.backbone = timm.create_model(
            'tf_efficientnetv2_s.in21k_ft_in1k',
            pretrained=False,
            in_chans=3,
            num_classes=num_classes,
            drop_rate=0.2,
            drop_path_rate=0.2
        )
        
        # Load offline weights from HuggingFace checkpoint
        weights_path = "/kaggle/input/tf_efficientnetv2_s.in21k_ft_in1k/pytorch/default/1/pytorch_model.bin"
        print(f"🔄 Loading EfficientNetV2-S weights from: {weights_path}")
        sd = torch.load(weights_path, map_location="cpu")
        
        # Clean up state dict keys if needed
        if isinstance(sd, dict) and "state_dict" in sd:
            sd = sd["state_dict"]
        
        clean = {}
        for k, v in sd.items():
            # Remove common prefixes
            if k.startswith("module."):
                k = k[7:]
            if k.startswith("model."):
                k = k[6:]
            # Skip classifier weights (1000 classes vs our 14 classes)
            if k.startswith("classifier."):
                continue
            clean[k] = v
        
        # Load weights directly (3-ch to 3-ch, no adaptation needed)
        missing, unexpected = self.backbone.load_state_dict(clean, strict=False)
        print(f"✅ EfficientNetV2-S weights loaded: Missing={len(missing)}, Unexpected={len(unexpected)}")
    
    def forward(self, x):
        """
        Args:
            x: [B, 32, H, W] - 32-slice volume in [0,1]
        Returns:
            logits: [B, num_classes]
        """
        B, D, H, W = x.shape
        
        # Extract 3 center slices as RGB channels
        center = D // 2
        x_3ch = x[:, center-1:center+2, :, :]  # [B, 3, H, W]
        
        # Normalize for EfficientNet (ImageNet stats)
        mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1, 3, 1, 1)
        x_3ch = (x_3ch.float() - mean) / std
        x_3ch = x_3ch.half() if x.dtype == torch.float16 else x_3ch
        
        # Forward through backbone
        logits = self.backbone(x_3ch)
        return logits
        

class RSNA2p5D(nn.Module):
    """2.5D model: 32-slice mixer -> 3ch ImageNet backbone (offline weights)."""
    def __init__(self, num_classes=len(Config.LABEL_COLS), backbone_name='tf_efficientnet_b3'):
        super().__init__()
        self.mixer = SliceMixer(in_slices=32, out_ch=3, init='gaussian')
        self.norm = Normalize()
        # Build 3-ch backbone (feature extractor)
        self.backbone = timm.create_model(backbone_name, pretrained=False, in_chans=3, num_classes=0, global_pool='avg')
        # ---- Offline weights: single fixed path (no fallback) ----
        weights_path = "/kaggle/input/tf-efficientnet/pytorch/tf-efficientnet-b3/1/tf_efficientnet_b3_aa-84b4657e.pth"
        if backbone_name.startswith("tf_efficientnet_b3"):
            if os.path.exists(weights_path):
                try:
                    print(f"🔄 Loading offline TF-EfficientNet-B3 weights from: {weights_path}")
                    state = torch.load(weights_path, map_location="cpu")
                    # Unwrap common formats
                    if isinstance(state, dict) and "state_dict" in state:
                        state = state["state_dict"]
                    clean = {}
                    for k, v in state.items():
                        if k.startswith("module."): k = k[7:]
                        if k.startswith("model."):  k = k[6:]
                        clean[k] = v
                    missing, unexpected = self.backbone.load_state_dict(clean, strict=False)
                    print(f"✅ Weights loaded. Missing: {len(missing)}, Unexpected: {len(unexpected)}")
                except Exception as e:
                    print(f"❌ Error loading offline TF-EfficientNet-B3 weights: {e}")
            else:
                print(f"⚠️ Offline weights not found at {weights_path}; using random init for backbone.")
        else:
            print(f"⚠️ Backbone '{backbone_name}' is not tf_efficientnet_b3; skipping offline load.")
        # Robust feature dim detection
        feature_dim = getattr(self.backbone, "num_features", None)
        if not isinstance(feature_dim, int) or feature_dim <= 0:
            try:
                feature_dim = self.backbone.get_classifier().in_features
            except Exception:
                feature_dim = 1536 if "efficientnet_b3" in backbone_name else 1024

        self.head = nn.Linear(int(feature_dim), num_classes)
    
    def forward(self, x):  # x: (B,32,H,W) in [0,1]
        x = self.mixer(x)
        x = torch.clamp(x, 0, 1)
        x = self.norm(x)
        x = x.to(memory_format=torch.channels_last)
        feats = self.backbone(x)
        return self.head(feats)

# ====================================================
# MIL 2.5D MODEL: Shared 2D encoder per slice + Transformer over depth
# ====================================================
class PositionalEncoding1D(nn.Module):
    def __init__(self, d_model, max_len=128):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe[None])  # (1, L, C)
    def forward(self, x):  # x: (B, L, C)
        return x + self.pe[:, :x.size(1)]

class SliceTriChannel(nn.Module):
    """Make each token a tri-slice [z-1, z, z+1] to inject tiny depth context per slice."""
    def forward(self, x):  # x: (B, D, H, W)
        # Pad depth by 1 on each side, then slice and stack
        xp = F.pad(x, (0, 0, 0, 0, 1, 1), mode='replicate')  # (B, D+2, H, W)
        tri = torch.stack([xp[:, :-2], xp[:, 1:-1], xp[:, 2:]], dim=2)  # (B, D, 3, H, W)
        return tri

class MIL2p5D(nn.Module):
    """
    OPTIMIZED: Shared 2D ImageNet encoder per slice -> Transformer over tokens -> gated attention pooling -> logits.
    Speed optimizations:
      - Smaller backbone (B0 instead of B3): 3-4x faster
      - Reduced spatial resolution (224 instead of 320): 2x faster
      - Gradient checkpointing: Allows larger batches
      - Optional encoder compilation (PyTorch 2.0+)
    """
    def __init__(self, num_classes=len(Config.LABEL_COLS), 
                 backbone=None, d_model=None, nhead=None, n_layers=None, 
                 spatial_size=None, use_grad_checkpoint=None):
        super().__init__()
        # Use Config defaults if not provided
        backbone = backbone or getattr(Config, 'MIL_BACKBONE', 'tf_efficientnet_b0')
        d_model = d_model or getattr(Config, 'MIL_D_MODEL', 512)
        nhead = nhead or getattr(Config, 'MIL_NHEAD', 8)
        n_layers = n_layers or getattr(Config, 'MIL_N_LAYERS', 2)
        self.spatial_size = spatial_size or getattr(Config, 'MIL_SPATIAL_SIZE', 224)
        use_grad_checkpoint = use_grad_checkpoint if use_grad_checkpoint is not None else getattr(Config, 'MIL_USE_GRAD_CHECKPOINT', True)
        
        self.tri = SliceTriChannel()
        # Shared 2D encoder (lightweight for speed)
        self.encoder = timm.create_model(backbone, pretrained=False, in_chans=3, num_classes=0, global_pool='avg')
        
        # Try to load offline weights (support B0, B1, B3)
        weights_map = {
            'tf_efficientnet_b0': '/kaggle/input/tf-efficientnet/pytorch/tf-efficientnet-b0/1/tf_efficientnet_b0_aa-827b6e33.pth',
            'tf_efficientnet_b1': '/kaggle/input/tf-efficientnet/pytorch/tf-efficientnet-b1/1/tf_efficientnet_b1_aa-ea7a7bb7.pth',
            'tf_efficientnet_b3': '/kaggle/input/tf-efficientnet/pytorch/tf-efficientnet-b3/1/tf_efficientnet_b3_aa-84b4657e.pth',
        }
        weights_path = weights_map.get(backbone, None)
        if weights_path and os.path.exists(weights_path):
            try:
                print(f"🔄 Loading offline {backbone} weights from: {weights_path}")
                state = torch.load(weights_path, map_location="cpu")
                if isinstance(state, dict) and "state_dict" in state:
                    state = state["state_dict"]
                clean = {}
                for k, v in state.items():
                    if isinstance(k, str):
                        if k.startswith("module."): k = k[7:]
                        if k.startswith("model."):  k = k[6:]
                        clean[k] = v
                missing, unexpected = self.encoder.load_state_dict(clean, strict=False)
                print(f"✅ Weights loaded. Missing: {len(missing)}, Unexpected: {len(unexpected)}")
            except Exception as e:
                print(f"⚠️ Error loading weights: {e}")
        
        # Enable gradient checkpointing for encoder (saves VRAM)
        if use_grad_checkpoint and hasattr(self.encoder, 'set_grad_checkpointing'):
            try:
                self.encoder.set_grad_checkpointing(True)
                print(f"✅ Gradient checkpointing enabled for {backbone}")
            except Exception:
                pass
        
        # Optional: Compile encoder for speed (PyTorch 2.0+)
        if getattr(Config, 'MIL_COMPILE_ENCODER', False):
            try:
                self.encoder = torch.compile(self.encoder, mode='max-autotune')
                print("✅ Encoder compiled with torch.compile")
            except Exception as e:
                print(f"⚠️ torch.compile not available: {e}")
        
        feat_dim = int(getattr(self.encoder, 'num_features'))
        self.proj = nn.Linear(feat_dim, d_model)
        self.pe   = PositionalEncoding1D(d_model, max_len=128)
        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True,
                                               dim_feedforward=4*d_model, norm_first=True)
        self.tx   = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.gate = nn.Sequential(nn.Linear(d_model, d_model), nn.Tanh(), nn.Linear(d_model, 1))
        self.head = nn.Linear(d_model, num_classes)
        
        print(f"🏗️  MIL2p5D: backbone={backbone}, spatial={self.spatial_size}, d_model={d_model}, layers={n_layers}")

    def forward(self, x):  # x: (B, D, H, W) in [0,1]
        B, D, H, W = x.shape
        x = self.tri(x)                      # (B, D, 3, H, W)
        x = x.view(B*D, 3, H, W)
        x = x.to(memory_format=torch.channels_last)
        # Downscale to target spatial size for speed
        if x.shape[-2] != self.spatial_size or x.shape[-1] != self.spatial_size:
            x = F.interpolate(x, size=(self.spatial_size, self.spatial_size), mode='bilinear', align_corners=False)
        feats = self.encoder(x)              # (B*D, feat_dim)
        feats = feats.view(B, D, -1)         # (B, D, feat_dim)
        z = self.proj(feats)                 # (B, D, d_model)
        z = self.pe(z)
        z = self.tx(z)                       # (B, D, d_model)
        a = torch.softmax(self.gate(z), dim=1)  # (B, D, 1)
        bag = (a * z).sum(dim=1)             # (B, d_model)
        logits = self.head(bag)              # (B, C)
        return logits


# Using original EfficientNet approach

def calculate_class_weights(df):
    """Calculate class weights with 13x multiplier for Aneurysm Present"""
    pos_counts = df[Config.LABEL_COLS].sum()
    neg_counts = len(df) - pos_counts
    
    # Standard frequency-based weights
    class_weights = neg_counts / (pos_counts + 1e-8)
    
    # Apply weight cap if configured (prevents extreme weights like 93.5)
    weight_cap = getattr(Config, 'CLASS_WEIGHT_CAP', 100.0)
    class_weights = np.minimum(class_weights, weight_cap)
    
    # Apply 13x multiplier to "Aneurysm Present" (matches competition metric)
    # But also respect the cap
    class_weights.iloc[-1] = min(class_weights.iloc[-1] * 13.0, weight_cap)
    
    return torch.tensor(class_weights.values, dtype=torch.float32)

try:
    torch.backends.cudnn.benchmark = True  # speed: autotune best cudnn algorithms for fixed input sizes
except Exception:
    pass
print("✅ Model definition loaded")

# ====================================================
# CELL 4: TRAINING PIPELINE
# ====================================================

def compute_hierarchical_penalty(logits: torch.Tensor) -> torch.Tensor:
    """Penalty to enforce sites' probabilities <= main (Aneurysm Present).
    logits: (B, C=14) with last index being 'Aneurysm Present'.
    Returns a scalar tensor on the same device/dtype.
    """
    try:
        if logits.ndim != 2 or logits.size(1) < 2:
            return torch.zeros((), device=logits.device, dtype=logits.dtype)
        sites = logits[:, :-1]
        main = logits[:, -1:]
        p_sites = torch.sigmoid(sites)
        p_main = torch.sigmoid(main)
        return F.relu(p_sites - p_main).mean()
    except Exception:
        # Safe fallback: no penalty if anything goes wrong
        return torch.zeros((), device=logits.device, dtype=logits.dtype)

def compute_primary_loss(logits: torch.Tensor, labels: torch.Tensor, bce_criterion: nn.Module) -> torch.Tensor:
    """Compute supervised loss: ASL for site logits + BCE for main logit, or fallback to BCE for all.
    Assumes last class is 'Aneurysm Present'.
    """
    try:
        if getattr(Config, 'USE_ASL', False) and logits.size(1) >= 2:
            # ASL for site logits (all except last)
            asl = AsymmetricLoss(
                gamma_pos=getattr(Config, 'ASL_GAMMA_POS', 0.0),
                gamma_neg=getattr(Config, 'ASL_GAMMA_NEG', 4.0),
                clip=getattr(Config, 'ASL_CLIP', 0.05),
            )
            sites_logits = logits[:, :-1]
            sites_targets = labels[:, :-1]
            loss_sites = asl(sites_logits, sites_targets)
            # BCE for main logit (last column), keep pos_weight emphasis
            main_logits = logits[:, -1]
            main_targets = labels[:, -1]
            try:
                pw_main = None
                if hasattr(bce_criterion, 'pos_weight') and bce_criterion.pos_weight is not None:
                    pw_main = bce_criterion.pos_weight[-1].to(main_logits.device)
                if pw_main is not None:
                    loss_main = F.binary_cross_entropy_with_logits(main_logits, main_targets, pos_weight=pw_main)
                else:
                    loss_main = F.binary_cross_entropy_with_logits(main_logits, main_targets)
            except Exception:
                loss_main = F.binary_cross_entropy_with_logits(main_logits, main_targets)
            w_sites = float(getattr(Config, 'ASL_WEIGHT_SITES', 1.0))
            w_main  = float(getattr(Config, 'MAIN_BCE_WEIGHT', 1.0))
            return w_sites * loss_sites + w_main * loss_main
        else:
            return bce_criterion(logits, labels)
    except Exception:
        return bce_criterion(logits, labels)


def _val_forward_in_chunks(model, images, *, max_tokens:int, mixed_precision:bool):
    """
    images:
      • mil2p5d: (B, D, H, W) before tri-stacking in the model
      • 2D ROI: (B, C, H, W)
    We compute a chunk size so that (b_chunk * D) <= max_tokens for mil2p5d,
    or (b_chunk) <= max_tokens for 2D ROI, and concatenate outputs.
    """
    import torch
    from torch import nn

    if images.dim() == 4:
        B = images.size(0)
        second = images.size(1)
        D = 1 if second in (1, 3) else second
    elif images.dim() == 5:
        B = images.size(0)
        D = images.size(1)
    else:
        B, D = images.size(0), 1

    b_chunk = max(1, int(max_tokens // max(D, 1)))
    outs = []
    for i in range(0, B, b_chunk):
        x = images[i:i + b_chunk]
        with torch.cuda.amp.autocast(enabled=mixed_precision):
            outs.append(model(x))
    return torch.cat(outs, dim=0)



def compute_validation_auc(logits: torch.Tensor, labels: torch.Tensor) -> float:
    """Compute weighted AUC (13x on 'Aneurysm Present') mirroring validate_epoch behavior."""
    try:
        probs = torch.sigmoid(logits).float().cpu().numpy()
        y_true = labels.float().cpu().numpy()
        auc_scores = []
        valid_aucs = 0
        for i in range(len(Config.LABEL_COLS)):
            if len(np.unique(y_true[:, i])) > 1:
                auc = roc_auc_score(y_true[:, i], probs[:, i])
                auc_scores.append(auc)
                valid_aucs += 1                
            else:
                auc_scores.append(0.5)

        # DEBUG: Print stats on first validation
        if not hasattr(compute_validation_auc, '_debug_printed'):
            print(f"🔍 AUC DEBUG - Valid classes: {valid_aucs}/{len(Config.LABEL_COLS)}")
            print(f"🔍 AUC DEBUG - Pred range: [{probs.min():.3f}, {probs.max():.3f}]")
            print(f"🔍 AUC DEBUG - Pred mean: {probs.mean():.3f}")
            print(f"🔍 AUC DEBUG - Positive rate: {y_true.mean():.3f}")
            compute_validation_auc._debug_printed = True
            
        weights = [1.0] * (len(Config.LABEL_COLS) - 1) + [13.0]
        weighted_auc = float(np.average(auc_scores, weights=weights))
        return weighted_auc
    except Exception as e:
        print(f"⚠️ AUC computation failed: {e}")
        return 0.5


def _run_validation(model, loader, criterion, device, *, max_tokens:int):
    import torch
    model.eval()
    # Temporarily disable encoder/backbone grad-checkpointing for eval (if available)
    core = model.module if isinstance(model, nn.DataParallel) else model
    enc_ref = getattr(core, 'encoder', None)
    if enc_ref is None:
        enc_ref = getattr(core, 'backbone', None)
    restore_gc = None
    if enc_ref is not None and hasattr(enc_ref, 'set_grad_checkpointing'):
        try:
            restore_gc = True
            enc_ref.set_grad_checkpointing(False)
        except Exception:
            restore_gc = None

    total_loss = 0.0
    num_batches = 0
    all_preds = []
    all_labels = []

    with torch.inference_mode():
        for batch in tqdm(loader, desc="Validating (proxy)" if int(max_tokens) < int(getattr(Config, 'VAL_MAX_ENCODER_TOKENS', 384)) else "Validating (full)"):
            if isinstance(batch, dict):
                images = batch['image'].to(device, non_blocking=True)
                labels = batch['labels'].to(device, non_blocking=True)
            else:
                images, labels = batch
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)

            logits = _val_forward_in_chunks(
                model,
                images,
                max_tokens=int(max_tokens),
                mixed_precision=bool(getattr(Config, 'MIXED_PRECISION', True)),
            )
            # Compute supervised loss (same routine as training)
            primary_loss = compute_primary_loss(logits, labels, criterion)
            if getattr(Config, 'ENABLE_HIER_LOSS', False):
                hier_pen = compute_hierarchical_penalty(logits)
                batch_loss = primary_loss + float(getattr(Config, 'HIER_LOSS_LAMBDA', 0.2)) * hier_pen
            else:
                batch_loss = primary_loss
            total_loss += float(batch_loss.detach().cpu())
            num_batches += 1

            # Collect predictions for AUC on CPU
            all_preds.append(torch.sigmoid(logits).float().cpu())
            all_labels.append(labels.float().cpu())

    if restore_gc and enc_ref is not None and hasattr(enc_ref, 'set_grad_checkpointing'):
        try:
            enc_ref.set_grad_checkpointing(True)
        except Exception:
            pass

    logits_cat = torch.cat(all_preds, dim=0)
    labels_cat = torch.cat(all_labels, dim=0)
    val_loss = total_loss / max(1, num_batches)
    val_auc = compute_validation_auc(logits_cat, labels_cat)
    return val_loss, val_auc


def autotune_micro_batch(model, train_loader, criterion, device):
    """
    Try a few micro-batch sizes on a single mini-iteration to find the largest safe one.
    No weights are updated. Returns chosen micro size.
    """
    import torch
    model.eval()
    
    # Safety: Try to get batch with timeout protection
    try:
        print("[AutoTune] Fetching test batch...", flush=True)
        batch = next(iter(train_loader))
        if isinstance(batch, dict):
            imgs = batch['image']
            labels = batch['labels']
        else:
            imgs, labels = batch
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        print(f"[AutoTune] Test batch loaded: {imgs.shape}", flush=True)
    except Exception as e:
        print(f"[AutoTune] ERROR loading batch: {e}", flush=True)
        print("[AutoTune] Falling back to default MICRO_BATCH_SIZE", flush=True)
        return getattr(Config, 'MICRO_BATCH_SIZE', 12)

    candidates = [16, 14, 12, 10, 8, 6, 4]
    chosen = None
    scaler_local = torch.cuda.amp.GradScaler(enabled=getattr(Config, 'MIXED_PRECISION', True))

    for mb in candidates:
        try:
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass
            accum = max(1, int((getattr(Config, 'BATCH_SIZE', 24) + mb - 1) // mb))
            model.zero_grad(set_to_none=True)
            for i in range(accum):
                i0 = i * mb
                i1 = min((i + 1) * mb, imgs.size(0))
                if i0 >= i1:
                    break
                x = imgs[i0:i1].to(memory_format=torch.channels_last)
                y = labels[i0:i1]
                with torch.cuda.amp.autocast(enabled=getattr(Config, 'MIXED_PRECISION', True)):
                    logits = model(x)
                    loss = criterion(logits, y) / accum
                scaler_local.scale(loss).backward()
            # Clear grads (no optimizer step)
            for p in model.parameters():
                if p.grad is not None:
                    p.grad = None
            chosen = mb
            break
        except RuntimeError as e:
            if 'CUDA out of memory' not in str(e):
                raise
        finally:
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass

    model.train()
    if chosen is None:
        chosen = max(1, int(getattr(Config, 'MICRO_BATCH_SIZE', 12)))
    return chosen

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(loader, desc="Training"):
        images = batch['image'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        
        # Forward pass
        with torch.cuda.amp.autocast(enabled=Config.MIXED_PRECISION):
            images = images.to(memory_format=torch.channels_last)
            logits = model(images)
            primary_loss = compute_primary_loss(logits, labels, criterion)
            if getattr(Config, 'ENABLE_HIER_LOSS', False):
                hier_pen = compute_hierarchical_penalty(logits)
                loss = primary_loss + float(getattr(Config, 'HIER_LOSS_LAMBDA', 0.2)) * hier_pen
            else:
                loss = primary_loss
        
        # Backward pass
        if Config.MIXED_PRECISION:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches

def validate_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    num_batches = 0
    
    with torch.inference_mode():
        for batch in tqdm(loader, desc="Validating"):
            images = batch['image'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)
            
            with torch.cuda.amp.autocast(enabled=Config.MIXED_PRECISION):
                images = images.to(memory_format=torch.channels_last)

            logits = _val_forward_in_chunks(
                model,
                images,
                max_tokens=int(getattr(Config, 'VAL_MAX_ENCODER_TOKENS', 384)),
                mixed_precision=bool(getattr(Config, 'MIXED_PRECISION', True)),
            )
            primary_loss = compute_primary_loss(logits, labels, criterion)
            if getattr(Config, 'ENABLE_HIER_LOSS', False):
                hier_pen = compute_hierarchical_penalty(logits)
                loss = primary_loss + float(getattr(Config, 'HIER_LOSS_LAMBDA', 0.2)) * hier_pen
            else:
                loss = primary_loss
            
            total_loss += loss.item()
            num_batches += 1
            
            # Collect predictions for AUC on CPU only and free GPU tensors promptly
            probs = torch.sigmoid(logits).float().cpu()
            all_preds.append(probs.numpy())
            all_labels.append(labels.cpu().numpy())
            del logits, probs, images, labels
    
    # Calculate AUC (overall + per-class)
    if len(all_preds) > 0:
        all_preds = np.vstack(all_preds)
        all_labels = np.vstack(all_labels)
        
        try:
            auc_scores = []
            for i in range(len(Config.LABEL_COLS)):
                if len(np.unique(all_labels[:, i])) > 1:
                    auc = roc_auc_score(all_labels[:, i], all_preds[:, i])
                    auc_scores.append(auc)
                else:
                    auc_scores.append(0.5)

            # DEBUG: Print stats EVERY fold to track prediction behavior
            print(f"🔍 AUC DEBUG - Valid classes: {valid_aucs}/{len(Config.LABEL_COLS)}")
            print(f"🔍 AUC DEBUG - Pred range: [{probs.min():.3f}, {probs.max():.3f}]")
            print(f"🔍 AUC DEBUG - Pred mean: {probs.mean():.3f}")
            print(f"🔍 AUC DEBUG - Positive rate: {y_true.mean():.3f}")
            
            # Weighted AUC (13x weight for Aneurysm Present)
            weights = [1.0] * (len(Config.LABEL_COLS) - 1) + [13.0]
            weighted_auc = np.average(auc_scores, weights=weights)

            # Per-class AUC logging (rounded)
            per_class_auc = {Config.LABEL_COLS[i]: (round(auc_scores[i], 4) if not np.isnan(auc_scores[i]) else None)
                             for i in range(len(Config.LABEL_COLS))}
            print("Per-class AUC:", per_class_auc)
        except:
            weighted_auc = 0.5
    else:
        weighted_auc = 0.5
    
    return total_loss / num_batches, weighted_auc

def main_training():
    print("🚀 STAGE 2: ANEURYSM CLASSIFICATION")
    print("📦 Data: Loading .npz shards directly (64x384x384x4 uint8)")
    
    # Load data
    train_df = pd.read_csv(Config.TRAIN_CSV_PATH)
    
    if Config.DEBUG_MODE:
        train_df = train_df.head(Config.DEBUG_SAMPLES)
    
    print(f"Training samples: {len(train_df)}")
    print(f"Aneurysm cases: {train_df['Aneurysm Present'].sum()}")
    
    # Build training index (DIRECT_VOLUME_MODE: no ROIs)
    records = []
    for _, r in train_df.iterrows():
        rec = {
            'series_id': str(r[Config.ID_COL]),
            'roi_id': f"{str(r[Config.ID_COL])}_vol32",
            'roi_path': '',
            'roi_confidence': 1.0,
            'slice_idx': -1,
        }
        for col in Config.LABEL_COLS:
            rec[col] = r[col]
        records.append(rec)
    training_df = pd.DataFrame(records)
    
    # Calculate class weights
    class_weights = calculate_class_weights(training_df)
    print(f"Class weights: {class_weights}")
    
    # Create criterion with class weights
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights).to(Config.DEVICE)
    
    # Mixed precision scaler
    global scaler
    scaler = torch.cuda.amp.GradScaler(enabled=Config.MIXED_PRECISION)
    
    # Cross-validation or single split
    # Use Aneurysm Present for stratification
    fold_scores = []
    if Config.N_FOLDS <= 1:
        idx_all = np.arange(len(training_df))
        train_idx, val_idx = train_test_split(
            idx_all,
            test_size=0.2,
            stratify=training_df['Aneurysm Present'],
            random_state=42,
        )
        fold_splits = [(train_idx, val_idx)]
    else:
        skf = StratifiedKFold(n_splits=Config.N_FOLDS, shuffle=True, random_state=42)
        fold_splits = list(skf.split(training_df, training_df['Aneurysm Present']))
    
    for fold, (train_idx, val_idx) in enumerate(fold_splits):
        print(f"\n{'='*50}")
        print(f"FOLD {fold + 1}/{Config.N_FOLDS}")
        print(f"{'='*50}")
        
        # Split data
        train_fold_df = training_df.iloc[train_idx].reset_index(drop=True)
        val_fold_df = training_df.iloc[val_idx].reset_index(drop=True)
        
        print(f"Train samples: {len(train_fold_df)}, Val samples: {len(val_fold_df)}")
        
        # Create datasets
        train_dataset = AneurysmClassificationDataset(train_fold_df, mode='train')
        val_dataset = AneurysmClassificationDataset(val_fold_df, mode='val')

        # DEBUG: Verify data sanity
        if fold == 0:
            test_sample = train_dataset[0]
            print(f"🔍 DATA CHECK - Image shape: {test_sample['image'].shape}, dtype: {test_sample['image'].dtype}")
            print(f"🔍 DATA CHECK - Image range: [{test_sample['image'].min():.3f}, {test_sample['image'].max():.3f}]")
            print(f"🔍 DATA CHECK - Image mean: {test_sample['image'].mean():.3f}, std: {test_sample['image'].std():.3f}")
            print(f"🔍 DATA CHECK - Non-zero voxels: {(test_sample['image'] > 0).sum()}/{test_sample['image'].numel()}")
            print(f"🔍 DATA CHECK - Labels shape: {test_sample['labels'].shape}, sum: {test_sample['labels'].sum()}")

            # Check class distribution
            aneurysm_present = train_fold_df['Aneurysm Present'].sum()
            total_samples = len(train_fold_df)
            print(f"🔍 CLASS DISTRIBUTION - Aneurysm Present: {aneurysm_present}/{total_samples} ({100*aneurysm_present/total_samples:.1f}%)")
            
            # Check all label columns
            for col in Config.LABEL_COLS:
                if col in train_fold_df.columns:
                    pos = train_fold_df[col].sum()
                    print(f"  {col}: {pos} positives")

        # WeightedRandomSampler to emphasize positives (Aneurysm Present)
        sampler = None
        if getattr(Config, 'USE_WEIGHTED_SAMPLER', True):
            try:
                pos = train_fold_df['Aneurysm Present'].values.astype(np.int64)
                class_sample_count = np.array([np.sum(pos == 0), np.sum(pos == 1)])
                weight = 1.0 / (class_sample_count + 1e-8)
                samples_weight = weight[pos]
                samples_weight = torch.from_numpy(samples_weight).double()
                sampler = WeightedRandomSampler(samples_weight, num_samples=len(samples_weight), replacement=True)
            except Exception:
                sampler = None
        
        # Create loaders (tuned for throughput)
        train_loader = DataLoader(
            train_dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=(sampler is None),
            sampler=sampler,
            num_workers=Config.NUM_WORKERS,
            pin_memory=Config.PIN_MEMORY,
            persistent_workers=Config.PERSISTENT_WORKERS,
            prefetch_factor=Config.PREFETCH_FACTOR,
            drop_last=True,
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=getattr(Config, 'VAL_BATCH_SIZE', 64),
            shuffle=False,
            num_workers=Config.NUM_WORKERS,
            pin_memory=Config.PIN_MEMORY,
            persistent_workers=Config.PERSISTENT_WORKERS,
            prefetch_factor=Config.PREFETCH_FACTOR,
        )
        # Build fast-val loader (subset) once
        def build_fast_val_loader(val_dataset):
            import numpy as _np
            from torch.utils.data import DataLoader as _DL, Subset as _Subset
            frac = float(getattr(Config, 'FAST_VAL_SUBSET_FRAC', 0.33))
            n = len(val_dataset)
            k = max(1, int(n * frac))
            rng = _np.random.RandomState(getattr(Config, 'SEED', 42))
            idx = rng.choice(n, size=k, replace=False)
            subset = _Subset(val_dataset, idx.tolist())
            return _DL(
                subset,
                batch_size=getattr(Config, 'VAL_BATCH_SIZE', 32),
                shuffle=False,
                num_workers=getattr(Config, 'NUM_WORKERS', 4),
                pin_memory=getattr(Config, 'PIN_MEMORY', True),
                persistent_workers=getattr(Config, 'PERSISTENT_WORKERS', True),
                prefetch_factor=getattr(Config, 'PREFETCH_FACTOR', 2),
            )
        val_loader_full = val_loader
        val_loader_fast = build_fast_val_loader(val_dataset) if getattr(Config, 'FAST_VAL', True) else None
        
        # Initialize model with architecture-specific shard configuration
        arch = getattr(Config, 'MODEL_ARCH', 'rsna2p5d')
        
        if arch == 'simple2d':
            # Simple 2D model (matching 0.69 working architecture)
            Config.SHARD_CHANNEL_MODE = 'cta'
            Config.SHARD_TARGET_SPATIAL = (384, 384)
            print(f"🔧 Simple2D mode: 3 center slices as RGB, spatial={Config.SHARD_TARGET_SPATIAL}")
            
            model = Simple2D(num_classes=len(Config.LABEL_COLS)).to(Config.DEVICE)
        elif arch == 'mil2p5d':
            # Override shard config for mil2p5d (multi-channel, smaller spatial)
            Config.SHARD_CHANNEL_MODE = 'best3'  # CTA + soft + vesselness
            Config.SHARD_TARGET_SPATIAL = (getattr(Config, 'MIL_SPATIAL_SIZE', 224), 
                                           getattr(Config, 'MIL_SPATIAL_SIZE', 224))
            print(f"🔧 MIL2p5D mode: shard channels={Config.SHARD_CHANNEL_MODE}, spatial={Config.SHARD_TARGET_SPATIAL}")
            
            model = MIL2p5D(num_classes=len(Config.LABEL_COLS)).to(Config.DEVICE)
        else:
            # rsna2p5d uses single CTA channel, full resolution
            Config.SHARD_CHANNEL_MODE = 'cta'
            Config.SHARD_TARGET_SPATIAL = (384, 384)
            print(f"🔧 RSNA2p5D mode: shard channels={Config.SHARD_CHANNEL_MODE}, spatial={Config.SHARD_TARGET_SPATIAL}")
            
            model = RSNA2p5D().to(Config.DEVICE)
        try:
            model = model.to(memory_format=torch.channels_last)
        except Exception:
            pass
        
        # Optimizer - SIMPLIFIED to match working 0.69 model
        # Use constant LR, no weight_decay, no LR scaling
        lr = Config.LEARNING_RATE  # 2e-4
        
        if isinstance(model, Simple2D):
            # Simple2D: single LR, no weight_decay (exact match to Run 1 - best: 0.6279)
            optimizer = optim.AdamW(model.parameters(), lr=lr)
        elif isinstance(model, RSNA2p5D):
            # RSNA2p5D: keep parameter groups for backward compatibility
            optimizer = optim.AdamW([
                {'params': model.mixer.parameters(), 'lr': lr},
                {'params': model.head.parameters(), 'lr': lr},
                {'params': model.backbone.parameters(), 'lr': lr},
            ])
        else:
            # MIL2p5D: keep parameter groups for backward compatibility
            optimizer = optim.AdamW([
                {'params': model.encoder.parameters(), 'lr': lr},
                {'params': model.proj.parameters(), 'lr': lr},
                {'params': model.tx.parameters(), 'lr': lr},
                {'params': model.gate.parameters(), 'lr': lr},
                {'params': model.head.parameters(), 'lr': lr},
            ])

        # Optional: 1-epoch freeze warmup for backbone
        freeze_backbone_for_epoch0 = True
        # EMA (CPU, float32) to avoid extra VRAM usage
        use_ema = bool(getattr(Config, 'USE_EMA', True))
        ema_decay = float(getattr(Config, 'EMA_DECAY', 0.999))
        ema_state = (
            {
                k: v.detach().float().cpu().clone()
                for k, v in model.state_dict().items()
                if isinstance(v, torch.Tensor) and torch.is_floating_point(v)
            }
            if use_ema else None
        )
        
        # Multi-GPU if available
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        
        # Scheduler: DISABLED to match working 0.69 model (constant LR)
        # steps_per_epoch = max(1, len(train_loader))
        # warmup_steps = min(200, steps_per_epoch)
        # scheduler = SequentialLR(
        #     optimizer,
        #     schedulers=[
        #         LinearLR(optimizer, start_factor=0.1, total_iters=warmup_steps),
        #         CosineAnnealingLR(optimizer, T_max=Config.EPOCHS * steps_per_epoch)
        #     ],
        #     milestones=[warmup_steps]
        # )
        scheduler = None  # No scheduler (constant LR)

        # Enable grad checkpointing on encoder/backbone if supported (saves VRAM)
        core_for_gc = model.module if isinstance(model, nn.DataParallel) else model
        enc_ref = None
        if isinstance(core_for_gc, RSNA2p5D):
            enc_ref = getattr(core_for_gc, 'backbone', None)
        else:
            enc_ref = getattr(core_for_gc, 'encoder', None)
        if enc_ref is not None and hasattr(enc_ref, 'set_grad_checkpointing'):
            try:
                enc_ref.set_grad_checkpointing(True)
            except Exception:
                pass
        
        # Auto-tune micro-batch size once before epoch 1
        if getattr(Config, 'AUTO_TUNE_MICRO', False):  # DISABLED by default to prevent hanging
            print("[AutoTune] Probing micro-batch size... this may take a few minutes on the first run", flush=True)
            try:
                mb = autotune_micro_batch(model, train_loader, criterion, Config.DEVICE)
                if mb != Config.MICRO_BATCH_SIZE:
                    print(f"[AutoTune] MICRO_BATCH_SIZE {Config.MICRO_BATCH_SIZE} -> {mb}", flush=True)
                else:
                    print(f"[AutoTune] MICRO_BATCH_SIZE remains {mb}", flush=True)
                Config.MICRO_BATCH_SIZE = mb
            except Exception as e:
                print(f"[AutoTune] ERROR: {e}", flush=True)
                print(f"[AutoTune] Using default MICRO_BATCH_SIZE={Config.MICRO_BATCH_SIZE}", flush=True)
        else:
            print(f"[AutoTune] Skipped (disabled). Using MICRO_BATCH_SIZE={Config.MICRO_BATCH_SIZE}", flush=True)
        # Training loop
        best_auc = 0
        patience = max(1, int(getattr(Config, 'EARLY_STOPPING_PATIENCE', 2)))
        no_improve = 0
        best_full_auc = -1.0
        full_val_metrics = None
        
        for epoch in range(Config.EPOCHS):
            print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")
            t0 = time.time()  
            
            # GPU utilization tracking
            gpu_idle_time = 0
            gpu_active_time = 0
            last_gpu_check = time.time()

            if freeze_backbone_for_epoch0:
                # Handle DataParallel and arch differences transparently
                core = model.module if isinstance(model, nn.DataParallel) else model
                # Skip freezing for Simple2D (backbone includes classifier, needs all gradients)
                if isinstance(core, RSNA2p5D):
                    backbone_ref = core.backbone
                elif not isinstance(core, Simple2D):
                    backbone_ref = getattr(core, 'encoder', None)
                else:
                    backbone_ref = None
                if backbone_ref is not None:
                    if epoch == 0:
                        for p in backbone_ref.parameters():
                            p.requires_grad = False
                    elif epoch == 1:
                        for p in backbone_ref.parameters():
                            p.requires_grad = True
            
            # Train
            train_loss = 0.0
            model.train()
            grad_accum_steps = getattr(Config, 'GRAD_ACCUM_STEPS', 1)
            
            for batch_idx, batch in enumerate(tqdm(train_loader, desc="Training")):
                iter_start = time.time()
                
                # Track GPU idle time (time waiting for data)
                if batch_idx > 0:
                    data_wait_time = iter_start - last_gpu_check
                    gpu_idle_time += data_wait_time
                    
                images = batch['image'].to(Config.DEVICE, non_blocking=True)
                labels = batch['labels'].to(Config.DEVICE, non_blocking=True)
                           
                # Zero grad only at start of accumulation cycle
                if batch_idx % grad_accum_steps == 0:
                    optimizer.zero_grad(set_to_none=True)
                
                # Forward pass with gradient accumulation scaling
                with torch.cuda.amp.autocast(enabled=Config.MIXED_PRECISION):
                    logits = model(images)

                    # Debug: verify ImageNet normalization on first batch of first epoch
                    if epoch == 0 and batch_idx == 0:
                        print(f"🔍 NORM CHECK - Input to model: range [{images.min():.3f}, {images.max():.3f}], mean {images.float().mean():.3f}")
                        # Extract 3 center slices like the model does
                        center = images.shape[1] // 2
                        x_3ch = images[:, center-1:center+2, :, :]
                        # Apply ImageNet normalization (3 channels RGB)
                        mean = torch.tensor([0.485, 0.456, 0.406], device=images.device).view(1, 3, 1, 1)
                        std = torch.tensor([0.229, 0.224, 0.225], device=images.device).view(1, 3, 1, 1)
                        x_norm = (x_3ch.float() - mean) / std
                        print(f"🔍 NORM CHECK - After ImageNet norm (3-ch RGB): range [{x_norm.min():.3f}, {x_norm.max():.3f}], mean {x_norm.float().mean():.3f}")
                    
                    primary_loss = compute_primary_loss(logits, labels, criterion)
                    if getattr(Config, 'ENABLE_HIER_LOSS', False):
                        hier_pen = compute_hierarchical_penalty(logits)
                        loss = primary_loss + float(getattr(Config, 'HIER_LOSS_LAMBDA', 0.2)) * hier_pen
                    else:
                        loss = primary_loss
                    
                    # Scale loss for gradient accumulation
                    loss = loss / grad_accum_steps

                # Backward pass
                if Config.MIXED_PRECISION:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()

                # Optimizer step only after accumulating gradients
                if (batch_idx + 1) % grad_accum_steps == 0:
                    if Config.MIXED_PRECISION:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()
                    
                    # Step scheduler per optimizer step (not per batch)
                    if scheduler is not None:
                        try:
                            scheduler.step()
                        except Exception:
                            pass

                train_loss += loss.item() * grad_accum_steps  # Unscale for logging
                
                # Track GPU active time
                iter_end = time.time()
                gpu_active_time += (iter_end - iter_start)
                last_gpu_check = iter_end
            train_loss /= max(1, len(train_loader))
            t_train = time.time() - t0                         
            
            # GPU utilization report
            total_time = gpu_active_time + gpu_idle_time
            if total_time > 0:
                gpu_util_pct = (gpu_active_time / total_time) * 100
                print(f"⚡ GPU Utilization: {gpu_util_pct:.1f}% | Active: {gpu_active_time/60:.1f}min | Idle: {gpu_idle_time/60:.1f}min")

            # DEBUG: Check if model is learning
            print(f"🔍 DEBUG - Epoch {epoch+1}: Train Loss = {train_loss:.4f}")
            
            # Validation: run either fast or full (not both)
            current_state = None
            if use_ema and ema_state is not None:
                current_state = {k: (v.detach().cpu() if torch.is_tensor(v) else v) for k, v in model.state_dict().items()}
                try:
                    model.load_state_dict(ema_state, strict=False)
                except Exception:
                    pass

            do_full = False
            if getattr(Config, 'RUN_FULL_ON_EPOCH_1', False) and (epoch == 0):
                do_full = True
            elif (epoch + 1) % int(getattr(Config, 'FULL_VAL_EVERY', 3)) == 0:
                do_full = True

            # Disable encoder/backbone checkpointing for eval (faster)
            had_gc = False
            core_tmp = model.module if isinstance(model, nn.DataParallel) else model
            enc_tmp = getattr(core_tmp, 'encoder', None)
            if enc_tmp is None:
                enc_tmp = getattr(core_tmp, 'backbone', None)
            if enc_tmp is not None and hasattr(enc_tmp, 'set_grad_checkpointing'):
                try:
                    enc_tmp.set_grad_checkpointing(False)
                    had_gc = True
                except Exception:
                    had_gc = False

            if do_full:
                val_loss, val_auc = _run_validation(
                    model, val_loader_full, criterion, Config.DEVICE,
                    max_tokens=int(getattr(Config, 'VAL_MAX_ENCODER_TOKENS', 384))
                )
                full_val_metrics = (val_loss, val_auc)
            else:
                if val_loader_fast is None:
                    val_loader_fast = build_fast_val_loader(val_dataset)
                val_loss, val_auc = _run_validation(
                    model, val_loader_fast, criterion, Config.DEVICE,
                    max_tokens=int(getattr(Config, 'FAST_VAL_MAX_TOKENS', 256))
                )

            # Restore encoder/backbone checkpointing
            if had_gc and enc_tmp is not None and hasattr(enc_tmp, 'set_grad_checkpointing'):
                try:
                    enc_tmp.set_grad_checkpointing(True)
                except Exception:
                    pass

            # Restore non-EMA weights
            if current_state is not None:
                try:
                    model.load_state_dict(current_state, strict=False)
                except Exception:
                    pass

            t_total = time.time() - t0            
            print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val AUC: {val_auc:.4f}")
            print(f"[Timing] epoch={epoch+1} | train={t_train/60:.2f} min | val_type={'full' if do_full else 'fast'} | epoch_total={t_total/60:.2f} min")            
            
            # Save best model + early stopping (CPU-only tensors). Track best FULL AUC.
            if (do_full and (val_auc > best_full_auc)) or (not do_full and best_full_auc < 0 and val_auc > best_auc):
                if do_full:
                    best_full_auc = val_auc
                best_auc = val_auc
                no_improve = 0
                if use_ema and ema_state is not None:
                    best_state_cpu = {k: v.detach().cpu().clone() for k, v in ema_state.items()}
                else:
                    best_state_cpu = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                # Architecture-specific model naming
                if arch == 'simple2d':
                    arch_name = 'simple2d'
                elif arch == 'mil2p5d':
                    arch_name = 'mil'
                else:
                    arch_name = 'rsna'
                model_filename = f'stage2_{arch_name}_fold_{fold}_best.pth'
                
                torch.save({
                    'model_state_dict': best_state_cpu,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_auc': val_auc,
                    'epoch': epoch,
                    'fold': fold,
                    'architecture': arch  # Save architecture type for ensemble loading
                }, model_filename)
                print(f"💾 Saved best model: {model_filename} (AUC: {val_auc:.4f})")
            else:
                no_improve += 1
                if do_full and no_improve >= patience:
                    print(f"⏹️ Early stopping (patience={patience}) at epoch {epoch+1}")
                    break

            # EMA update (compute on CPU tensors)
            if use_ema and ema_state is not None:
                with torch.no_grad():
                    for k, v in model.state_dict().items():
                        if k in ema_state and isinstance(v, torch.Tensor) and torch.is_floating_point(v):
                            tgt = ema_state[k]
                            ema_state[k] = tgt.mul(ema_decay).add(
                                v.detach().float().cpu(),
                                alpha=(1.0 - ema_decay),
                            )
            train_loss /= max(1, len(train_loader))
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass
        
        fold_scores.append(best_auc)
        print(f"Fold {fold + 1} best AUC: {best_auc:.4f}")
    
    # Final results
    mean_cv_score = np.mean(fold_scores)
    print(f"\n✅ Cross-validation complete!")
    print(f"Mean CV AUC: {mean_cv_score:.4f} ± {np.std(fold_scores):.4f}")
    print(f"Individual fold scores: {fold_scores}")

print("✅ Training pipeline loaded")

# ====================================================
# CELL 5: INFERENCE & SUBMISSION
# ====================================================

class InferenceConfig:
    """Configuration for inference server"""
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ID_COL = 'SeriesInstanceUID'
    LABEL_COLS = [
        'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery',
        'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
        'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery', 'Anterior Communicating Artery',
        'Left Anterior Cerebral Artery', 'Right Anterior Cerebral Artery',
        'Left Posterior Communicating Artery', 'Right Posterior Communicating Artery',
        'Basilar Tip', 'Other Posterior Circulation', 'Aneurysm Present',
    ]


# ===============================
# Inference-time TTA helpers
NAME_TO_IDX = {name: i for i, name in enumerate(Config.LABEL_COLS)}
LR_PAIRS = [
    ("Left Infraclinoid Internal Carotid Artery", "Right Infraclinoid Internal Carotid Artery"),
    ("Left Supraclinoid Internal Carotid Artery", "Right Supraclinoid Internal Carotid Artery"),
    ("Left Middle Cerebral Artery", "Right Middle Cerebral Artery"),
    ("Left Anterior Cerebral Artery", "Right Anterior Cerebral Artery"),
    ("Left Posterior Communicating Artery", "Right Posterior Communicating Artery"),
]
UNI_CLASSES = [
    "Anterior Communicating Artery", "Basilar Tip", "Other Posterior Circulation", "Aneurysm Present"
]

def swap_lr_logits(logits: torch.Tensor) -> torch.Tensor:
    # logits: (B, C)
    logits = logits.clone()
    for L, R in LR_PAIRS:
        li, ri = NAME_TO_IDX[L], NAME_TO_IDX[R]
        tmp = logits[..., li].clone()
        logits[..., li] = logits[..., ri]
        logits[..., ri] = tmp
    return logits

def predict_with_tta(models: list, x: torch.Tensor) -> torch.Tensor:
    # x: (B,32,H,W) in [0,1]
    x0 = x
    xH = torch.flip(x, dims=[-1])
    xV = torch.flip(x, dims=[-2])
    xR2 = torch.roll(x, shifts=2, dims=[-3])
    xR_2 = torch.roll(x, shifts=-2, dims=[-3])
    batch = torch.cat([x0, xH, xV, xR2, xR_2], dim=0).to(device=x.device)
    logits_models = []
    amp_enabled = bool(getattr(Config, 'MIXED_PRECISION', True) and x.is_cuda)
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=amp_enabled):
        b = x.size(0)
        for m in models:
            z = m(batch.to(memory_format=torch.channels_last))  # (5B, C)
            z = z.view(5, b, -1)                               # (5, B, C)
            z[1] = swap_lr_logits(z[1])                        # H-flip needs L<->R swap
            z_mean = z.mean(0)                                 # (B, C)
            logits_models.append(z_mean)
    logits = torch.stack(logits_models, dim=0).mean(0)         # (B, C)
    return torch.sigmoid(logits)


class ModelEnsemble:
    """Ensemble of Stage 2 models for inference"""
    def __init__(self, model_paths, device):
        self.device = device
        self.models = []
        
        for path in model_paths:
            try:
                # Detect model type from filename
                filename = os.path.basename(path).lower()
                is_mil = 'mil' in filename
                is_simple2d = 'simple2d' in filename
                
                if is_simple2d:
                    # Simple2D model
                    model = Simple2D(num_classes=len(Config.LABEL_COLS)).to(device)
                elif is_mil:
                    # MIL2p5D model
                    model = MIL2p5D(num_classes=len(Config.LABEL_COLS)).to(device)
                else:
                    # RSNA2p5D model (default)
                    model = RSNA2p5D().to(device)
                checkpoint = torch.load(path, map_location=device, weights_only=False)
                # Robustly extract state dict
                state_dict = None
                if isinstance(checkpoint, dict):
                    for key in ('model', 'model_state_dict', 'state_dict'):
                        if key in checkpoint and isinstance(checkpoint[key], dict):
                            state_dict = checkpoint[key]
                            break
                    if state_dict is None:
                        # Some checkpoints save raw weights at top-level
                        if all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
                            state_dict = checkpoint
                if state_dict is None:
                    raise RuntimeError('Unsupported checkpoint format')
                
                # Handle DataParallel wrapper
                if any(key.startswith('module.') for key in state_dict.keys()):
                    state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
                # Strict load first; if it fails, relax
                try:
                    model.load_state_dict(state_dict, strict=True)
                except Exception:
                    model.load_state_dict(state_dict, strict=False)
                model.eval()
                try:
                    model = model.to(memory_format=torch.channels_last)
                except Exception:
                    pass
                self.models.append(model)
                print(f"Loaded model: {path}")
            except Exception as e:
                print(f"Error loading {path}: {e}")
        
        print(f"Loaded {len(self.models)} models for ensemble")
    
    def predict_single(self, series_path):
        """Predict by building a 32x384x384 volume directly from DICOMs using shared preprocessor."""
        vol = process_dicom_series_safe(series_path, target_shape=(32,384,384))
        x = torch.from_numpy(vol).unsqueeze(0)  # [1,32,384,384]
        # Ensure dtype matches FP32 weights; AMP will downcast compute where safe
        x = x.to(device=self.device, dtype=torch.float32, non_blocking=True)
        try:
            x = x.to(memory_format=torch.channels_last)
        except Exception:
            pass
        # Vectorized TTA across models (logit averaging with L<->R swap for H-flip)
        probs = predict_with_tta(self.models, x).cpu().numpy()[0]
        return probs

class InferenceDICOMProcessor:
    """DICOM processor for inference"""
    def __init__(self):
        pass

# Global variables for model ensemble
model_ensemble = None
processor = None

def initialize_models():
    """Initialize models - called once at startup"""
    global model_ensemble, processor
    
    print("Initializing models...")
    
    # Restrict candidates to provided MODEL_DIRS only
    # Support both old and new naming conventions
    candidate_names = [
        # New naming (architecture-specific)
        'stage2_simple2d_fold_0_best.pth', 'stage2_simple2d_fold_1_best.pth',
        'stage2_simple2d_fold_2_best.pth', 'stage2_simple2d_fold_3_best.pth',
        'stage2_simple2d_fold_4_best.pth',
        'stage2_rsna_fold_0_best.pth', 'stage2_rsna_fold_1_best.pth',
        'stage2_rsna_fold_2_best.pth', 'stage2_rsna_fold_3_best.pth',
        'stage2_rsna_fold_4_best.pth',
        'stage2_mil_fold_0_best.pth', 'stage2_mil_fold_1_best.pth',
        'stage2_mil_fold_2_best.pth', 'stage2_mil_fold_3_best.pth',
        'stage2_mil_fold_4_best.pth',
        # Legacy naming (for backward compatibility)
        'stage2_fold_0_best.pth', 'stage2_fold_1_best.pth',
        'stage2_fold_2_best.pth', 'stage2_fold_3_best.pth',
        'stage2_fold_4_best.pth',
    ]
    available_models = []
    for d in getattr(Config, 'MODEL_DIRS', []):
        if not isinstance(d, str) or not len(d):
            continue
        for name in candidate_names:
            p = os.path.join(d, name)
            if os.path.exists(p):
                available_models.append(p)
    
    if not available_models:
        print("Warning: No trained models found! Using dummy predictions.")
        model_ensemble = None
    else:
        try:
            model_ensemble = ModelEnsemble(available_models, InferenceConfig.DEVICE)
            print("Models initialized successfully!")
        except Exception as e:
            print(f"Error initializing models: {e}")
            model_ensemble = None
    
    processor = InferenceDICOMProcessor()

def predict(series_path: str) -> pl.DataFrame:
    """Make predictions for the competition API"""
    global model_ensemble, processor
    
    # Initialize models on first call (lazy loading)
    if model_ensemble is None and processor is None:
        initialize_models()
    
    series_id = os.path.basename(series_path)
    
    try:
        if model_ensemble is not None:
            # Use trained ensemble
            predictions = model_ensemble.predict_single(series_path)
        else:
            # Fallback: extract metadata and make informed dummy predictions
            print(f"Using fallback prediction for {series_id}")
            
            # Load DICOM metadata
            all_filepaths = []
            for root, _, files in os.walk(series_path):
                for file in files:
                    if file.endswith('.dcm'):
                        all_filepaths.append(os.path.join(root, file))
            
            if all_filepaths:
                ds = pydicom.dcmread(all_filepaths[0], force=True)
                modality = getattr(ds, 'Modality', 'UNKNOWN')
                
                # Slightly better informed predictions based on modality
                if modality in ['CTA', 'MRA']:
                    # Vascular imaging - slightly higher probability
                    base_prob = 0.1
                else:
                    # Other modalities - lower baseline
                    base_prob = 0.05
                
                # Add some noise to make predictions more realistic
                predictions = np.random.normal(base_prob, 0.02, len(InferenceConfig.LABEL_COLS))
                predictions = np.clip(predictions, 0.001, 0.999)
            else:
                # No DICOM files found
                predictions = np.full(len(InferenceConfig.LABEL_COLS), 0.5)

        # Ensure predictions is numpy array and convert to list safely
        if not isinstance(predictions, np.ndarray):
            predictions = np.array(predictions)
        
        # Create prediction DataFrame
        prediction_df = pl.DataFrame(
            data=[[series_id] + predictions.tolist()],
            schema=[InferenceConfig.ID_COL, *InferenceConfig.LABEL_COLS],
            orient='row',
        )
        
    except Exception as e:
        print(f"Error processing {series_id}: {e}")
        # Return safe default predictions
        prediction_df = pl.DataFrame(
            data=[[series_id] + [0.5] * len(InferenceConfig.LABEL_COLS)],
            schema=[InferenceConfig.ID_COL, *InferenceConfig.LABEL_COLS],
            orient='row',
        )
    
    # IMPORTANT: Remove SeriesInstanceUID before returning (API requirement)
    prediction_df = prediction_df.drop(InferenceConfig.ID_COL)
    
    # IMPORTANT: Disk cleanup to prevent "out of disk space" errors
    shutil.rmtree('/kaggle/shared', ignore_errors=True)
    
    return prediction_df


# ====================================================
# SERVER EXECUTION
# ====================================================

# Initialize the inference server
inference_server = kaggle_evaluation.rsna_inference_server.RSNAInferenceServer(predict)

print("✅ Inference and submission pipeline loaded")

# ====================================================
# CELL 6: MAIN EXECUTION
# ====================================================

if __name__ == "__main__":
    if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
        # Production mode - serve the API
        print("Starting inference server...")
        inference_server.serve()
    else:
        # Local/dev mode - default to inference only unless explicitly enabled
        print("Ready for Stage 2 training!")
        print("Set Config.TRAIN_ON_START=True to run training.")
        if getattr(Config, 'TRAIN_ON_START', False):
            main_training()
        
        # Or run local testing
        print("Running local gateway for testing...")
        inference_server.run_local_gateway()
        
        # Display results if available + sanity checks + CSV
        results_path = '/kaggle/working/submission.parquet'
        if os.path.exists(results_path):
            results_df = pl.read_parquet(results_path)

            # GO / NO-GO checks
            import pandas as pd
            pdf = results_df.to_pandas()
            expected_label_cols = list(Config.LABEL_COLS)
            id_col = getattr(Config, 'ID_COL', 'SeriesInstanceUID')
            if id_col in pdf.columns:
                expected_cols = [id_col] + expected_label_cols
                # Reorder to expected order if needed
                pdf = pdf[expected_cols]
            else:
                expected_cols = expected_label_cols
            assert list(pdf.columns) == expected_cols, (
                f"Column mismatch.\nExpected ({len(expected_cols)}): {expected_cols}\n"
                f"Got ({len(pdf.columns)}): {list(pdf.columns)}"
            )
            # Validate only label columns for NaNs and range
            label_pdf = pdf[expected_label_cols]
            assert not label_pdf.isna().any().any(), "Submission contains NaNs."
            vmin, vmax = label_pdf.to_numpy().min(), label_pdf.to_numpy().max()
            assert 0.0 <= vmin <= 1.0 and 0.0 <= vmax <= 1.0, (
                f"Out-of-range probabilities: min={vmin}, max={vmax}"
            )
            test_meta = '/kaggle/working/test_series.parquet'
            if os.path.exists(test_meta):
                test_n = int(pl.read_parquet(test_meta).height)
                assert len(pdf) == test_n, f"Row count mismatch: got {len(pdf)}, expected {test_n}"

            print("Submission preview:")
            print(results_df.head())
            out_csv = '/kaggle/working/submission.csv'
            pdf.to_csv(out_csv, index=False)
            print(f"💾 Saved: {out_csv}")
        else:
            print("⚠️ submission.parquet not found. Did inference write it?")

