<a href="https://www.kaggle.com/code/nicholas33/02-aneurysmnet-cnn-intracranial-nb153?scriptVersionId=254261514" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
!pip install monai

# ====================================================
# RSNA INTRACRANIAL ANEURYSM DETECTION - TRAINING PIPELINE
# ====================================================

import os
import gc
import warnings
import json
import time
import numpy as np
import pandas as pd
from typing import Tuple, Dict, List
from collections import counter
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import roc_auc_score
import albumentations as A

warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler

import pydicom
from scipy import ndimage
import nibabel as nib
from monai.transforms import (
    Compose, RandRotate90d, RandFlipd, RandAffined,
    RandGaussianNoised, RandAdjustContrastd, ToTensord
)
from monai.networks.nets import BasicUNet
from monai.losses import DiceCELoss, FocalLoss
from tqdm import tqdm


Collecting monai
  Downloading monai-1.5.0-py3-none-any.whl.metadata (13 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch<2.7.0,

2025-08-04 07:50:42.799840: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754293843.149656      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754293843.245457      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# ====================================================
# CONFIGURATION
# ====================================================

class Config:
    # Paths
    TRAIN_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train.csv'
    LOCALIZER_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train_localizers.csv'
    SERIES_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/series/'
    SEGMENTATION_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/segmentations/'
    
    # Model parameters
    TARGET_SIZE = (32, 64, 64)  # Increased resolution
    EPOCHS = 10
    BATCH_SIZE = 8  # Reduced due to larger input size
    LEARNING_RATE = 1e-3
    WEIGHT_DECAY = 1e-4
    N_FOLDS = 3
    
    # Training parameters
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    MIXED_PRECISION = True
    GRADIENT_ACCUMULATION = 4
    
    # Competition constants
    ID_COL = 'SeriesInstanceUID'
    LABEL_COLS = [
        'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery',
        'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
        'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery', 'Anterior Communicating Artery',
        'Left Anterior Cerebral Artery', 'Right Anterior Cerebral Artery',
        'Left Posterior Communicating Artery', 'Right Posterior Communicating Artery',
        'Basilar Tip', 'Other Posterior Circulation', 'Aneurysm Present',
    ]
    
    # Class weights for imbalanced data
    ANEURYSM_PRESENT_WEIGHT = 13.0  # Match evaluation metric weighting

In [5]:
# ====================================================
# ENHANCED DATA PREPROCESSING
# ====================================================

class AdvancedDICOMProcessor:
    def __init__(self, target_size: Tuple[int, int, int] = Config.TARGET_SIZE):
        self.target_size = target_size
        self.stats = {
            'total_loaded': 0,
            'successful_loads': 0,
            'shape_errors': 0,
            'empty_volumes': 0,
            'preprocessing_errors': 0
        }
        
    def load_dicom_series(self, series_path: str) -> Tuple[np.ndarray, Dict]:
        """Load DICOM series with FIXED shape handling"""
        self.stats['total_loaded'] += 1
        try:
            dicom_files = [os.path.join(series_path, f) for f in os.listdir(series_path) if f.endswith('.dcm')]
            if not dicom_files:
                print(f"No DICOM files found in {series_path}")
                return np.zeros(self.target_size, dtype=np.float32), {}
                
            dicoms = []
            
            for f in dicom_files:
                ds = pydicom.dcmread(f, force=True)
                dicoms.append(ds)
                
            # Extract metadata from first DICOM
            first_ds = dicoms[0]
            metadata = {
                'modality': getattr(first_ds, 'Modality', 'UNKNOWN'),
                'spacing': getattr(first_ds, 'PixelSpacing', [1.0, 1.0]),
                'slice_thickness': getattr(first_ds, 'SliceThickness', 1.0),
                'rescale_slope': getattr(first_ds, 'RescaleSlope', 1.0),
                'rescale_intercept': getattr(first_ds, 'RescaleIntercept', 0.0),
            }
            
            # Sort by instance number 
            dicoms.sort(key=lambda x: int(getattr(x, 'InstanceNumber', 0)))

            # MINIMAL FIX: Get all pixel arrays and find common shape
            pixel_arrays = []
            shapes = []
            for d in dicoms:
                if hasattr(d, 'pixel_array'):
                    pixel_arrays.append(d.pixel_array)
                    shapes.append(d.pixel_array.shape)
            
            if len(pixel_arrays) == 0:
                print(f"No pixel arrays found in {series_path}")
                self.stats['shape_errors'] += 1
                return np.zeros(self.target_size, dtype=np.float32), metadata
            
            # Check if all shapes are the same
            unique_shapes = list(set(shapes))
            if len(unique_shapes) == 1:
                # All shapes are same - original code works
                volume = np.stack(pixel_arrays, axis=0).astype(np.float32)
            else:
                # Shapes differ - resize to most common shape
                most_common_shape = Counter(shapes).most_common(1)[0][0]
                
                resized_arrays = []
                for arr in pixel_arrays:
                    if arr.shape == most_common_shape:
                        resized_arrays.append(arr)
                    else:
                        # Resize to most common shape
                        zoom_factors = (most_common_shape[0] / arr.shape[0], 
                                      most_common_shape[1] / arr.shape[1])
                        resized_arr = ndimage.zoom(arr, zoom_factors, order=1, prefilter=False)
                        resized_arrays.append(resized_arr)
                
                volume = np.stack(resized_arrays, axis=0).astype(np.float32)
            
            # Apply rescale if available
            if metadata['rescale_slope'] != 1.0 or metadata['rescale_intercept'] != 0.0:
                volume = volume * metadata['rescale_slope'] + metadata['rescale_intercept']

            self.stats['successful_loads'] += 1
            return volume, metadata
            
        except Exception as e:
            print(f"Error loading {series_path}: {e}")
            self.stats['shape_errors'] += 1
            return np.zeros(self.target_size, dtype=np.float32), {}

    def print_stats(self):
        """Print loading statistics"""
        total = self.stats['total_loaded']
        successful = self.stats['successful_loads']
        if total > 0:
            success_rate = (successful / total) * 100
            print(f"\n=== DICOM Loading Stats ===")
            print(f"Total attempts: {total}")
            print(f"Successful loads: {successful} ({success_rate:.1f}%)")
            print(f"Shape errors: {self.stats['shape_errors']}")
            print(f"Empty volumes: {self.stats['empty_volumes']}")
            print(f"Preprocessing errors: {self.stats['preprocessing_errors']}")
            print(f"===========================")

    def preprocess_volume(self, volume: np.ndarray, metadata: Dict) -> np.ndarray:
        """Enhanced preprocessing with modality-specific handling"""
        if volume.ndim != 3 or volume.size == 0:
            print(f"Warning: Received a non-3D volume. Returning empty target volume.")
            return np.zeros(self.target_size, dtype=np.float32)
        
        # Default windowing
        p1, p99 = np.percentile(volume, [5, 95])
        volume = np.clip(volume, p1, p99)
        
        # Normalization
        vol_min, vol_max = volume.min(), volume.max()
        if vol_max > vol_min:
            volume = (volume - vol_min) / (vol_max - vol_min)
        
        # Resize to target size
        if volume.shape != self.target_size:
            zoom_factors = [self.target_size[i] / volume.shape[i] for i in range(3)]
            volume = ndimage.zoom(volume, zoom_factors, order=1, prefilter=False)
        
        return volume.astype(np.float32)

    def load_localization_mask(self, series_id: str, localizer_df: pd.DataFrame) -> np.ndarray:
        return np.zeros(self.target_size, dtype=np.float32)


# ====================================================
# ENHANCED DATASET
# ====================================================

class EnhancedAneurysmDataset(Dataset):
    def __init__(self, df: pd.DataFrame, localizer_df: pd.DataFrame, 
                 series_dir: str, processor: AdvancedDICOMProcessor, 
                 mode: str = 'train', fold: int = None):
        self.df = df
        self.localizer_df = localizer_df
        self.series_dir = series_dir
        self.processor = processor
        self.mode = mode
        self.fold = fold
        
        # Data augmentation for training
        if mode == 'train':
            self.transform = Compose([
                # RandRotate90d(keys=['volume'], prob=0.3, spatial_axes=(0, 1)),
                # RandFlipd(keys=['volume'], prob=0.3, spatial_axis=0),
                # RandFlipd(keys=['volume'], prob=0.3, spatial_axis=1),
                # RandFlipd(keys=['volume'], prob=0.3, spatial_axis=2),
                # RandAffined(keys=['volume'], prob=0.3, rotate_range=0.1, scale_range=0.1),
                # RandGaussianNoised(keys=['volume'], prob=0.3, std=0.05),
                # RandAdjustContrastd(keys=['volume'], prob=0.3, gamma=(0.8, 1.2)),
                RandFlipd(keys=['volume'], prob=0.5, spatial_axis=0),  # Only basic flips
                RandFlipd(keys=['volume'], prob=0.5, spatial_axis=1),
                ToTensord(keys=['volume'])
            ])
        else:
            self.transform = Compose([ToTensord(keys=['volume'])])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        load_start = time.time()
        row = self.df.iloc[idx]
        series_id = row[Config.ID_COL]
        series_path = os.path.join(self.series_dir, series_id)
        
        # Load and process volume
        volume, metadata = self.processor.load_dicom_series(series_path)
        dicom_time = time.time() - load_start
        preprocess_start = time.time()
        volume = self.processor.preprocess_volume(volume, metadata)
        
        # Create localization mask (for auxiliary loss)
        loc_mask = self.processor.load_localization_mask(series_id, self.localizer_df)
        
        # Get labels
        labels = row[Config.LABEL_COLS].values.astype(np.float32)
        
        # Apply transforms
        data_dict = {'volume': volume}
        if self.transform:
            data_dict = self.transform(data_dict)
        
        volume_tensor = data_dict['volume'].unsqueeze(0)  # Add channel dimension
        loc_mask_tensor = torch.from_numpy(loc_mask).unsqueeze(0)
        labels_tensor = torch.from_numpy(labels)
        
        # Add metadata features
        modality_encoding = self._encode_modality(metadata.get('modality', 'UNKNOWN'))
        metadata_tensor = torch.tensor(modality_encoding, dtype=torch.float32)
        
        preprocess_time = time.time() - preprocess_start
        # Print timing for first few samples to debug
        if idx < 5:
            print(f"Sample {idx}: DICOM load: {dicom_time:.2f}s, Preprocess: {preprocess_time:.2f}s")
        
        return {
            'volume': volume_tensor,
            'localization_mask': loc_mask_tensor,
            'labels': labels_tensor,
            'metadata': metadata_tensor,
            'series_id': series_id
        }
    
    def _encode_modality(self, modality: str) -> List[float]:
        """One-hot encode modality"""
        modalities = ['CTA', 'MRA', 'MRI', 'MR', 'UNKNOWN']
        encoding = [0.0] * len(modalities)
        if modality in modalities:
            encoding[modalities.index(modality)] = 1.0
        else:
            encoding[-1] = 1.0  # UNKNOWN
        return encoding


# ====================================================
# ADVANCED MODEL ARCHITECTURE
# ====================================================

#class MultiModalAneurysmNet(nn.Module):
class SimplifiedAneurysmNet(nn.Module):
    def __init__(self, num_classes: int = len(Config.LABEL_COLS), 
                 spatial_dims: int = 3, in_channels: int = 1, 
        #          features: Tuple = (32, 64, 128, 256, 512, 1024)):
        # super(MultiModalAneurysmNet, self).__init__()
                 features: Tuple = (16, 32, 64, 128, 256, 51)):
        super(SimplifiedAneurysmNet, self).__init__()
        
        # Main 3D U-Net backbone
        self.backbone = BasicUNet(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=features[0],
            features=features,
            dropout=0.1 #Reduced dropout 
        )
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        
        # # Metadata processing
        # self.metadata_mlp = nn.Sequential(
        #     nn.Linear(5, 32),  # 5 modality categories
        #     nn.ReLU(),
        #     nn.Dropout(0.3),
        #     nn.Linear(32, 64),
        #     nn.ReLU()
        # )
        
        # Classification head
        #feature_size = features[0] + 64  # backbone features + metadata features
        self.classifier = nn.Sequential(
            # nn.Linear(feature_size, 512),
            # nn.ReLU(),
            # nn.Dropout(0.5),
            # nn.Linear(512, 256),
            nn.Linear(features[0], 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, volume, metadata=None):
        # Extract features from 3D volume
        features = self.backbone(volume)
        # Global features for classification
        global_features = self.global_pool(features).flatten(1)
        classification_logits = self.classifier(global_features)
        return classification_logits, None

# ====================================================
# WEIGHTED LOSS FUNCTION
# ====================================================

class WeightedMultiLabelLoss(nn.Module):
    def __init__(self, pos_weights=None, aneurysm_weight=13.0):
        super().__init__()
        self.pos_weights = pos_weights
        self.aneurysm_weight = aneurysm_weight
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
        
    def forward(self, logits, targets):
        bce_loss = self.bce(logits, targets)
        
        # Apply position weights if provided
        if self.pos_weights is not None:
            bce_loss = bce_loss * self.pos_weights.to(logits.device)
        
        # Weight the "Aneurysm Present" class higher (last column)
        weights = torch.ones_like(bce_loss)
        weights[:, -1] = self.aneurysm_weight
        
        weighted_loss = bce_loss * weights
        return weighted_loss.mean()


# ====================================================
# TRAINING FUNCTIONS
# ====================================================

def compute_weighted_auc(y_true, y_pred):
    """Compute weighted AUC matching competition metric"""
    aucs = []
    weights = []
    
    for i in range(len(Config.LABEL_COLS)):
        try:
            auc = roc_auc_score(y_true[:, i], y_pred[:, i])
            aucs.append(auc)
            # Weight "Aneurysm Present" (last column) higher
            weights.append(13.0 if i == len(Config.LABEL_COLS) - 1 else 1.0)
        except ValueError:
            aucs.append(0.5)  # Default for no positive cases
            weights.append(13.0 if i == len(Config.LABEL_COLS) - 1 else 1.0)
    
    weighted_auc = sum(a * w for a, w in zip(aucs, weights)) / sum(weights)
    return weighted_auc, aucs

def train_epoch(model, train_loader, optimizer, criterion, scaler, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Training Epoch")):
        start_time = time.time()
        volume = batch['volume'].to(device)
        metadata = batch['metadata'].to(device)
        labels = batch['labels'].to(device)
        loc_mask = batch['localization_mask'].to(device)
        
        with autocast(device_type=device.type, enabled=Config.MIXED_PRECISION):
            class_logits, _ = model(volume, metadata)
            total_loss_batch = criterion(class_logits, labels)
        
        # Gradient accumulation
        scaled_loss = total_loss_batch / Config.GRADIENT_ACCUMULATION
        scaler.scale(scaled_loss).backward()
        
        if (batch_idx + 1) % Config.GRADIENT_ACCUMULATION == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        total_loss += total_loss_batch.item()
        num_batches += 1

        # Print timing for first few batches to identify bottlenecks
        if batch_idx < 5:
            batch_time = time.time() - start_time
            print(f"Batch {batch_idx}: {batch_time:.2f}s")
    
    return total_loss / num_batches

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            volume = batch['volume'].to(device)
            metadata = batch['metadata'].to(device)
            labels = batch['labels'].to(device)
            
            with autocast(device_type=device.type, enabled=Config.MIXED_PRECISION):
                class_logits, _ = model(volume, metadata)
                loss = criterion(class_logits, labels)
            
            total_loss += loss.item()
            
            # Collect predictions for AUC calculation
            probs = torch.sigmoid(class_logits).cpu().numpy()
            all_preds.append(probs)
            all_labels.append(labels.cpu().numpy())
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    weighted_auc, individual_aucs = compute_weighted_auc(all_labels, all_preds)
    
    return total_loss / len(val_loader), weighted_auc, individual_aucs



In [None]:
# ====================================================
# MAIN TRAINING EXECUTION
# ====================================================

def main():
    print(f"Using device: {Config.DEVICE}")
    print(f"Mixed precision: {Config.MIXED_PRECISION}")
    
    # Load data
    train_df = pd.read_csv(Config.TRAIN_CSV_PATH)
    localizer_df = pd.read_csv(Config.LOCALIZER_CSV_PATH)

    print("---!!! RUNNING IN DEBUG MODE ON A SMALL SUBSET !!!---")
    #print(f"Training samples: {len(train_df)}")
    train_df = train_df.head(100)  # Limit to 100 samples for speed testing
    print(f"Training samples: {len(train_df)} (limited for speed testing)")
    print(f"Positive aneurysm cases: {train_df['Aneurysm Present'].sum()}")
    
    # Create stratified group k-fold split
    # Use patient-level grouping to prevent data leakage
    train_df['patient_group'] = train_df['PatientID'] if 'PatientID' in train_df.columns else range(len(train_df))
    
    skf = StratifiedGroupKFold(n_splits=Config.N_FOLDS, shuffle=True, random_state=42)
    train_df['fold'] = -1
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(
        train_df, train_df['Aneurysm Present'], groups=train_df['patient_group']
    )):
        train_df.loc[val_idx, 'fold'] = fold
    
    # Calculate class weights for imbalanced data
    pos_counts = train_df[Config.LABEL_COLS].sum()
    neg_counts = len(train_df) - pos_counts
    pos_weights = neg_counts / (pos_counts + 1e-8)  # Add small epsilon
    pos_weights = torch.tensor(pos_weights.values, dtype=torch.float32)
    
    print("Class weights:", pos_weights)
    
    # Initialize processor
    processor = AdvancedDICOMProcessor()
    
    # Train models for each fold
    fold_scores = []
    
    for fold in range(Config.N_FOLDS):
        print(f"\n{'='*50}")
        print(f"FOLD {fold + 1}/{Config.N_FOLDS}")
        print(f"{'='*50}")
        
        # Split data
        train_fold_df = train_df[train_df['fold'] != fold].reset_index(drop=True)
        val_fold_df = train_df[train_df['fold'] == fold].reset_index(drop=True)
        
        print(f"Train: {len(train_fold_df)}, Validation: {len(val_fold_df)}")
        
        # Create datasets
        train_dataset = EnhancedAneurysmDataset(
            train_fold_df, localizer_df, Config.SERIES_DIR, processor, mode='train', fold=fold
        )
        val_dataset = EnhancedAneurysmDataset(
            val_fold_df, localizer_df, Config.SERIES_DIR, processor, mode='val', fold=fold
        )
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, 
            num_workers=4, pin_memory=True, drop_last=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, 
            num_workers=4, pin_memory=True
        )
        
        # Initialize model
        model = SimplifiedAneurysmNet().to(Config.DEVICE)
        criterion = WeightedMultiLabelLoss(pos_weights=pos_weights)
        
        # Optimizer with different learning rates for different parts
        # optimizer = optim.AdamW([
        #     {'params': model.backbone.parameters(), 'lr': Config.LEARNING_RATE},
        #     {'params': model.classifier.parameters(), 'lr': Config.LEARNING_RATE * 2},
        #     {'params': model.metadata_mlp.parameters(), 'lr': Config.LEARNING_RATE * 2}
        # ], weight_decay=Config.WEIGHT_DECAY)
        optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)
        
        # Learning rate scheduler
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=2, eta_min=1e-6
        )
        
        scaler = GradScaler(enabled=Config.MIXED_PRECISION)
        
        # Training loop
        best_auc = 0
        patience = 10
        patience_counter = 0
        
        for epoch in range(Config.EPOCHS):
            # Train
            train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler, Config.DEVICE)
            
            # Validate
            val_loss, val_auc, individual_aucs = validate_epoch(model, val_loader, criterion, Config.DEVICE)
            
            # Step scheduler
            scheduler.step()
            
            print(f"Epoch {epoch+1:3d} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"Val AUC: {val_auc:.4f}")

            processor.print_stats() # Print DICOM loading stats after each epoch
            
            # Save best model
            if val_auc > best_auc:
                best_auc = val_auc
                patience_counter = 0
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_auc': val_auc,
                    'epoch': epoch,
                    'fold': fold,
                    'individual_aucs': individual_aucs
                }, f'best_model_fold_{fold}.pth')
            else:
                patience_counter += 1
                
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
                
            # Memory cleanup
            if epoch % 5 == 0:
                gc.collect()
                torch.cuda.empty_cache()
        
        fold_scores.append(best_auc)
        print(f"Fold {fold + 1} best AUC: {best_auc:.4f}")
    
    # Final results
    mean_cv_score = np.mean(fold_scores)
    std_cv_score = np.std(fold_scores)
    
    print(f"\n{'='*50}")
    print(f"CROSS-VALIDATION RESULTS")
    print(f"{'='*50}")
    print(f"Mean CV AUC: {mean_cv_score:.4f} ± {std_cv_score:.4f}")
    print(f"Individual fold scores: {fold_scores}")
    
    # Save training summary
    results = {
        'cv_scores': fold_scores,
        'mean_cv_score': mean_cv_score,
        'std_cv_score': std_cv_score,
        'config': vars(Config())
    }
    
    with open('training_results.json', 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    print("Training complete! Models saved as 'best_model_fold_X.pth'")

if __name__ == "__main__":
    main()

Using device: cuda
Mixed precision: True
---!!! RUNNING IN DEBUG MODE ON A SMALL SUBSET !!!---
Training samples: 100 (limited for speed testing)
Positive aneurysm cases: 48
Class weights: tensor([3.2333e+01, 4.9000e+01, 1.3286e+01, 9.0000e+00, 4.9000e+01, 1.1500e+01,
        1.0111e+01, 9.9000e+01, 4.9000e+01, 4.9000e+01, 1.0000e+10, 3.2333e+01,
        4.9000e+01, 1.0833e+00])

FOLD 1/3
Train: 66, Validation: 34
BasicUNet features: (16, 32, 64, 128, 256, 51).


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Sample 1: DICOM load: 7.52s, Preprocess: 1.18s
Sample 0: DICOM load: 2.25s, Preprocess: 0.95s


Training Epoch:  12%|█▎        | 1/8 [00:31<03:42, 31.79s/it]

Batch 0: 1.24s
Sample 4: DICOM load: 9.56s, Preprocess: 2.23s


Training Epoch:  38%|███▊      | 3/8 [00:55<01:13, 14.76s/it]

Batch 1: 0.28s
Batch 2: 0.19s


Training Epoch:  50%|█████     | 4/8 [00:55<00:36,  9.01s/it]

Batch 3: 0.20s
Sample 2: DICOM load: 3.84s, Preprocess: 0.01s


Training Epoch:  62%|██████▎   | 5/8 [01:12<00:35, 11.94s/it]

Batch 4: 0.24s
Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10192011262895867728128531292507199782: all input arrays must have the same shape
Sample 3: DICOM load: 23.49s, Preprocess: 4.65s


Training Epoch: 100%|██████████| 8/8 [01:53<00:00, 14.25s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10077108087009955586144859725246456654: all input arrays must have the same shape
Sample 0: DICOM load: 5.18s, Preprocess: 0.84s
Sample 1: DICOM load: 11.83s, Preprocess: 2.63s
Sample 2: DICOM load: 7.20s, Preprocess: 1.20s
Sample 3: DICOM load: 10.45s, Preprocess: 1.93s
Sample 4: DICOM load: 5.03s, Preprocess: 1.13s


Validating: 100%|██████████| 5/5 [01:44<00:00, 20.92s/it]


Epoch   1 | Train Loss: 487743064.0000 | Val Loss: 488626124.8000 | Val AUC: 0.3028


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10192011262895867728128531292507199782: all input arrays must have the same shape
Sample 1: DICOM load: 5.85s, Preprocess: 1.49s
Sample 3: DICOM load: 19.67s, Preprocess: 5.09s
Sample 4: DICOM load: 8.94s, Preprocess: 2.58s


Training Epoch:  25%|██▌       | 2/8 [01:05<02:41, 26.88s/it]

Batch 0: 0.29s
Batch 1: 0.19s


Training Epoch:  38%|███▊      | 3/8 [01:05<01:13, 14.69s/it]

Batch 2: 0.19s


Training Epoch:  50%|█████     | 4/8 [01:05<00:35,  8.97s/it]

Batch 3: 0.20s
Sample 0: DICOM load: 2.07s, Preprocess: 0.94s


Training Epoch:  75%|███████▌  | 6/8 [01:33<00:20, 10.49s/it]

Batch 4: 0.27s


Training Epoch: 100%|██████████| 8/8 [01:34<00:00, 11.78s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10077108087009955586144859725246456654: all input arrays must have the same shape
Sample 0: DICOM load: 4.51s, Preprocess: 1.01s
Sample 1: DICOM load: 11.68s, Preprocess: 2.60s
Sample 2: DICOM load: 4.08s, Preprocess: 1.14s
Sample 3: DICOM load: 7.46s, Preprocess: 1.96s
Sample 4: DICOM load: 3.01s, Preprocess: 1.11s


Validating: 100%|██████████| 5/5 [01:17<00:00, 15.49s/it]


Epoch   2 | Train Loss: 488249596.0000 | Val Loss: 488626124.8000 | Val AUC: 0.3028


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Sample 2: DICOM load: 3.38s, Preprocess: 0.01s
Sample 3: DICOM load: 19.62s, Preprocess: 5.15s
Sample 4: DICOM load: 9.53s, Preprocess: 2.53s
Sample 0: DICOM load: 2.22s, Preprocess: 0.91s
Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10192011262895867728128531292507199782: all input arrays must have the same shape


Training Epoch:  12%|█▎        | 1/8 [01:09<08:07, 69.70s/it]

Batch 0: 0.24s


Training Epoch:  38%|███▊      | 3/8 [01:10<01:18, 15.75s/it]

Batch 1: 0.20s
Batch 2: 0.20s


Training Epoch:  50%|█████     | 4/8 [01:10<00:38,  9.61s/it]

Batch 3: 0.20s
Sample 1: DICOM load: 5.78s, Preprocess: 1.52s


Training Epoch:  75%|███████▌  | 6/8 [01:46<00:25, 12.69s/it]

Batch 4: 0.27s


Training Epoch: 100%|██████████| 8/8 [01:47<00:00, 13.38s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10077108087009955586144859725246456654: all input arrays must have the same shape
Sample 0: DICOM load: 4.92s, Preprocess: 1.06s
Sample 1: DICOM load: 11.95s, Preprocess: 2.74s
Sample 2: DICOM load: 4.09s, Preprocess: 1.16s
Sample 3: DICOM load: 6.98s, Preprocess: 1.95s
Sample 4: DICOM load: 3.17s, Preprocess: 1.12s


Validating: 100%|██████████| 5/5 [01:17<00:00, 15.58s/it]


Epoch   3 | Train Loss: 487865052.0000 | Val Loss: 488626124.8000 | Val AUC: 0.3028


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Sample 2: DICOM load: 3.37s, Preprocess: 0.01s


Training Epoch:  12%|█▎        | 1/8 [00:41<04:47, 41.13s/it]

Batch 0: 0.27s
Sample 3: DICOM load: 19.00s, Preprocess: 5.15s
Sample 1: DICOM load: 6.16s, Preprocess: 1.20s
Sample 0: DICOM load: 2.73s, Preprocess: 0.99s
Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10192011262895867728128531292507199782: all input arrays must have the same shape
Sample 4: DICOM load: 4.95s, Preprocess: 1.82s


Training Epoch:  25%|██▌       | 2/8 [01:23<04:10, 41.79s/it]

Batch 1: 0.28s


Training Epoch:  50%|█████     | 4/8 [01:23<00:55, 13.88s/it]

Batch 2: 0.22s
Batch 3: 0.20s


Training Epoch:  62%|██████▎   | 5/8 [01:24<00:26,  8.95s/it]

Batch 4: 0.20s


Training Epoch: 100%|██████████| 8/8 [01:53<00:00, 14.18s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10077108087009955586144859725246456654: all input arrays must have the same shape
Sample 0: DICOM load: 4.68s, Preprocess: 1.01s
Sample 1: DICOM load: 11.70s, Preprocess: 2.72s
Sample 2: DICOM load: 4.69s, Preprocess: 1.19s
Sample 3: DICOM load: 7.15s, Preprocess: 1.94s
Sample 4: DICOM load: 2.88s, Preprocess: 1.12s


Validating: 100%|██████████| 5/5 [01:17<00:00, 15.47s/it]


Epoch   4 | Train Loss: 488775332.0000 | Val Loss: 488626124.8000 | Val AUC: 0.3028


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Sample 2: DICOM load: 3.49s, Preprocess: 0.01s
Sample 4: DICOM load: 9.35s, Preprocess: 2.71s
Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10192011262895867728128531292507199782: all input arrays must have the same shape


Training Epoch:  25%|██▌       | 2/8 [00:52<02:09, 21.65s/it]

Batch 0: 0.26s
Batch 1: 0.19s


Training Epoch:  38%|███▊      | 3/8 [00:52<00:59, 11.85s/it]

Batch 2: 0.20s


Training Epoch:  50%|█████     | 4/8 [00:52<00:29,  7.25s/it]

Batch 3: 0.20s
Sample 1: DICOM load: 6.05s, Preprocess: 1.50s
Sample 0: DICOM load: 2.21s, Preprocess: 0.91s


Training Epoch:  62%|██████▎   | 5/8 [01:15<00:38, 12.73s/it]

Batch 4: 0.27s
Sample 3: DICOM load: 18.43s, Preprocess: 3.78s


Training Epoch: 100%|██████████| 8/8 [01:34<00:00, 11.80s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10077108087009955586144859725246456654: all input arrays must have the same shape
Sample 0: DICOM load: 4.84s, Preprocess: 1.04s
Sample 1: DICOM load: 11.64s, Preprocess: 2.69s
Sample 2: DICOM load: 3.64s, Preprocess: 1.17s
Sample 3: DICOM load: 7.27s, Preprocess: 1.95s
Sample 4: DICOM load: 2.83s, Preprocess: 1.12s


Validating: 100%|██████████| 5/5 [01:16<00:00, 15.33s/it]


Epoch   5 | Train Loss: 488635768.0000 | Val Loss: 488626124.8000 | Val AUC: 0.3028


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Sample 2: DICOM load: 3.61s, Preprocess: 0.01s
Sample 1: DICOM load: 6.22s, Preprocess: 1.51s
Sample 4: DICOM load: 9.92s, Preprocess: 2.65s
Sample 0: DICOM load: 2.24s, Preprocess: 0.92s


Training Epoch:  25%|██▌       | 2/8 [00:50<02:05, 20.96s/it]

Batch 0: 0.27s
Batch 1: 0.19s


Training Epoch:  38%|███▊      | 3/8 [00:51<00:57, 11.48s/it]

Batch 2: 0.20s


Training Epoch:  50%|█████     | 4/8 [00:54<00:33,  8.32s/it]

Batch 3: 0.25s
Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10192011262895867728128531292507199782: all input arrays must have the same shape
Sample 3: DICOM load: 10.15s, Preprocess: 3.69s


Training Epoch:  75%|███████▌  | 6/8 [01:43<00:30, 15.12s/it]

Batch 4: 0.25s


Training Epoch: 100%|██████████| 8/8 [01:43<00:00, 12.97s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10077108087009955586144859725246456654: all input arrays must have the same shape
Sample 0: DICOM load: 4.93s, Preprocess: 0.99s
Sample 1: DICOM load: 11.43s, Preprocess: 2.76s
Sample 2: DICOM load: 4.48s, Preprocess: 1.17s
Sample 3: DICOM load: 6.88s, Preprocess: 1.97s
Sample 4: DICOM load: 2.96s, Preprocess: 1.10s


Validating: 100%|██████████| 5/5 [01:17<00:00, 15.41s/it]


Epoch   6 | Train Loss: 487841196.0000 | Val Loss: 488626124.8000 | Val AUC: 0.3028


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Sample 2: DICOM load: 3.44s, Preprocess: 0.01s
Sample 0: DICOM load: 3.22s, Preprocess: 0.96s


Training Epoch:  12%|█▎        | 1/8 [00:42<04:55, 42.29s/it]

Batch 0: 0.29s
Sample 4: DICOM load: 10.28s, Preprocess: 2.70s


Training Epoch:  25%|██▌       | 2/8 [01:05<03:05, 30.92s/it]

Batch 1: 0.26s


Training Epoch:  50%|█████     | 4/8 [01:05<00:41, 10.30s/it]

Batch 2: 0.20s
Batch 3: 0.19s
Sample 3: DICOM load: 19.33s, Preprocess: 5.19s
Sample 1: DICOM load: 5.77s, Preprocess: 1.54s
Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10192011262895867728128531292507199782: all input arrays must have the same shape


Training Epoch:  75%|███████▌  | 6/8 [01:29<00:20, 10.06s/it]

Batch 4: 0.27s


Training Epoch: 100%|██████████| 8/8 [01:30<00:00, 11.27s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10077108087009955586144859725246456654: all input arrays must have the same shape
Sample 0: DICOM load: 4.83s, Preprocess: 1.05s
Sample 1: DICOM load: 11.75s, Preprocess: 2.66s
Sample 2: DICOM load: 4.00s, Preprocess: 1.17s
Sample 3: DICOM load: 7.40s, Preprocess: 2.02s
Sample 4: DICOM load: 2.79s, Preprocess: 1.09s


Validating: 100%|██████████| 5/5 [01:16<00:00, 15.34s/it]


Epoch   7 | Train Loss: 489422272.0000 | Val Loss: 488626124.8000 | Val AUC: 0.3028


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10192011262895867728128531292507199782: all input arrays must have the same shape


Training Epoch:  12%|█▎        | 1/8 [00:34<03:58, 34.00s/it]

Batch 0: 0.26s


Training Epoch:  38%|███▊      | 3/8 [00:34<00:38,  7.76s/it]

Batch 1: 0.20s
Batch 2: 0.19s
Sample 2: DICOM load: 3.51s, Preprocess: 0.01s
Sample 1: DICOM load: 3.54s, Preprocess: 1.14s


Training Epoch:  62%|██████▎   | 5/8 [01:04<00:31, 10.54s/it]

Batch 3: 0.26s
Batch 4: 0.19s


Training Epoch:  75%|███████▌  | 6/8 [01:04<00:14,  7.02s/it]

Sample 0: DICOM load: 1.27s, Preprocess: 0.67s


Training Epoch:  88%|████████▊ | 7/8 [01:10<00:06,  6.69s/it]

Sample 3: DICOM load: 10.30s, Preprocess: 3.67s


Training Epoch: 100%|██████████| 8/8 [01:50<00:00, 13.81s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10077108087009955586144859725246456654: all input arrays must have the same shape
Sample 0: DICOM load: 4.73s, Preprocess: 1.05s
Sample 1: DICOM load: 11.66s, Preprocess: 2.63s
Sample 2: DICOM load: 4.16s, Preprocess: 1.37s
Sample 3: DICOM load: 7.04s, Preprocess: 1.99s
Sample 4: DICOM load: 2.89s, Preprocess: 1.15s


Validating: 100%|██████████| 5/5 [01:16<00:00, 15.37s/it]


Epoch   8 | Train Loss: 491971076.0000 | Val Loss: 488626124.8000 | Val AUC: 0.3028


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Sample 1: DICOM load: 5.94s, Preprocess: 1.45s
Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10192011262895867728128531292507199782: all input arrays must have the same shape
Sample 4: DICOM load: 9.51s, Preprocess: 2.64s
Sample 0: DICOM load: 2.32s, Preprocess: 0.91s
Sample 2: DICOM load: 3.52s, Preprocess: 0.01s


Training Epoch:  12%|█▎        | 1/8 [00:58<06:48, 58.38s/it]

Batch 0: 0.25s
Batch 1: 0.20s


Training Epoch:  38%|███▊      | 3/8 [00:58<01:06, 13.22s/it]

Batch 2: 0.20s


Training Epoch:  50%|█████     | 4/8 [00:59<00:32,  8.09s/it]

Batch 3: 0.22s
Sample 3: DICOM load: 12.71s, Preprocess: 3.77s


Training Epoch:  75%|███████▌  | 6/8 [01:31<00:22, 11.08s/it]

Batch 4: 0.25s


Training Epoch: 100%|██████████| 8/8 [01:31<00:00, 11.46s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10077108087009955586144859725246456654: all input arrays must have the same shape
Sample 0: DICOM load: 4.53s, Preprocess: 0.98s
Sample 1: DICOM load: 11.52s, Preprocess: 2.60s
Sample 2: DICOM load: 3.95s, Preprocess: 1.40s
Sample 3: DICOM load: 7.15s, Preprocess: 1.96s
Sample 4: DICOM load: 2.92s, Preprocess: 1.09s


Validating: 100%|██████████| 5/5 [01:16<00:00, 15.21s/it]


Epoch   9 | Train Loss: 487589904.0000 | Val Loss: 488626124.8000 | Val AUC: 0.3028


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Sample 0: DICOM load: 2.27s, Preprocess: 0.97s
Sample 2: DICOM load: 3.28s, Preprocess: 0.01s
Sample 3: DICOM load: 19.02s, Preprocess: 5.18s


Training Epoch:  25%|██▌       | 2/8 [00:41<01:43, 17.23s/it]

Batch 0: 0.27s
Batch 1: 0.19s
Error loading /kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10192011262895867728128531292507199782: all input arrays must have the same shape


Training Epoch:  50%|█████     | 4/8 [01:17<01:02, 15.66s/it]

Batch 2: 0.27s
Batch 3: 0.19s


Training Epoch:  75%|███████▌  | 6/8 [01:18<00:13,  6.71s/it]

Batch 4: 0.18s
