In [1]:
# ====================================================
# CELL 1: IMPORTS & CONFIG
# ====================================================

import os
import shutil
import numpy as np
import pandas as pd
import polars as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
import cv2
import pydicom
import nibabel as nib
from scipy import ndimage
from scipy.ndimage import label, center_of_mass
from PIL import Image
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score
import kaggle_evaluation.rsna_inference_server
from collections import defaultdict
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import gc

  data = fetch_version_info()


In [None]:
# Competition Configuration
class Config:
    # Paths
    TRAIN_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train.csv'
    SERIES_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/series/'
    SEGMENTATION_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/segmentations/'
    
    # Stage 2 Configuration
    ROI_SIZE = (224, 224)
    ROIS_PER_SERIES = 5
    BATCH_SIZE = 2
    EPOCHS = 3
    LEARNING_RATE = 1e-4
    N_FOLDS = 1
    
    # Competition constants
    ID_COL = 'SeriesInstanceUID'
    LABEL_COLS = [
        'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery',
        'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
        'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery', 'Anterior Communicating Artery',
        'Left Anterior Cerebral Artery', 'Right Anterior Cerebral Artery',
        'Left Posterior Communicating Artery', 'Right Posterior Communicating Artery',
        'Basilar Tip', 'Other Posterior Circulation', 'Aneurysm Present',
    ]
    
    # Device and training
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    MIXED_PRECISION = True
    STAGE2_CACHE_DIR = '/kaggle/working/stage2_cache'
    # Optional: reuse Stage 1 external cache volumes directly for exact preprocessing parity
    # STAGE1_EXTERNAL_CACHE_DIR = '/kaggle/input/rsna2025-v2-intracranial-aneurysm-detection-nb153/stage1_AneurysmNet_prebuilt_v2'
    
    # Debug
    DEBUG_MODE = False
    DEBUG_SAMPLES = 0
    # External ROI/cache reuse (Kaggle dataset with precomputed outputs)
    # ROIS_EXTERNAL_DIR = '/kaggle/input/rsna2025-brainnet-aneurysm-rois/rois'
    # STAGE2_CACHE_EXTERNAL_DIR = '/kaggle/input/rsna2025-brainnet-aneurysm-rois/stage2_cache'
    # Cache/throughput
    REUSE_EXISTING_ROIS = True  # if cached training_df exists, reuse to skip long ROI extraction
    # Direct volume mode (top_example-style) using Stage 0 32x384x384 volumes
    DIRECT_VOLUME_MODE = True
    # STAGE0_PREBUILT_ROOT = '/kaggle/input/rsna2025-v2-intracranial-aneurysm-detection-nb153/stage1_AneurysmNet_prebuilt_v2'
    NUM_WORKERS = 6
    PREFETCH_FACTOR = 6
    PIN_MEMORY = True
    PERSISTENT_WORKERS = True
    # Caching
    CACHE_VOLUMES = True
    CACHE_DIR = '/kaggle/working/stage2_cache_vols'
    CACHE_DTYPE = 'uint8'  # 'uint8' (~4.7MB/series) or 'float16' (~9.4MB/series)
    CACHE_MAX_GB = 20.0    # soft cap; skip saving when exceeded
    CACHE_VERBOSE = False  # set True to log cache hits/writes/skips

print(f"✅ Configuration loaded - Device: {Config.DEVICE}")

# Speed-friendly backend settings
try:
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.set_float32_matmul_precision('high')
except Exception:
    pass

# ====================================================
# CELL 1.1: LIGHTWEIGHT DICOM PREPROCESSOR (32x384x384)
# ====================================================

class DICOMPreprocessorKaggle:
    """Minimal, memory-safe DICOM → (32,384,384) volume preprocessor (offline, no deps)."""
    def __init__(self, target_shape=(32, 384, 384)):
        self.target_depth, self.target_height, self.target_width = target_shape

    def process_series(self, series_path: str) -> np.ndarray:
        # Cache check
        if getattr(Config, 'CACHE_VOLUMES', False):
            try:
                os.makedirs(Config.CACHE_DIR, exist_ok=True)
                sid = os.path.basename(series_path.rstrip('/'))
                cache_base = os.path.join(Config.CACHE_DIR, f"{sid}_32x384x384")
                # Try uint8 then float16 if present
                for suffix in ['_u8.npy', '_f16.npy']:
                    cache_path_try = cache_base + suffix
                    if os.path.exists(cache_path_try):
                        cached = np.load(cache_path_try, mmap_mode='r')
                        vol = cached.astype(np.float32)
                        if cached.dtype == np.uint8:
                            vol = vol / 255.0
                        if getattr(Config, 'CACHE_VERBOSE', False):
                            print(f"[CACHE] hit: {os.path.basename(cache_path_try)}")
                        return vol
            except Exception:
                cache_path = None
        # Collect DICOMs
        dicoms = []
        for root, _, files in os.walk(series_path):
            for f in files:
                if f.endswith('.dcm'):
                    try:
                        ds = pydicom.dcmread(os.path.join(root, f), force=True)
                        if hasattr(ds, 'PixelData'):
                            dicoms.append(ds)
                    except Exception:
                        continue
        if len(dicoms) == 0:
            return np.zeros((self.target_depth, self.target_height, self.target_width), dtype=np.float32)

        # Sort by patient-space normal if possible, else by InstanceNumber
        try:
            orient = np.array(dicoms[0].ImageOrientationPatient, dtype=np.float32)
            rowv, colv = orient[:3], orient[3:]
            normal = np.cross(rowv, colv)
            def sort_key(ds):
                ipp = np.array(getattr(ds, 'ImagePositionPatient', [0,0,0]), dtype=np.float32)
                return float(np.dot(ipp, normal))
            dicoms = sorted(dicoms, key=sort_key)
        except Exception:
            dicoms = sorted(dicoms, key=lambda ds: getattr(ds, 'InstanceNumber', 0))

        base_h = int(getattr(dicoms[0], 'Rows', 256))
        base_w = int(getattr(dicoms[0], 'Columns', 256))
        c, w = 300.0, 700.0
        lo, hi = c - w/2.0, c + w/2.0
        modality = (getattr(dicoms[0], 'Modality', '') or '').upper()

        slices = []
        for ds in dicoms:
            try:
                fr = ds.pixel_array
            except Exception:
                continue
            if fr.ndim >= 3:
                h, w2 = fr.shape[-2], fr.shape[-1]
                frames = fr.reshape(int(np.prod(fr.shape[:-2])), h, w2)
            else:
                frames = fr[np.newaxis, ...]
            for sl in frames:
                sl = sl.astype(np.float32)
                if getattr(ds, 'PhotometricInterpretation', 'MONOCHROME2') == 'MONOCHROME1':
                    sl = sl.max() - sl
                slope = float(getattr(ds, 'RescaleSlope', 1.0)); intercept = float(getattr(ds, 'RescaleIntercept', 0.0))
                sl = sl * slope + intercept
                if sl.shape != (base_h, base_w):
                    sl = cv2.resize(sl, (base_w, base_h))
                if modality == 'CT':
                    s = np.clip(sl, lo, hi)
                    s = (s - lo) / (hi - lo + 1e-6)
                else:
                    mean = float(sl.mean()); std = float(sl.std() + 1e-6)
                    s = (sl - mean) / std; zc = 3.0
                    s = np.clip(s, -zc, zc); s = (s + zc) / (2.0*zc)
                slices.append(s.astype(np.float32))

        volf = np.stack(slices, axis=0) if slices else np.zeros((1, base_h, base_w), dtype=np.float32)
        D = volf.shape[0]
        idx = np.linspace(0, max(D-1,0), num=self.target_depth).astype(int) if D>0 else np.zeros(self.target_depth, dtype=int)
        vT = volf[idx]
        out = np.empty((self.target_depth, self.target_height, self.target_width), dtype=np.float32)
        for i in range(self.target_depth):
            out[i] = cv2.resize(vT[i], (self.target_width, self.target_height))
        p1, p99 = np.percentile(out, [1, 99])
        if p99 > p1:
            out = np.clip(out, p1, p99)
            out = (out - p1) / (p99 - p1 + 1e-8)
        out = np.nan_to_num(out, nan=0.0, posinf=1.0, neginf=0.0)
        # Save cache using configured dtype
        try:
            if getattr(Config, 'CACHE_VOLUMES', False):
                # Respect soft size cap
                try:
                    total_bytes = 0
                    for f in os.listdir(Config.CACHE_DIR):
                        fp = os.path.join(Config.CACHE_DIR, f)
                        try:
                            total_bytes += os.path.getsize(fp)
                        except Exception:
                            pass
                    if total_bytes > Config.CACHE_MAX_GB * (1024**3):
                        if getattr(Config, 'CACHE_VERBOSE', False):
                            print("[CACHE] size cap reached; skip saving")
                        raise RuntimeError('Cache size cap reached, skip saving')
                except Exception:
                    pass
                # Ensure cache dir exists and choose dtype and filename suffix
                os.makedirs(Config.CACHE_DIR, exist_ok=True)
                sid = os.path.basename(series_path.rstrip('/'))
                cache_base = os.path.join(Config.CACHE_DIR, f"{sid}_32x384x384")
                dtype_choice = getattr(Config, 'CACHE_DTYPE', 'float16')
                if dtype_choice == 'uint8':
                    # Quantize to 8-bit [0,255] after normalization
                    arr8 = (out * 255.0).round().astype(np.uint8)
                    # Optional quick verification to ensure acceptable quantization error
                    try:
                        diff_max = float(np.max(np.abs(out - (arr8.astype(np.float32) / 255.0))))
                        if diff_max > 0.02:  # fallback threshold for safety
                            # Fallback to float16 if quantization too lossy
                            np.save(cache_base + '_f16.npy', out.astype(np.float16), allow_pickle=False)
                            if getattr(Config, 'CACHE_VERBOSE', False):
                                print(f"[CACHE] wrote f16 (fallback): {os.path.basename(cache_base)}_f16.npy")
                        else:
                            np.save(cache_base + '_u8.npy', arr8, allow_pickle=False)
                            if getattr(Config, 'CACHE_VERBOSE', False):
                                print(f"[CACHE] wrote u8: {os.path.basename(cache_base)}_u8.npy (max_err={diff_max:.4f})")
                    except Exception:
                        np.save(cache_base + '_u8.npy', arr8, allow_pickle=False)
                        if getattr(Config, 'CACHE_VERBOSE', False):
                            print(f"[CACHE] wrote u8 (no-check): {os.path.basename(cache_base)}_u8.npy")
                else:
                    np.save(cache_base + '_f16.npy', out.astype(np.float16), allow_pickle=False)
                    if getattr(Config, 'CACHE_VERBOSE', False):
                        print(f"[CACHE] wrote f16: {os.path.basename(cache_base)}_f16.npy")
        except Exception:
            pass
        return out.astype(np.float32)

def process_dicom_series_safe(series_path: str, target_shape=(32,384,384)) -> np.ndarray:
    try:
        pre = DICOMPreprocessorKaggle(target_shape)
        return pre.process_series(series_path)
    finally:
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception:
            pass
        try:
            gc.collect()
        except Exception:
            pass


# ====================================================
# CELL 2: DATA LOADING & ROI EXTRACTION
# ====================================================

class Simple3DSegmentationNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.dummy = nn.Identity()
    def forward(self, x):
        return self.dummy(x)

class Stage1Predictor:
    def __init__(self, *args, **kwargs):
        pass
    def predict_segmentation_with_volume(self, series_path):
        # Not used in direct-volume path; kept minimal for compatibility
        return np.zeros((1,1,1), dtype=np.float32), np.zeros((32,384,384), dtype=np.float32)

class ROIExtractor:
    """Research-backed ROI extraction with adaptive count and quality filtering"""
    def __init__(self, stage1_predictor, roi_size=(224, 224)):
        self.stage1_predictor = stage1_predictor
        self.roi_size = roi_size
        self.processor = SimpleDICOMProcessor()

        # Research-backed thresholds
        # Relaxed thresholds to avoid over-pruning when Stage 1 is weak
        self.min_confidence_threshold = 0.15
        self.high_confidence_threshold = 0.5
        self.max_rois_per_series = getattr(Config, 'ROIS_PER_SERIES', 3)
        # Post-process controls
        self.border_margin = 2            # suppress edge activations near skull
        self.min_region_size = 6         # minimum connected component size (pixels)
        self.morph_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))

    def extract_top3_rois(self, series_path):
        """Extract 0-5 ROIs based on segmentation quality (research-backed)"""
        # Cache ROI results per series to avoid recomputation
        try:
            # Prefer external cache dir when available
            cache_root = getattr(Config, 'STAGE2_CACHE_EXTERNAL_DIR', '')
            if not isinstance(cache_root, str) or not os.path.isdir(cache_root):
                cache_root = Config.STAGE2_CACHE_DIR
            os.makedirs(cache_root, exist_ok=True)
            sid = os.path.basename(series_path)
            cache_path = os.path.join(cache_root, f"{sid}_rois.npy")
            if os.path.exists(cache_path):
                arr = np.load(cache_path, allow_pickle=True)
                return list(arr)
        except Exception:
            cache_path = None
        rois = self.extract_adaptive_rois(series_path)
        try:
            if cache_path is not None:
                np.save(cache_path, np.array(rois, dtype=object), allow_pickle=True)
        except Exception:
            pass
        return rois

    def extract_adaptive_rois(self, series_path):
        """Extract 0-5 ROIs based on segmentation quality (research-backed)"""
        try:
            print(f"🔍 DEBUG: Quality-based ROI extraction for {os.path.basename(series_path)}")
            
            # Get Stage 1 seg mask and the preprocessed volume (avoid reloading original DICOMs here)
            seg_mask, original_volume = self.stage1_predictor.predict_segmentation_with_volume(series_path)
            print(f"🔍 DEBUG: Segmentation mask shape: {seg_mask.shape}; Volume shape: {original_volume.shape}")
            
            # STEP 1: Assess overall segmentation quality
            seg_quality = self._assess_segmentation_quality(seg_mask)
            print(f"🔍 DEBUG: Segmentation quality score: {seg_quality:.3f}")
            
            # STEP 2: If segmentation is poor, still attempt candidate extraction; fallback only if none
            low_quality = seg_quality < self.min_confidence_threshold
            if low_quality:
                print(f"🔍 DEBUG: Low segmentation quality ({seg_quality:.3f} < {self.min_confidence_threshold}), attempting candidate extraction anyway")
            
            # STEP 4: Extract ROIs with confidence-based filtering
            roi_candidates = self._find_quality_based_rois(seg_mask, original_volume)
            
            if low_quality and not roi_candidates:
                print("🔍 DEBUG: No candidates under low-quality mask, using volume-based fallback")
                return self._get_quality_fallback_rois_from_volume(original_volume, self.max_rois_per_series)

            # STEP 5: Adaptive ROI count
            selected_rois = self._select_adaptive_rois(roi_candidates, seg_quality, original_volume)
            
            print(f"🔍 DEBUG: Selected {len(selected_rois)} ROIs based on quality assessment")
            return selected_rois
            
        except Exception as e:
            print(f"❌ Error in quality-based ROI extraction: {e}")
            return self._get_emergency_fallback_rois()
    
    def _assess_segmentation_quality(self, seg_mask):
        """Assess segmentation quality using connected components and border penalties."""
        try:
            D, H, W = seg_mask.shape
            largest_area_frac = 0.0
            largest_mean_conf = 0.0
            total_components = 0
            border_touch_penalty = 0.0

            for z in range(D):
                sm = seg_mask[z]
                # suppress borders
                sm_proc = sm.copy()
                sm_proc[:self.border_margin, :] = 0
                sm_proc[-self.border_margin:, :] = 0
                sm_proc[:, :self.border_margin] = 0
                sm_proc[:, -self.border_margin:] = 0

               # Adaptive thresholding based on actual max values
                max_val = float(sm_proc.max())
                if max_val > 0.3:
                    thr = max(0.05, 0.3 * max_val)
                elif max_val > 0.1:
                    thr = max(0.03, 0.4 * max_val)
                else:
                    thr = max(0.02, 0.5 * max_val)
                binmask = (sm_proc > thr).astype(np.uint8)
                if binmask.max() == 0:
                    continue
                # small opening to remove speckle
                binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, self.morph_kernel)

                labeled, n = label(binmask)
                if n == 0:
                    continue
                total_components += int(n)

                # evaluate components
                for comp_id in range(1, n + 1):
                    comp = (labeled == comp_id)
                    comp_size = int(comp.sum())
                    if comp_size < self.min_region_size:
                        continue
                    mean_conf = float(sm[comp].mean())
                    area_frac = comp_size / float(H * W)
                    if area_frac > largest_area_frac:
                        largest_area_frac = area_frac
                    if mean_conf > largest_mean_conf:
                        largest_mean_conf = mean_conf

                    # simple border-touch penalty if component abuts image edge
                    ys, xs = np.where(comp)
                    if ys.size > 0:
                        if (ys.min() <= self.border_margin or ys.max() >= H - self.border_margin - 1 or
                            xs.min() <= self.border_margin or xs.max() >= W - self.border_margin - 1):
                            border_touch_penalty += 0.02

            # compose quality score
            area_score = min(largest_area_frac / 0.02, 1.0)  # cap around ~2% of slice (aneurysm-sized)
            comp_penalty = min(0.1, 0.0015 * total_components) + min(0.1, border_touch_penalty)
            quality_score = max(0.0, 0.6 * largest_mean_conf + 0.4 * area_score - comp_penalty)

            # robust floor based on global mask stats to avoid spurious 0.0 quality
            max_val = float(seg_mask.max())
            mean_val = float(seg_mask.mean())
            if max_val >= 0.55:
                quality_score = max(quality_score, 0.35)
            elif max_val >= 0.45:
                quality_score = max(quality_score, 0.25)
            elif mean_val >= 0.25:
                quality_score = max(quality_score, 0.22)

            return float(quality_score)
        except Exception:
            return 0.1
    
    def _find_quality_based_rois(self, seg_mask, original_volume):
        """Find ROI candidates with confidence scores (no hardcoded count)"""
        print("🔍 DEBUG: Finding quality-based ROI candidates...")
        
        # Resize segmentation mask to match original volume
        if seg_mask.shape != original_volume.shape:
            print("🔍 DEBUG: Resizing segmentation mask with cv2...")
            seg_mask_resized = np.zeros(original_volume.shape, dtype=np.float32)
            for i in range(min(seg_mask.shape[0], original_volume.shape[0])):
                if i < seg_mask.shape[0]:
                    resized_slice = cv2.resize(
                        seg_mask[i],
                        (original_volume.shape[2], original_volume.shape[1])
                    )
                    seg_mask_resized[i] = resized_slice
        else:
            seg_mask_resized = seg_mask
        
        # 3D peak proposals first (relative peak logic; does not lower thresholds)
        roi_candidates = self._proposals_from_3d_peaks(seg_mask_resized)
        if len(roi_candidates) == 0:
            # Fall back to 2D slice-wise CC method
            roi_candidates = []
        
        H, W = original_volume.shape[1], original_volume.shape[2]
        for slice_idx in range(seg_mask_resized.shape[0]):
            slice_mask = seg_mask_resized[slice_idx].copy()

            # Suppress borders to avoid skull/edge activations
            slice_mask[:self.border_margin, :] = 0
            slice_mask[-self.border_margin:, :] = 0
            slice_mask[:, :self.border_margin] = 0
            slice_mask[:, -self.border_margin:] = 0

            # Adaptive dynamic threshold tied to local max (aligned with quality assessment)
            max_val = float(slice_mask.max())
            if max_val > 0.2:
                thr = max(self.min_confidence_threshold, 0.25 * max_val)
            elif max_val > 0.1:
                thr = max(0.03, 0.30 * max_val)
            else:
                thr = max(0.02, 0.25 * max_val)
            high_conf_regions = (slice_mask > thr).astype(np.uint8)
            if high_conf_regions.max() == 0:
                # Percentile-based fallback with small dilation to form blobs
                p90 = float(np.percentile(slice_mask, 90))
                if p90 > 0:
                    mask_peaks = (slice_mask >= p90).astype(np.uint8)
                    # small dilation to merge nearby high pixels
                    mask_peaks = cv2.dilate(mask_peaks, self.morph_kernel, iterations=1)
                    labeled_regions, num_regions = label(mask_peaks)
                    for region_id in range(1, num_regions + 1):
                        region_mask = (labeled_regions == region_id)
                        region_size = int(region_mask.sum())
                        if region_size < 3:
                            continue
                        ys, xs = np.where(region_mask)
                        if ys.size == 0:
                            continue
                        # Skip borders
                        if (ys.min() <= self.border_margin or ys.max() >= H - self.border_margin - 1 or
                            xs.min() <= self.border_margin or xs.max() >= W - self.border_margin - 1):
                            continue
                        com = center_of_mass(region_mask)
                        y, x = int(com[0]), int(com[1])
                        region_confidence = float(slice_mask[region_mask].mean())
                        roi_candidates.append({
                            'slice_idx': slice_idx,
                            'y': y,
                            'x': x,
                            'confidence': region_confidence,
                            'region_size': region_size
                        })
                continue
            # Apply opening only if region is sufficiently large; avoid eroding tiny blobs
            if int(high_conf_regions.sum()) > 50:
                high_conf_regions = cv2.morphologyEx(high_conf_regions, cv2.MORPH_OPEN, self.morph_kernel)

            labeled_regions, num_regions = label(high_conf_regions)
            for region_id in range(1, num_regions + 1):
                region_mask = (labeled_regions == region_id)
                region_size = int(region_mask.sum())
                if region_size < self.min_region_size:
                    continue
                ys, xs = np.where(region_mask)
                if ys.size == 0:
                    continue
                # Skip border-touching components
                if (ys.min() <= self.border_margin or ys.max() >= H - self.border_margin - 1 or
                    xs.min() <= self.border_margin or xs.max() >= W - self.border_margin - 1):
                    continue

                com = center_of_mass(region_mask)
                y, x = int(com[0]), int(com[1])
                region_confidence = float(slice_mask[region_mask].mean())

                roi_candidates.append({
                    'slice_idx': slice_idx,
                    'y': y,
                    'x': x,
                    'confidence': region_confidence,
                    'region_size': region_size
                })
        
        # Sort by confidence (descending)
        if not roi_candidates:
            # Volume-wise peak fallback: pick top maxima per slice (excluding borders)
            print("🔍 DEBUG: No ROI components found; using volume-wise peak fallback")
            D = seg_mask_resized.shape[0]
            peak_candidates = []
            for z in range(D):
                m = seg_mask_resized[z].copy()
                # suppress borders
                m[:self.border_margin, :] = 0
                m[-self.border_margin:, :] = 0
                m[:, :self.border_margin] = 0
                m[:, -self.border_margin:] = 0
                yx = np.unravel_index(np.argmax(m), m.shape)
                y, x = int(yx[0]), int(yx[1])
                conf = float(m[y, x])
                if conf > 0:
                    peak_candidates.append({
                        'slice_idx': z,
                        'y': y,
                        'x': x,
                        'confidence': conf,
                        'region_size': 1
                    })
            # Keep strongest few peaks across volume
            peak_candidates.sort(key=lambda c: c['confidence'], reverse=True)
            roi_candidates.extend(peak_candidates[: max( self.max_rois_per_series * 3, 6)])

        roi_candidates.sort(key=lambda x: x['confidence'], reverse=True)
        
        print(f"🔍 DEBUG: Found {len(roi_candidates)} ROI candidates")
        return roi_candidates

    def _proposals_from_3d_peaks(self, seg_mask_zyx: np.ndarray):
        """3D local-max proposals with seeded relative growth (no absolute threshold lowering)."""
        try:
            D, H, W = seg_mask_zyx.shape
            # Light 3D smoothing to stabilize local maxima
            try:
                sm = ndimage.gaussian_filter(seg_mask_zyx.astype(np.float32), sigma=0.75)
            except Exception:
                sm = seg_mask_zyx.astype(np.float32)
            # 3D local maxima via maximum filter
            footprint = np.ones((3,3,3), dtype=np.uint8)
            max_f = ndimage.maximum_filter(sm, footprint=footprint, mode='nearest')
            peaks = (sm == max_f)
            # Suppress borders
            b = self.border_margin
            if b > 0:
                peaks[:, :b, :] = False; peaks[:, -b:, :] = False
                peaks[:, :, :b] = False; peaks[:, :, -b:] = False
            coords = np.argwhere(peaks)
            if coords.shape[0] == 0:
                return []
            # Rank peaks by value and keep top-K to control cost
            values = sm[peaks]
            order = np.argsort(values)[::-1]
            top_k = min(64, order.size)
            selected = coords[order[:top_k]]
            # Non-maximum suppression by 3D distance
            kept = []
            min_dist = 4.0
            for (cz, cy, cx) in selected:
                if any(((cz-kz)**2 + (cy-ky)**2 + (cx-kx)**2) ** 0.5 < min_dist for kz,ky,kx in kept):
                    continue
                kept.append((int(cz), int(cy), int(cx)))
                if len(kept) >= 64:
                    break
            # Seeded relative growth
            proposals = []
            for cz, cy, cx in kept:
                peak = float(sm[cz, cy, cx])
                if peak <= 0:
                    continue
                rel_thr = max(0.6*peak, 1e-6)  # relative to each peak
                # collect voxels that descend from the peak (thresholded region)
                region = sm >= rel_thr
                labeled, num = ndimage.label(region)
                cid = int(labeled[cz, cy, cx])
                if cid == 0:
                    continue
                comp = (labeled == cid)
                size = int(comp.sum())
                if size < self.min_region_size:
                    continue
                # score = peak * mean(comp)
                conf = peak * float(sm[comp].mean() + 1e-6)
                # project to a representative slice (peak slice)
                ys, xs = np.where(comp[cz])
                if ys.size == 0:
                    # fallback to COM over full comp
                    zc, yc, xc = ndimage.center_of_mass(comp)
                    zc = int(round(zc)); yc = int(round(yc)); xc = int(round(xc))
                    if yc <= self.border_margin or yc >= H - self.border_margin - 1 or xc <= self.border_margin or xc >= W - self.border_margin - 1:
                        continue
                    proposals.append({
                        'slice_idx': int(zc),
                        'y': int(yc),
                        'x': int(xc),
                        'confidence': float(conf),
                        'region_size': size,
                    })
                else:
                    y = int(ys.mean()); x = int(xs.mean())
                    if y <= self.border_margin or y >= H - self.border_margin - 1 or x <= self.border_margin or x >= W - self.border_margin - 1:
                        continue
                    proposals.append({
                        'slice_idx': int(cz),
                        'y': y,
                        'x': x,
                        'confidence': float(conf),
                        'region_size': size,
                    })
            proposals.sort(key=lambda c: c['confidence'], reverse=True)
            return proposals
        except Exception:
            return []
    
    def _select_adaptive_rois(self, roi_candidates, seg_quality, original_volume):
        """Adaptively select ROIs based on segmentation quality (research-backed)"""
        if not roi_candidates:
            print("🔍 DEBUG: No candidates found, using fallback")
            return self._get_quality_fallback_rois_from_volume(original_volume)
        
        # Adaptive selection based on segmentation quality
        if seg_quality >= self.high_confidence_threshold:
            max_rois = self.max_rois_per_series
            min_confidence = 0.3
        elif seg_quality >= self.min_confidence_threshold + 0.2:
            max_rois = self.max_rois_per_series
            min_confidence = 0.2
        else:
            max_rois = self.max_rois_per_series
            min_confidence = 0.05
        
        # Filter and select ROIs
        filtered = [c for c in roi_candidates if c['confidence'] >= min_confidence]
        selected_candidates = filtered[:max_rois]
        # If not enough, top-off with next best candidates
        if len(selected_candidates) < max_rois:
            for c in roi_candidates:
                if c in selected_candidates:
                    continue
                selected_candidates.append(c)
                if len(selected_candidates) >= max_rois:
                    break
        
        # Convert to ROI format
        rois = []
        for i, candidate in enumerate(selected_candidates):
            roi_patch = self._extract_roi_patch(
                original_volume,
                candidate['slice_idx'], 
                candidate['y'], 
                candidate['x']
            )
            
            rois.append({
                'roi_image': roi_patch,
                'slice_idx': candidate['slice_idx'],
                'coordinates': (candidate['y'], candidate['x']),
                'confidence': candidate['confidence'],
                'roi_id': i
            })
        # Ensure at least max_rois via center-based fallback if still short
        if len(rois) < self.max_rois_per_series:
            needed = self.max_rois_per_series - len(rois)
            center_fallbacks = self._get_quality_fallback_rois_from_volume(original_volume, needed)
            rois.extend(center_fallbacks)
        print(f"🔍 DEBUG: Adaptively selected {len(rois)} ROIs (quality: {seg_quality:.3f})")
        return rois[: self.max_rois_per_series]
    
    def _get_quality_fallback_rois(self, series_path, seg_mask):
        """Fallback for poor segmentation quality: generate multiple center-based ROIs"""
        print("🔍 DEBUG: Using quality-aware fallback (multi-center ROIs)")
        original_volume = self._load_efficient_volume(series_path)
        return self._get_quality_fallback_rois_from_volume(original_volume, self.max_rois_per_series)

    def _get_quality_fallback_rois_from_volume(self, original_volume, count: int = 3):
        D, H, W = original_volume.shape
        # Choose slice indices: center and quartiles
        slices = sorted(set([D // 2, max(0, D // 4), min(D - 1, 3 * D // 4)]))
        # Ensure desired count
        while len(slices) < count:
            # Add random slices if needed
            slices.append(np.random.randint(0, D))
            slices = list(dict.fromkeys(slices))
        rois = []
        cy, cx = H // 2, W // 2
        for i, s in enumerate(slices[:count]):
            roi_patch = self._extract_roi_patch(original_volume, s, cy, cx)
            rois.append({
                'roi_image': roi_patch,
                'slice_idx': s,
                'coordinates': (cy, cx),
                'confidence': 0.2,
                'roi_id': i
            })
        return rois
    
    def _get_simple_fallback_rois(self):
        """Simple fallback when no quality ROIs found"""
        print("🔍 DEBUG: Using simple fallback (single center ROI)")
        dummy_roi = np.random.random((*Config.ROI_SIZE, 3)).astype(np.float32)
        return [{
            'roi_image': dummy_roi,
            'slice_idx': 25,
            'coordinates': (128, 128),
            'confidence': 0.1,
            'roi_id': 0
        }]
    
    def _get_emergency_fallback_rois(self):
        """Emergency fallback when everything fails"""
        print("🔍 DEBUG: Using emergency fallback ROI")
        dummy_roi = np.random.random((*Config.ROI_SIZE, 3)).astype(np.float32)
        return [{
            'roi_image': dummy_roi,
            'slice_idx': 0,
            'coordinates': (128, 128),
            'confidence': 0.1,
            'roi_id': 0
        }]

    
    def _load_efficient_volume(self, series_path):
        """Load volume with smart distributed sampling to cover entire brain"""
        try:
            # Cache original volume slices to reduce repeated I/O
            os.makedirs(Config.STAGE2_CACHE_DIR, exist_ok=True)
            sid = os.path.basename(series_path)
            vcache = os.path.join(Config.STAGE2_CACHE_DIR, f"{sid}_vol.npy")
            if os.path.exists(vcache):
                return np.load(vcache, allow_pickle=False)
            dicom_files = [f for f in os.listdir(series_path) if f.endswith('.dcm')]
            pixel_arrays = []
            
            # SMART SAMPLING: Distribute 50 slices across entire volume
            total_files = len(dicom_files)
            if total_files > 50:
                # Calculate step size to distribute slices evenly
                step = total_files / 50
                selected_indices = [int(i * step) for i in range(50)]
                selected_files = [dicom_files[i] for i in selected_indices]
                print(f"🔍 DEBUG: Smart sampling - selected {len(selected_files)} files from {total_files} total (every {step:.1f})")
            else:
                selected_files = dicom_files
                print(f"🔍 DEBUG: Using all {len(selected_files)} files (less than 50)")
            
            for f in selected_files:
                try:
                    ds = pydicom.dcmread(os.path.join(series_path, f), force=True)
                    if hasattr(ds, 'pixel_array'):
                        arr = ds.pixel_array
                        if arr.ndim == 2:
                            pixel_arrays.append(arr)
                except:
                    continue
            
            if pixel_arrays:
                # SMALLER target shape to reduce memory usage
                target_shape = (256, 256)  # Reduced from (512, 512)
                
                resized_arrays = []
                for arr in pixel_arrays:
                    # Use cv2.resize instead of ndimage.zoom (more reliable)
                    if arr.shape != target_shape:
                        resized_arr = cv2.resize(arr.astype(np.float32), target_shape)
                        resized_arrays.append(resized_arr)
                    else:
                        resized_arrays.append(arr.astype(np.float32))
                
                volume = np.stack(resized_arrays, axis=0)
                
                # Simple normalization
                p1, p99 = np.percentile(volume, [1, 99])
                volume = np.clip(volume, p1, p99)
                volume = (volume - p1) / (p99 - p1 + 1e-8)
                
                try:
                    np.save(vcache, volume.astype(np.float32), allow_pickle=False)
                except Exception:
                    pass
                return volume
            
        except Exception as e:
            print(f"Error loading efficient volume: {e}")
        
        # Fallback volume (matches our smart sampling approach)
        return np.random.random((50, 256, 256)).astype(np.float32)

    
    def _extract_roi_patch(self, volume, slice_idx, center_y, center_x):
        """Extract ROI with adjacent-slice context as RGB channels (s-1, s, s+1)."""
        D, H, W = volume.shape
        s_indices = [max(0, slice_idx - 1), slice_idx, min(D - 1, slice_idx + 1)]
        channels = []
        half_size = Config.ROI_SIZE[0] // 2
        for s in s_indices:
            slice_data = volume[s]
            h, w = slice_data.shape
            y1 = max(0, center_y - half_size)
            y2 = min(h, center_y + half_size)
            x1 = max(0, center_x - half_size)
            x2 = min(w, center_x + half_size)
            patch = slice_data[y1:y2, x1:x2]
            patch_resized = cv2.resize(patch, Config.ROI_SIZE)
            channels.append(patch_resized)
        patch_3ch = np.stack(channels, axis=2)
        return patch_3ch
    

def create_training_data(df, stage1_predictor):
    """Create training data. If DIRECT_VOLUME_MODE, bypass ROI extraction and use Stage0 32x384x384 volumes."""
    print("🔄 Extracting ROIs for training data...")
    
    # Direct volume mode: build dataframe pointing to Stage0 32x384x384 volumes
    if getattr(Config, 'DIRECT_VOLUME_MODE', False):
        print("✅ DIRECT_VOLUME_MODE: using Stage 0 32x384x384 volumes (no ROI extraction)")
        # On-the-fly generation: do not depend on prebuilt volumes; dataset will read DICOMs
        records = []
        for _, row in df.iterrows():
            sid = str(row[Config.ID_COL])
            rec = {
                'roi_id': f"{sid}_vol32",
                'roi_path': '',
                'series_id': sid,
                'roi_confidence': 1.0,
                'slice_idx': -1,
            }
            for col in Config.LABEL_COLS:
                rec[col] = row[col]
            records.append(rec)
        training_df = pd.DataFrame(records)
        print(f"✅ DIRECT_VOLUME_MODE: built {len(training_df)} samples from {len(df)} series")
        return training_df

    # Reuse cached ROIs/training dataframe if available
    cache_dir = 'rois'
    os.makedirs(cache_dir, exist_ok=True)
    cached_df_path_parquet = os.path.join(cache_dir, 'training_df.parquet')
    external_cached_df_path = os.path.join(getattr(Config, 'ROIS_EXTERNAL_DIR', ''), 'training_df.parquet')
    if Config.REUSE_EXISTING_ROIS:
        # Prefer working cache
        if os.path.exists(cached_df_path_parquet):
            try:
                cached = pl.read_parquet(cached_df_path_parquet).to_pandas()
                if len(cached) > 0 and all(c in cached.columns for c in ['roi_path', 'roi_id', 'series_id'] + Config.LABEL_COLS):
                    print(f"✅ Reusing cached training ROIs (working): {len(cached)} samples from {cached['series_id'].nunique()} series")
                    return cached
            except Exception:
                pass
        # Fallback to external cache
        if isinstance(external_cached_df_path, str) and len(external_cached_df_path) and os.path.exists(external_cached_df_path):
            try:
                cached = pl.read_parquet(external_cached_df_path).to_pandas()
                if len(cached) > 0 and all(c in cached.columns for c in ['roi_path', 'roi_id', 'series_id'] + Config.LABEL_COLS):
                    print(f"✅ Reusing cached training ROIs (external): {len(cached)} samples from {cached['series_id'].nunique()} series")
                    # Optionally copy into working for faster subsequent access
                    try:
                        pl.from_pandas(cached).write_parquet(cached_df_path_parquet)
                    except Exception:
                        pass
                    return cached
            except Exception:
                pass
        # If external parquet is missing but external ROI images exist, auto-build parquet
        ext_dir = getattr(Config, 'ROIS_EXTERNAL_DIR', '')
        if isinstance(ext_dir, str) and os.path.isdir(ext_dir):
            try:
                candidates = [f for f in os.listdir(ext_dir) if f.lower().endswith('.png')]
                if len(candidates) > 0:
                    print(f"🧩 Building training_df from external ROI images: {len(candidates)} files")
                    records = []
                    label_cols = list(Config.LABEL_COLS)
                    # Map labels by series_id for fast join
                    df_labels = df[[Config.ID_COL] + label_cols].copy()
                    df_labels[Config.ID_COL] = df_labels[Config.ID_COL].astype(str)
                    label_map = df_labels.set_index(Config.ID_COL).to_dict('index')
                    for fname in candidates:
                        base = os.path.splitext(fname)[0]
                        # Expect pattern: {series_id}_roi_{k}
                        # Robust parse: split on '_roi_'
                        if '_roi_' not in base:
                            continue
                        sid_part, roi_part = base.split('_roi_', 1)
                        series_id = sid_part
                        try:
                            roi_id_int = int(roi_part)
                        except Exception:
                            roi_id_int = 0
                        rec = {
                            'roi_id': f"{series_id}_roi_{roi_id_int}",
                            'roi_path': os.path.join(ext_dir, fname),
                            'series_id': series_id,
                            'roi_confidence': 0.2,
                            'slice_idx': -1,
                        }
                        # Attach labels
                        labs = label_map.get(series_id)
                        if labs is None:
                            # skip if label missing (should not happen for train)
                            continue
                        for col in label_cols:
                            rec[col] = labs[col]
                        records.append(rec)
                    if records:
                        training_df_ext = pd.DataFrame(records)
                        print(f"✅ Reconstructed training ROIs (external): {len(training_df_ext)} samples from {training_df_ext['series_id'].nunique()} series")
                        try:
                            pl.from_pandas(training_df_ext).write_parquet(cached_df_path_parquet)
                        except Exception:
                            pass
                        return training_df_ext
            except Exception:
                pass
    roi_extractor = ROIExtractor(stage1_predictor)
    training_data = []
    
    os.makedirs('rois', exist_ok=True)
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Extracting ROIs"):
        series_id = row[Config.ID_COL]
        series_path = os.path.join(Config.SERIES_DIR, series_id)
        
        if not os.path.exists(series_path):
            continue
        
        # Extract ROIs
        rois = roi_extractor.extract_top3_rois(series_path)
        
        # Create training samples
        for roi_data in rois:
            roi_filename = f"rois/{series_id}_roi_{roi_data['roi_id']}.png"
            
            # Save ROI image
            roi_image = (roi_data['roi_image'] * 255).astype(np.uint8)
            Image.fromarray(roi_image).save(roi_filename)
            
            # Create training record
            sample = {
                'roi_id': f"{series_id}_roi_{roi_data['roi_id']}",
                'roi_path': roi_filename,
                'series_id': series_id,
                'roi_confidence': roi_data['confidence'],
                'slice_idx': roi_data['slice_idx']
            }
            
            # Add all label columns
            for col in Config.LABEL_COLS:
                sample[col] = row[col]
            
            training_data.append(sample)
    
    training_df = pd.DataFrame(training_data)
    print(f"✅ Created {len(training_df)} training samples from {len(df)} series")
    # Save for reuse next runs
    try:
        pl.from_pandas(training_df).write_parquet(cached_df_path_parquet)
        print(f"💾 Saved training ROI dataframe → {cached_df_path_parquet}")
    except Exception:
        pass
    
    return training_df

print("✅ Data loading and ROI extraction functions loaded")

# ====================================================
# CELL 3: MODEL DEFINITION
# ====================================================

class AneurysmClassificationDataset(Dataset):
    """Dataset for classification. In direct volume mode, builds 32x384x384 on-the-fly from DICOMs."""
    def __init__(self, df, mode='train'):
        self.df = df
        self.mode = mode
        self.direct_volume = getattr(Config, 'DIRECT_VOLUME_MODE', False)
        if self.direct_volume:
            self.preprocessor = DICOMPreprocessorKaggle(target_shape=(32, 384, 384))
            # Albumentations pipeline for 32-channel input: only Normalize + ToTensorV2
            self.alb_transform = A.Compose([
                A.Normalize(mean=0.0, std=1.0, max_pixel_value=1.0),
                ToTensorV2(transpose_mask=False),
            ])
        else:
            # ROI image pipeline (3-channel PNGs) - keep minimal and albumentations-based
            if mode == 'train':
                self.alb_transform = A.Compose([
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.5),
                    A.Rotate(limit=15, p=0.5),
                    A.ColorJitter(brightness=0.2, contrast=0.2, p=0.5),
                    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
                    ToTensorV2(transpose_mask=False),
                ])
            else:
                self.alb_transform = A.Compose([
                    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
                    ToTensorV2(transpose_mask=False),
                ])
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        if self.direct_volume:
            # Build volume directly from series path
            series_id = str(row['series_id'])
            series_path = os.path.join(Config.SERIES_DIR, series_id)
            arr = process_dicom_series_safe(series_path, target_shape=(32,384,384))
            # Albumentations expects (H,W,C). Move channels to last then ToTensorV2 → (C,H,W)
            vol_hwc = np.transpose(arr, (1,2,0))  # (384,384,32)
            out = self.alb_transform(image=vol_hwc)
            image = out['image']  # tensor (32,384,384)
            # Proactive memory cleanup of intermediates
            try:
                del vol_hwc
                del arr
            except Exception:
                pass
        else:
            # Load ROI image
            roi_path = row['roi_path']
            try:
                pil_img = Image.open(roi_path).convert('RGB')
            except:
                pil_img = Image.fromarray(np.random.randint(0, 255, (*Config.ROI_SIZE, 3), dtype=np.uint8))
            np_img = np.array(pil_img)
            try:
                pil_img.close()
            except Exception:
                pass
            out = self.alb_transform(image=np_img)
            image = out['image']
            try:
                del np_img
            except Exception:
                pass
        
        # Get labels
        labels = torch.tensor([row[col] for col in Config.LABEL_COLS], dtype=torch.float32)
        
        return {
            'image': image,
            'labels': labels,
            'roi_id': row['roi_id'],
            'confidence': torch.tensor(row['roi_confidence'], dtype=torch.float32)
        }

class AneurysmEfficientNet(nn.Module):
    """EfficientNet-B3 (ROI mode) or EfficientNetV2-S (32-channel volume mode)."""
    def __init__(self, num_classes=len(Config.LABEL_COLS)):
        super().__init__()
        self.direct_volume = getattr(Config, 'DIRECT_VOLUME_MODE', False)
        if self.direct_volume:
            # 32-channel EfficientNetV2-S
            self.backbone = timm.create_model('tf_efficientnetv2_s.in21k_ft_in1k', pretrained=False, num_classes=0, in_chans=32)
            feature_dim = self.backbone.num_features
        else:
            # ROI classifier (3-channel EfficientNet-B3) with offline weights
            weights_path = '/kaggle/input/tf-efficientnet/pytorch/tf-efficientnet-b3/1/tf_efficientnet_b3_aa-84b4657e.pth'
            try:
                self.backbone = timm.create_model('efficientnet_b3', pretrained=False, num_classes=0)
                if os.path.exists(weights_path):
                    print(f"🔄 Loading offline EfficientNet-B3 weights from: {weights_path}")
                    state_dict = torch.load(weights_path, map_location='cpu', weights_only=False)
                    self.backbone.load_state_dict(state_dict, strict=False)
                    print("✅ Successfully loaded offline EfficientNet-B3 weights!")
                else:
                    print(f"⚠️ Weights file not found at {weights_path}, using random initialization")
            except Exception as e:
                print(f"❌ Error loading offline weights: {e}")
                print("🔄 Falling back to timm without pre-training...")
                self.backbone = timm.create_model('efficientnet_b3', pretrained=False, num_classes=0)
            feature_dim = self.backbone.num_features
        
        # Classification head with dropout
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        features = self.backbone(x)
        logits = self.classifier(features)
        return logits

# Using original EfficientNet approach

def calculate_class_weights(df):
    """Calculate class weights with 13x multiplier for Aneurysm Present"""
    pos_counts = df[Config.LABEL_COLS].sum()
    neg_counts = len(df) - pos_counts
    
    # Standard frequency-based weights
    class_weights = neg_counts / (pos_counts + 1e-8)
    class_weights = np.minimum(class_weights, 100.0)  # Cap at 100
    
    # Apply 13x multiplier to "Aneurysm Present" (matches competition metric)
    class_weights.iloc[-1] = class_weights.iloc[-1] * 13.0
    
    return torch.tensor(class_weights.values, dtype=torch.float32)

print("✅ Model definition loaded")

# ====================================================
# CELL 4: TRAINING PIPELINE
# ====================================================

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch in tqdm(loader, desc="Training"):
        images = batch['image'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)
        
        optimizer.zero_grad()
        
        # Forward pass
        with torch.cuda.amp.autocast(enabled=Config.MIXED_PRECISION):
            images = images.to(memory_format=torch.channels_last)
            logits = model(images)
            loss = criterion(logits, labels)
        
        # Backward pass
        if Config.MIXED_PRECISION:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches

def validate_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating"):
            images = batch['image'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)
            
            with torch.cuda.amp.autocast(enabled=Config.MIXED_PRECISION):
                images = images.to(memory_format=torch.channels_last)
                logits = model(images)
                loss = criterion(logits, labels)
            
            total_loss += loss.item()
            num_batches += 1
            
            # Collect predictions for AUC
            probs = torch.sigmoid(logits).cpu().numpy()
            all_preds.append(probs)
            all_labels.append(labels.cpu().numpy())
    
    # Calculate AUC
    if len(all_preds) > 0:
        all_preds = np.vstack(all_preds)
        all_labels = np.vstack(all_labels)
        
        try:
            auc_scores = []
            for i in range(len(Config.LABEL_COLS)):
                if len(np.unique(all_labels[:, i])) > 1:
                    auc = roc_auc_score(all_labels[:, i], all_preds[:, i])
                    auc_scores.append(auc)
                else:
                    auc_scores.append(0.5)
            
            # Weighted AUC (13x weight for Aneurysm Present)
            weights = [1.0] * (len(Config.LABEL_COLS) - 1) + [13.0]
            weighted_auc = np.average(auc_scores, weights=weights)
        except:
            weighted_auc = 0.5
    else:
        weighted_auc = 0.5
    
    return total_loss / num_batches, weighted_auc

def main_training():
    print("🚀 STAGE 2: ANEURYSM CLASSIFICATION WITH EFFICIENTNET-B3")
    
    # Load data
    train_df = pd.read_csv(Config.TRAIN_CSV_PATH)
    
    if Config.DEBUG_MODE:
        train_df = train_df.head(Config.DEBUG_SAMPLES)
    
    print(f"Training samples: {len(train_df)}")
    print(f"Aneurysm cases: {train_df['Aneurysm Present'].sum()}")
    
    # Build training index (DIRECT_VOLUME_MODE: no ROIs)
    records = []
    for _, r in train_df.iterrows():
        rec = {
            'series_id': str(r[Config.ID_COL]),
            'roi_id': f"{str(r[Config.ID_COL])}_vol32",
            'roi_path': '',
            'roi_confidence': 1.0,
            'slice_idx': -1,
        }
        for col in Config.LABEL_COLS:
            rec[col] = r[col]
        records.append(rec)
    training_df = pd.DataFrame(records)
    
    # Calculate class weights
    class_weights = calculate_class_weights(training_df)
    print(f"Class weights: {class_weights}")
    
    # Create criterion with class weights
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights).to(Config.DEVICE)
    
    # Mixed precision scaler
    global scaler
    scaler = torch.cuda.amp.GradScaler(enabled=Config.MIXED_PRECISION)
    
    # Cross-validation or single split
    # Use Aneurysm Present for stratification
    fold_scores = []
    if Config.N_FOLDS <= 1:
        idx_all = np.arange(len(training_df))
        train_idx, val_idx = train_test_split(
            idx_all,
            test_size=0.2,
            stratify=training_df['Aneurysm Present'],
            random_state=42,
        )
        fold_splits = [(train_idx, val_idx)]
    else:
        skf = StratifiedKFold(n_splits=Config.N_FOLDS, shuffle=True, random_state=42)
        fold_splits = list(skf.split(training_df, training_df['Aneurysm Present']))
    
    for fold, (train_idx, val_idx) in enumerate(fold_splits):
        print(f"\n{'='*50}")
        print(f"FOLD {fold + 1}/{Config.N_FOLDS}")
        print(f"{'='*50}")
        
        # Split data
        train_fold_df = training_df.iloc[train_idx].reset_index(drop=True)
        val_fold_df = training_df.iloc[val_idx].reset_index(drop=True)
        
        print(f"Train samples: {len(train_fold_df)}, Val samples: {len(val_fold_df)}")
        
        # Create datasets
        train_dataset = AneurysmClassificationDataset(train_fold_df, mode='train')
        val_dataset = AneurysmClassificationDataset(val_fold_df, mode='val')
        
        # Create loaders (tuned for throughput)
        train_loader = DataLoader(
            train_dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=True,
            num_workers=Config.NUM_WORKERS,
            pin_memory=Config.PIN_MEMORY,
            persistent_workers=Config.PERSISTENT_WORKERS,
            prefetch_factor=Config.PREFETCH_FACTOR,
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=Config.BATCH_SIZE,
            shuffle=False,
            num_workers=Config.NUM_WORKERS,
            pin_memory=Config.PIN_MEMORY,
            persistent_workers=Config.PERSISTENT_WORKERS,
            prefetch_factor=Config.PREFETCH_FACTOR,
        )
        
        # Initialize model (ensure in_chans=32 path is used)
        model = AneurysmEfficientNet().to(Config.DEVICE)
        try:
            model = model.to(memory_format=torch.channels_last)
        except Exception:
            pass
        
        # Optimizer with different learning rates
        optimizer = optim.AdamW([
            {'params': model.backbone.parameters(), 'lr': Config.LEARNING_RATE * 0.1},  # Lower LR for backbone
            {'params': model.classifier.parameters(), 'lr': Config.LEARNING_RATE}
        ], weight_decay=1e-4)

        # Multi-GPU if available
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        
        # Scheduler
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.EPOCHS)
        
        # Training loop
        best_auc = 0
        
        for epoch in range(Config.EPOCHS):
            print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")
            
            # Train
            train_loss = train_epoch(model, train_loader, optimizer, criterion, Config.DEVICE)
            
            # Validate
            val_loss, val_auc = validate_epoch(model, val_loader, criterion, Config.DEVICE)
            
            # Step scheduler
            scheduler.step()
            
            print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val AUC: {val_auc:.4f}")
            
            # Save best model
            if val_auc > best_auc:
                best_auc = val_auc
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_auc': val_auc,
                    'epoch': epoch,
                    'fold': fold
                }, f'stage2_fold_{fold}_best.pth')
                print(f"💾 Saved best model (AUC: {val_auc:.4f})")
        
        fold_scores.append(best_auc)
        print(f"Fold {fold + 1} best AUC: {best_auc:.4f}")
    
    # Final results
    mean_cv_score = np.mean(fold_scores)
    print(f"\n✅ Cross-validation complete!")
    print(f"Mean CV AUC: {mean_cv_score:.4f} ± {np.std(fold_scores):.4f}")
    print(f"Individual fold scores: {fold_scores}")

print("✅ Training pipeline loaded")

# ====================================================
# CELL 5: INFERENCE & SUBMISSION
# ====================================================

class InferenceConfig:
    """Configuration for inference server"""
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ID_COL = 'SeriesInstanceUID'
    LABEL_COLS = [
        'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery',
        'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
        'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery', 'Anterior Communicating Artery',
        'Left Anterior Cerebral Artery', 'Right Anterior Cerebral Artery',
        'Left Posterior Communicating Artery', 'Right Posterior Communicating Artery',
        'Basilar Tip', 'Other Posterior Circulation', 'Aneurysm Present',
    ]

class ModelEnsemble:
    """Ensemble of Stage 2 models for inference"""
    def __init__(self, model_paths, device):
        self.device = device
        self.models = []
        
        # Replace model paths with your 5-fold 32ch checkpoints
        base = '/kaggle/input/rsna2025-effnetv2-32ch'
        model_paths = [
            os.path.join(base, f'tf_efficientnetv2_s.in21k_ft_in1k_fold{i}_best.pth') for i in range(5)
        ]
        for path in model_paths:
            try:
                model = AneurysmEfficientNet().to(device)
                checkpoint = torch.load(path, map_location=device, weights_only=False)
                
                if 'model_state_dict' in checkpoint:
                    state_dict = checkpoint['model_state_dict']
                else:
                    state_dict = checkpoint
                
                # Handle DataParallel wrapper
                if any(key.startswith('module.') for key in state_dict.keys()):
                    state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()}
                
                model.load_state_dict(state_dict)
                model.eval()
                self.models.append(model)
                print(f"Loaded model: {path}")
            except Exception as e:
                print(f"Error loading {path}: {e}")
        
        print(f"Loaded {len(self.models)} models for ensemble")
    
    def predict_single(self, series_path):
        """Predict by building a 32x384x384 volume directly from DICOMs using shared preprocessor."""
        vol = process_dicom_series_safe(series_path, target_shape=(32,384,384))
        x = torch.from_numpy(vol).unsqueeze(0).to(self.device)  # [1,32,384,384]
        preds = []
        with torch.no_grad():
            for model in self.models:
                logits = model(x)
                probs = torch.sigmoid(logits).cpu().numpy()[0]
                preds.append(probs)
        return np.mean(np.stack(preds, axis=0), axis=0)

class InferenceDICOMProcessor:
    """DICOM processor for inference"""
    def __init__(self):
        pass

# Global variables for model ensemble
model_ensemble = None
processor = None

def initialize_models():
    """Initialize models - called once at startup"""
    global model_ensemble, processor
    
    print("Initializing models...")
    
    # Model paths - adjust these to match your uploaded dataset structure
    model_paths = [
        'stage2_fold_0_best.pth',
        'stage2_fold_1_best.pth',
        'stage2_fold_2_best.pth',
        'stage2_fold_3_best.pth',
        'stage2_fold_4_best.pth',
    ]
    
    # Check if models exist, use available ones
    available_models = [path for path in model_paths if os.path.exists(path)]
    
    if not available_models:
        print("Warning: No trained models found! Using dummy predictions.")
        model_ensemble = None
    else:
        try:
            model_ensemble = ModelEnsemble(available_models, InferenceConfig.DEVICE)
            print("Models initialized successfully!")
        except Exception as e:
            print(f"Error initializing models: {e}")
            model_ensemble = None
    
    processor = InferenceDICOMProcessor()

def predict(series_path: str) -> pl.DataFrame:
    """Make predictions for the competition API"""
    global model_ensemble, processor
    
    # Initialize models on first call (lazy loading)
    if model_ensemble is None and processor is None:
        initialize_models()
    
    series_id = os.path.basename(series_path)
    
    try:
        if model_ensemble is not None:
            # Use trained ensemble
            predictions = model_ensemble.predict_single(series_path)
        else:
            # Fallback: extract metadata and make informed dummy predictions
            print(f"Using fallback prediction for {series_id}")
            
            # Load DICOM metadata
            all_filepaths = []
            for root, _, files in os.walk(series_path):
                for file in files:
                    if file.endswith('.dcm'):
                        all_filepaths.append(os.path.join(root, file))
            
            if all_filepaths:
                ds = pydicom.dcmread(all_filepaths[0], force=True)
                modality = getattr(ds, 'Modality', 'UNKNOWN')
                
                # Slightly better informed predictions based on modality
                if modality in ['CTA', 'MRA']:
                    # Vascular imaging - slightly higher probability
                    base_prob = 0.1
                else:
                    # Other modalities - lower baseline
                    base_prob = 0.05
                
                # Add some noise to make predictions more realistic
                predictions = np.random.normal(base_prob, 0.02, len(InferenceConfig.LABEL_COLS))
                predictions = np.clip(predictions, 0.001, 0.999)
            else:
                # No DICOM files found
                predictions = np.full(len(InferenceConfig.LABEL_COLS), 0.5)

        # Ensure predictions is numpy array and convert to list safely
        if not isinstance(predictions, np.ndarray):
            predictions = np.array(predictions)
        
        # Create prediction DataFrame
        prediction_df = pl.DataFrame(
            data=[[series_id] + predictions.tolist()],
            schema=[InferenceConfig.ID_COL, *InferenceConfig.LABEL_COLS],
            orient='row',
        )
        
    except Exception as e:
        print(f"Error processing {series_id}: {e}")
        # Return safe default predictions
        prediction_df = pl.DataFrame(
            data=[[series_id] + [0.5] * len(InferenceConfig.LABEL_COLS)],
            schema=[InferenceConfig.ID_COL, *InferenceConfig.LABEL_COLS],
            orient='row',
        )
    
    # IMPORTANT: Remove SeriesInstanceUID before returning (API requirement)
    prediction_df = prediction_df.drop(InferenceConfig.ID_COL)
    
    # IMPORTANT: Disk cleanup to prevent "out of disk space" errors
    shutil.rmtree('/kaggle/shared', ignore_errors=True)
    
    return prediction_df



# ====================================================
# SERVER EXECUTION
# ====================================================

# Initialize the inference server
inference_server = kaggle_evaluation.rsna_inference_server.RSNAInferenceServer(predict)

print("✅ Inference and submission pipeline loaded")

# ====================================================
# CELL 6: MAIN EXECUTION
# ====================================================

if __name__ == "__main__":
    if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
        # Production mode - serve the API
        print("Starting inference server...")
        inference_server.serve()
    else:
        # Training mode
        print("Ready for Stage 2 training!")
        print("Uncomment the line below to start training:")
        print("# main_training()")
        
        # Uncomment to start training
        main_training()
        
        # Or run local testing
        print("Running local gateway for testing...")
        inference_server.run_local_gateway()
        
        # Display results if available
        results_path = '/kaggle/working/submission.parquet'
        if os.path.exists(results_path):
            results_df = pl.read_parquet(results_path)
            print("Submission preview:")
            print(results_df.head())

✅ Configuration loaded - Device: cuda
✅ Data loading and ROI extraction functions loaded
✅ Model definition loaded
✅ Training pipeline loaded
✅ Inference and submission pipeline loaded
Ready for Stage 2 training!
Uncomment the line below to start training:
# main_training()
🚀 STAGE 2: ANEURYSM CLASSIFICATION WITH EFFICIENTNET-B3
Training samples: 4348
Aneurysm cases: 1864
Class weights: tensor([54.7436, 43.3673, 12.1360, 14.6968, 18.8539, 13.7891, 10.9780, 93.5217,
        76.6429, 49.5581, 42.0495, 38.5273, 37.4779, 17.3240])

FOLD 1/1
Train samples: 3478, Val samples: 870

Epoch 1/3


Training:   0%|          | 0/1739 [00:00<?, ?it/s]

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.27680122398901436027276783658914589954_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.20352838605781624312895197978664744075_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.12904246053955178641505906243733756576_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.10783586076403918900057381253415239230_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.13035832792413871820010907388005791076_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.11231019858377850021999891102731187707_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.10092666779602341135460882241562348436_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.25401566480135645158545753333376825827_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.147276533782

Training:   0%|          | 1/1739 [00:36<17:26:50, 36.14s/it]

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.89460822484126633248553997073630753402_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.12743083402126679385964805363054623625_32x384x384_u8.npy (max_err=0.0020)


Training:   0%|          | 6/1739 [00:39<1:12:13,  2.50s/it] 

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.11114613141735642199606043212646844886_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.86037975393556827852769300088670915080_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.99421822954919332641371697175982753182_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.99297218927715340305099097057004586774_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.31628002870565033361286640405875848972_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.15847980512533357707448321523314296455_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.97191581480677904605690881854234107618_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.59321031989170539770517544048571714746_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.887271333896

Training:   1%|          | 12/1739 [01:12<59:20,  2.06s/it]  

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.76301795615527602645756396708867809495_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.49504988543101147636991852267737251575_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.87865223651281657051189368725400341319_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.90230341788943218278385841963462570470_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.45789072046383277170393600308966109036_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.95380253040471768084221411882180922662_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.53947155422591684879953627516013605305_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.30885835981120543326208883457853128283_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.126578663772

Training:   1%|          | 15/1739 [01:49<2:56:00,  6.13s/it]

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.85212589540007626039427792519492210226_32x384x384_u8.npy (max_err=0.0020)


Training:   1%|          | 18/1739 [01:50<1:03:35,  2.22s/it]

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.11948255979244132827019816539294376988_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.46462519342058199786903141190024113863_32x384x384_u8.npy (max_err=0.0020)


Training:   1%|          | 21/1739 [01:53<36:51,  1.29s/it]  

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.71550538308484373791005892758892068095_32x384x384_u8.npy (max_err=0.0020)


Training:   1%|▏         | 24/1739 [01:54<18:27,  1.55it/s]

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.28681157493123082643438198449009757076_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.52524639115355387664045096288385299391_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.13156819427968405341920939838729113222_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.11731089624678785415420487370578919131_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.11062397380277678777080157173387177272_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.11208788596258922886794998326857227331_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.11038636852681039246443401046449812061_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.10603321067992496978932502160661673268_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.166198900230

Training:   2%|▏         | 30/1739 [02:30<58:24,  2.05s/it]  

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.17650328348608009816816941699740585437_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.85330951120080333123485292655736144682_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.12821030325057451794033542804285567094_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.19950322290309930502685963351749767205_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.89561322985962991141463885723229681301_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.13152457913019434787350387410585992407_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.53573089697219772119994230159413715556_32x384x384_u8.npy (max_err=0.0020)


Training:   2%|▏         | 34/1739 [02:35<31:21,  1.10s/it]  

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.11873351578622765241634317263552561587_32x384x384_u8.npy (max_err=0.0020)


Training:   2%|▏         | 36/1739 [02:36<19:02,  1.49it/s]

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.31016492921636257021969319428153307687_32x384x384_u8.npy (max_err=0.0020)


Training:   2%|▏         | 42/1739 [02:39<10:48,  2.62it/s]

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.55335626170859465016687259915777364744_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.71583117639965131882497910550331790290_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.72305112536340538867034966246953618485_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.29667087068052601737556059884413817393_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.26383584662098611508554214963067859078_32x384x384_u8.npy (max_err=0.0020)


Training:   3%|▎         | 47/1739 [02:47<19:53,  1.42it/s]  

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.11772545330652739508075303939268792529_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.99953513260518059135058337324142717073_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.77351519132509988103103734443501529160_32x384x384_u8.npy (max_err=0.0020)


Training:   3%|▎         | 48/1739 [02:50<42:01,  1.49s/it]

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.23421600482463782319293054087843086911_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.60806471972978535805998258895959371616_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.70048682207873090597326950007352492114_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.19764176435911045852235280876942035947_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.29789030841471927572567470683240960289_32x384x384_u8.npy (max_err=0.0020)


Training:   3%|▎         | 51/1739 [02:59<50:21,  1.79s/it]  

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.53713028846939619414399725027946568503_32x384x384_u8.npy (max_err=0.0020)


Training:   3%|▎         | 53/1739 [02:59<28:20,  1.01s/it]

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.38094808038974181102880321183103989801_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.90272546526306161811446757328579665073_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.12159152010278655162358172837938626290_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.35462271463152990781312639766446467244_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.11776450499172121144481170405958665580_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.10097649530131165889513682791963111629_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.11864645671097263388176300581289300776_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.61225069775477744578753822969424698697_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.559489360247

Training:   3%|▎         | 54/1739 [03:18<2:54:13,  6.20s/it]

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.28710895896233158724073271531642622364_32x384x384_u8.npy (max_err=0.0020)


Training:   3%|▎         | 59/1739 [03:18<33:05,  1.18s/it]  

[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.12813565901564977994662924864827111603_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.14234941301612013649573263693853357171_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.18108212083513041239064199663549795472_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.32738278165208105984060645831271331150_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.82754434126210061881442049561952688899_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.79053237532664154618488686227121698456_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.33254059742616938664293801285152925743_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.23657152176763679599021789757461301944_32x384x384_u8.npy (max_err=0.0020)
[CACHE] wrote u8: 1.2.826.0.1.3680043.8.498.120166120899