<a href="https://www.kaggle.com/code/nicholas33/02-aneurysmnet-cnn-intracranial-nb153?scriptVersionId=254264443" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [3]:
!pip install monai

# ====================================================
# RSNA INTRACRANIAL ANEURYSM DETECTION - TRAINING PIPELINE
# ====================================================

import os
import gc
import warnings
import json
import time
import numpy as np
import pandas as pd
from typing import Tuple, Dict, List
from collections import Counter
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import roc_auc_score
import albumentations as A

warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler

import pydicom
from scipy import ndimage
import nibabel as nib
from monai.transforms import (
    Compose, RandRotate90d, RandFlipd, RandAffined,
    RandGaussianNoised, RandAdjustContrastd, ToTensord
)
from monai.networks.nets import BasicUNet
from monai.losses import DiceCELoss, FocalLoss
from tqdm import tqdm




2025-08-05 03:42:35.471244: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754365355.662727      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754365355.715688      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
# ====================================================
# CONFIGURATION
# ====================================================

class Config:
    # Paths
    TRAIN_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train.csv'
    LOCALIZER_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train_localizers.csv'
    SERIES_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/series/'
    SEGMENTATION_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/segmentations/'
    
    # Model parameters
    TARGET_SIZE = (32, 64, 64)  # Increased resolution
    EPOCHS = 10
    BATCH_SIZE = 8  # Reduced due to larger input size
    LEARNING_RATE = 1e-3
    WEIGHT_DECAY = 1e-4
    N_FOLDS = 3
    
    # Training parameters
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    MIXED_PRECISION = True
    GRADIENT_ACCUMULATION = 4
    
    # Competition constants
    ID_COL = 'SeriesInstanceUID'
    LABEL_COLS = [
        'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery',
        'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
        'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery', 'Anterior Communicating Artery',
        'Left Anterior Cerebral Artery', 'Right Anterior Cerebral Artery',
        'Left Posterior Communicating Artery', 'Right Posterior Communicating Artery',
        'Basilar Tip', 'Other Posterior Circulation', 'Aneurysm Present',
    ]
    
    # Class weights for imbalanced data
    ANEURYSM_PRESENT_WEIGHT = 13.0  # Match evaluation metric weighting

In [5]:
# ====================================================
# ENHANCED DATA PREPROCESSING
# ====================================================

class AdvancedDICOMProcessor:
    def __init__(self, target_size: Tuple[int, int, int] = Config.TARGET_SIZE):
        self.target_size = target_size
        self.stats = {
            'total_loaded': 0,
            'successful_loads': 0,
            'shape_errors': 0,
            'empty_volumes': 0,
            'preprocessing_errors': 0
        }
        
    def load_dicom_series(self, series_path: str) -> Tuple[np.ndarray, Dict]:
        """Load DICOM series with FIXED shape handling"""
        self.stats['total_loaded'] += 1
        try:
            dicom_files = [os.path.join(series_path, f) for f in os.listdir(series_path) if f.endswith('.dcm')]
            if not dicom_files:
                print(f"No DICOM files found in {series_path}")
                return np.zeros(self.target_size, dtype=np.float32), {}
                
            dicoms = []
            
            for f in dicom_files:
                ds = pydicom.dcmread(f, force=True)
                dicoms.append(ds)
                
            # Extract metadata from first DICOM
            first_ds = dicoms[0]
            metadata = {
                'modality': getattr(first_ds, 'Modality', 'UNKNOWN'),
                'spacing': getattr(first_ds, 'PixelSpacing', [1.0, 1.0]),
                'slice_thickness': getattr(first_ds, 'SliceThickness', 1.0),
                'rescale_slope': getattr(first_ds, 'RescaleSlope', 1.0),
                'rescale_intercept': getattr(first_ds, 'RescaleIntercept', 0.0),
            }
            
            # Sort by instance number 
            dicoms.sort(key=lambda x: int(getattr(x, 'InstanceNumber', 0)))

            
            # DEBUG: Let's see what we're actually getting (only for first few volumes)
            pixel_arrays = []
            shapes = []
            debug_this_volume = self.stats['total_loaded'] <= 3  # Only debug first 3 volumes
            
            for i, d in enumerate(dicoms):
                if hasattr(d, 'pixel_array'):
                    arr = d.pixel_array
                    if debug_this_volume:
                        print(f"  Slice {i}: shape={arr.shape}, dtype={arr.dtype}, ndim={arr.ndim}")
                    
                    # Only accept 2D arrays
                    if arr.ndim == 2:
                        pixel_arrays.append(arr)
                        shapes.append(arr.shape)
                    else:
                        if debug_this_volume:
                            print(f"  SKIPPING slice {i}: not 2D (ndim={arr.ndim})")
                else:
                    if debug_this_volume:
                        print(f"  SKIPPING slice {i}: no pixel_array attribute")
            
            if len(pixel_arrays) == 0:
                print(f"CRITICAL: No valid 2D pixel arrays found in {series_path}")
                self.stats['shape_errors'] += 1
                return np.zeros(self.target_size, dtype=np.float32), metadata

            if debug_this_volume:
                print(f"  Found {len(pixel_arrays)} valid slices, unique shapes: {set(shapes)}")
            
            # Check if all shapes are the same
            unique_shapes = list(set(shapes))
            if len(unique_shapes) == 1:
                if debug_this_volume:
                    print(f"  All slices have same shape {unique_shapes[0]} - using original stacking")
                volume = np.stack(pixel_arrays, axis=0).astype(np.float32)
            else:
                if debug_this_volume:
                    print(f"  Multiple shapes found: {unique_shapes}")
                # Find most common shape
                most_common_shape = Counter(shapes).most_common(1)[0][0]
                if debug_this_volume:
                    print(f"  Resizing all to most common shape: {most_common_shape}")
                
                resized_arrays = []
                for i, arr in enumerate(pixel_arrays):
                    if arr.shape == most_common_shape:
                        resized_arrays.append(arr.astype(np.float32))
                    else:
                        if debug_this_volume:
                            print(f"    Resizing slice {i} from {arr.shape} to {most_common_shape}")
                        zoom_factors = (most_common_shape[0] / arr.shape[0], 
                                      most_common_shape[1] / arr.shape[1])
                        resized_arr = ndimage.zoom(arr, zoom_factors, order=1, prefilter=False)
                        resized_arrays.append(resized_arr.astype(np.float32))
                
                volume = np.stack(resized_arrays, axis=0).astype(np.float32)

            if debug_this_volume:
                print(f"  ✅ Final volume shape: {volume.shape}, dtype: {volume.dtype}")
            
            # Apply rescale if available
            if metadata['rescale_slope'] != 1.0 or metadata['rescale_intercept'] != 0.0:
                volume = volume * metadata['rescale_slope'] + metadata['rescale_intercept']

            self.stats['successful_loads'] += 1
            return volume, metadata
            
        except Exception as e:
            print(f"Error loading {series_path}: {e}")
            self.stats['shape_errors'] += 1
            return np.zeros(self.target_size, dtype=np.float32), {}

    def print_stats(self):
        """Print loading statistics"""
        total = self.stats['total_loaded']
        successful = self.stats['successful_loads']
        if total > 0:
            success_rate = (successful / total) * 100
            print(f"\n=== DICOM Loading Stats ===")
            print(f"Total attempts: {total}")
            print(f"Successful loads: {successful} ({success_rate:.1f}%)")
            print(f"Shape errors: {self.stats['shape_errors']}")
            print(f"Empty volumes: {self.stats['empty_volumes']}")
            print(f"Preprocessing errors: {self.stats['preprocessing_errors']}")
            print(f"===========================")

    def preprocess_volume(self, volume: np.ndarray, metadata: Dict) -> np.ndarray:
        """Enhanced preprocessing with modality-specific handling"""
        if volume.ndim != 3 or volume.size == 0:
            print(f"Warning: Received a non-3D volume. Returning empty target volume.")
            return np.zeros(self.target_size, dtype=np.float32)
        
        # Default windowing
        p1, p99 = np.percentile(volume, [5, 95])
        volume = np.clip(volume, p1, p99)
        
        # Normalization
        vol_min, vol_max = volume.min(), volume.max()
        if vol_max > vol_min:
            volume = (volume - vol_min) / (vol_max - vol_min)
        
        # Resize to target size
        if volume.shape != self.target_size:
            zoom_factors = [self.target_size[i] / volume.shape[i] for i in range(3)]
            volume = ndimage.zoom(volume, zoom_factors, order=1, prefilter=False)
        
        return volume.astype(np.float32)

    def load_localization_mask(self, series_id: str, localizer_df: pd.DataFrame) -> np.ndarray:
        return np.zeros(self.target_size, dtype=np.float32)


# ====================================================
# ENHANCED DATASET
# ====================================================

class EnhancedAneurysmDataset(Dataset):
    def __init__(self, df: pd.DataFrame, localizer_df: pd.DataFrame, 
                 series_dir: str, processor: AdvancedDICOMProcessor, 
                 mode: str = 'train', fold: int = None):
        self.df = df
        self.localizer_df = localizer_df
        self.series_dir = series_dir
        self.processor = processor
        self.mode = mode
        self.fold = fold
        
        # Data augmentation for training
        if mode == 'train':
            self.transform = Compose([
                # RandRotate90d(keys=['volume'], prob=0.3, spatial_axes=(0, 1)),
                # RandFlipd(keys=['volume'], prob=0.3, spatial_axis=0),
                # RandFlipd(keys=['volume'], prob=0.3, spatial_axis=1),
                # RandFlipd(keys=['volume'], prob=0.3, spatial_axis=2),
                # RandAffined(keys=['volume'], prob=0.3, rotate_range=0.1, scale_range=0.1),
                # RandGaussianNoised(keys=['volume'], prob=0.3, std=0.05),
                # RandAdjustContrastd(keys=['volume'], prob=0.3, gamma=(0.8, 1.2)),
                RandFlipd(keys=['volume'], prob=0.5, spatial_axis=0),  # Only basic flips
                RandFlipd(keys=['volume'], prob=0.5, spatial_axis=1),
                ToTensord(keys=['volume'])
            ])
        else:
            self.transform = Compose([ToTensord(keys=['volume'])])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        load_start = time.time()
        row = self.df.iloc[idx]
        series_id = row[Config.ID_COL]
        series_path = os.path.join(self.series_dir, series_id)
        
        # Load and process volume
        volume, metadata = self.processor.load_dicom_series(series_path)
        dicom_time = time.time() - load_start
        preprocess_start = time.time()
        volume = self.processor.preprocess_volume(volume, metadata)
        
        # Create localization mask (for auxiliary loss)
        loc_mask = self.processor.load_localization_mask(series_id, self.localizer_df)
        
        # Get labels
        labels = row[Config.LABEL_COLS].values.astype(np.float32)
        
        # Apply transforms
        data_dict = {'volume': volume}
        if self.transform:
            data_dict = self.transform(data_dict)
        
        volume_tensor = data_dict['volume'].unsqueeze(0)  # Add channel dimension
        loc_mask_tensor = torch.from_numpy(loc_mask).unsqueeze(0)
        labels_tensor = torch.from_numpy(labels)
        
        # Add metadata features
        modality_encoding = self._encode_modality(metadata.get('modality', 'UNKNOWN'))
        metadata_tensor = torch.tensor(modality_encoding, dtype=torch.float32)
        
        preprocess_time = time.time() - preprocess_start
        # Print timing for first few samples to debug
        if idx < 5:
            print(f"Sample {idx}: DICOM load: {dicom_time:.2f}s, Preprocess: {preprocess_time:.2f}s")
        
        return {
            'volume': volume_tensor,
            'localization_mask': loc_mask_tensor,
            'labels': labels_tensor,
            'metadata': metadata_tensor,
            'series_id': series_id
        }
    
    def _encode_modality(self, modality: str) -> List[float]:
        """One-hot encode modality"""
        modalities = ['CTA', 'MRA', 'MRI', 'MR', 'UNKNOWN']
        encoding = [0.0] * len(modalities)
        if modality in modalities:
            encoding[modalities.index(modality)] = 1.0
        else:
            encoding[-1] = 1.0  # UNKNOWN
        return encoding


# ====================================================
# ADVANCED MODEL ARCHITECTURE
# ====================================================

#class MultiModalAneurysmNet(nn.Module):
class SimplifiedAneurysmNet(nn.Module):
    def __init__(self, num_classes: int = len(Config.LABEL_COLS), 
                 spatial_dims: int = 3, in_channels: int = 1, 
        #          features: Tuple = (32, 64, 128, 256, 512, 1024)):
        # super(MultiModalAneurysmNet, self).__init__()
                 features: Tuple = (16, 32, 64, 128, 256, 51)):
        super(SimplifiedAneurysmNet, self).__init__()
        
        # Main 3D U-Net backbone
        self.backbone = BasicUNet(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=features[0],
            features=features,
            dropout=0.1 #Reduced dropout 
        )
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        
        # # Metadata processing
        # self.metadata_mlp = nn.Sequential(
        #     nn.Linear(5, 32),  # 5 modality categories
        #     nn.ReLU(),
        #     nn.Dropout(0.3),
        #     nn.Linear(32, 64),
        #     nn.ReLU()
        # )
        
        # Classification head
        #feature_size = features[0] + 64  # backbone features + metadata features
        self.classifier = nn.Sequential(
            # nn.Linear(feature_size, 512),
            # nn.ReLU(),
            # nn.Dropout(0.5),
            # nn.Linear(512, 256),
            nn.Linear(features[0], 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, volume, metadata=None):
        # Extract features from 3D volume
        features = self.backbone(volume)
        # Global features for classification
        global_features = self.global_pool(features).flatten(1)
        classification_logits = self.classifier(global_features)
        return classification_logits, None

# ====================================================
# WEIGHTED LOSS FUNCTION
# ====================================================

class WeightedMultiLabelLoss(nn.Module):
    def __init__(self, pos_weights=None, aneurysm_weight=13.0):
        super().__init__()
        self.pos_weights = pos_weights
        self.aneurysm_weight = aneurysm_weight
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
        
    def forward(self, logits, targets):
        bce_loss = self.bce(logits, targets)
        
        # Apply position weights if provided
        if self.pos_weights is not None:
            bce_loss = bce_loss * self.pos_weights.to(logits.device)
        
        # Weight the "Aneurysm Present" class higher (last column)
        weights = torch.ones_like(bce_loss)
        weights[:, -1] = self.aneurysm_weight
        
        weighted_loss = bce_loss * weights
        return weighted_loss.mean()


# ====================================================
# TRAINING FUNCTIONS
# ====================================================

def compute_weighted_auc(y_true, y_pred):
    """Compute weighted AUC matching competition metric"""
    aucs = []
    weights = []
    
    for i in range(len(Config.LABEL_COLS)):
        try:
            auc = roc_auc_score(y_true[:, i], y_pred[:, i])
            aucs.append(auc)
            # Weight "Aneurysm Present" (last column) higher
            weights.append(13.0 if i == len(Config.LABEL_COLS) - 1 else 1.0)
        except ValueError:
            aucs.append(0.5)  # Default for no positive cases
            weights.append(13.0 if i == len(Config.LABEL_COLS) - 1 else 1.0)
    
    weighted_auc = sum(a * w for a, w in zip(aucs, weights)) / sum(weights)
    return weighted_auc, aucs

def train_epoch(model, train_loader, optimizer, criterion, scaler, device):
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Training Epoch")):
        start_time = time.time()
        volume = batch['volume'].to(device)
        metadata = batch['metadata'].to(device)
        labels = batch['labels'].to(device)
        loc_mask = batch['localization_mask'].to(device)
        
        with autocast(device_type=device.type, enabled=Config.MIXED_PRECISION):
            class_logits, _ = model(volume, metadata)
            total_loss_batch = criterion(class_logits, labels)
        
        # Gradient accumulation
        scaled_loss = total_loss_batch / Config.GRADIENT_ACCUMULATION
        scaler.scale(scaled_loss).backward()
        
        if (batch_idx + 1) % Config.GRADIENT_ACCUMULATION == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        total_loss += total_loss_batch.item()
        num_batches += 1

        # Print timing for first few batches to identify bottlenecks
        if batch_idx < 5:
            batch_time = time.time() - start_time
            print(f"Batch {batch_idx}: {batch_time:.2f}s")
    
    return total_loss / num_batches

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            volume = batch['volume'].to(device)
            metadata = batch['metadata'].to(device)
            labels = batch['labels'].to(device)
            
            with autocast(device_type=device.type, enabled=Config.MIXED_PRECISION):
                class_logits, _ = model(volume, metadata)
                loss = criterion(class_logits, labels)
            
            total_loss += loss.item()
            
            # Collect predictions for AUC calculation
            probs = torch.sigmoid(class_logits).cpu().numpy()
            all_preds.append(probs)
            all_labels.append(labels.cpu().numpy())
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    weighted_auc, individual_aucs = compute_weighted_auc(all_labels, all_preds)
    
    return total_loss / len(val_loader), weighted_auc, individual_aucs



In [None]:
# ====================================================
# MAIN TRAINING EXECUTION
# ====================================================

def main():
    print(f"Using device: {Config.DEVICE}")
    print(f"Mixed precision: {Config.MIXED_PRECISION}")
    
    # Load data
    train_df = pd.read_csv(Config.TRAIN_CSV_PATH)
    localizer_df = pd.read_csv(Config.LOCALIZER_CSV_PATH)

    print("---!!! RUNNING IN DEBUG MODE ON A SMALL SUBSET !!!---")
    #print(f"Training samples: {len(train_df)}")
    train_df = train_df.head(100)  # Limit to 100 samples for speed testing
    print(f"Training samples: {len(train_df)} (limited for speed testing)")
    print(f"Positive aneurysm cases: {train_df['Aneurysm Present'].sum()}")
    
    # Create stratified group k-fold split
    # Use patient-level grouping to prevent data leakage
    train_df['patient_group'] = train_df['PatientID'] if 'PatientID' in train_df.columns else range(len(train_df))
    
    skf = StratifiedGroupKFold(n_splits=Config.N_FOLDS, shuffle=True, random_state=42)
    train_df['fold'] = -1
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(
        train_df, train_df['Aneurysm Present'], groups=train_df['patient_group']
    )):
        train_df.loc[val_idx, 'fold'] = fold
    
    # Calculate class weights for imbalanced data
    pos_counts = train_df[Config.LABEL_COLS].sum()
    neg_counts = len(train_df) - pos_counts
    pos_weights = neg_counts / (pos_counts + 1e-8)  # Add small epsilon
    pos_weights = torch.tensor(pos_weights.values, dtype=torch.float32)
    
    print("Class weights:", pos_weights)
    
    # Initialize processor
    processor = AdvancedDICOMProcessor()
    
    # Train models for each fold
    fold_scores = []
    
    for fold in range(Config.N_FOLDS):
        print(f"\n{'='*50}")
        print(f"FOLD {fold + 1}/{Config.N_FOLDS}")
        print(f"{'='*50}")
        
        # Split data
        train_fold_df = train_df[train_df['fold'] != fold].reset_index(drop=True)
        val_fold_df = train_df[train_df['fold'] == fold].reset_index(drop=True)
        
        print(f"Train: {len(train_fold_df)}, Validation: {len(val_fold_df)}")
        
        # Create datasets
        train_dataset = EnhancedAneurysmDataset(
            train_fold_df, localizer_df, Config.SERIES_DIR, processor, mode='train', fold=fold
        )
        val_dataset = EnhancedAneurysmDataset(
            val_fold_df, localizer_df, Config.SERIES_DIR, processor, mode='val', fold=fold
        )
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, 
            num_workers=4, pin_memory=True, drop_last=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, 
            num_workers=4, pin_memory=True
        )
        
        # Initialize model
        model = SimplifiedAneurysmNet().to(Config.DEVICE)
        criterion = WeightedMultiLabelLoss(pos_weights=pos_weights)
        
        # Optimizer with different learning rates for different parts
        # optimizer = optim.AdamW([
        #     {'params': model.backbone.parameters(), 'lr': Config.LEARNING_RATE},
        #     {'params': model.classifier.parameters(), 'lr': Config.LEARNING_RATE * 2},
        #     {'params': model.metadata_mlp.parameters(), 'lr': Config.LEARNING_RATE * 2}
        # ], weight_decay=Config.WEIGHT_DECAY)
        optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)
        
        # Learning rate scheduler
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=2, eta_min=1e-6
        )
        
        scaler = GradScaler(enabled=Config.MIXED_PRECISION)
        
        # Training loop
        best_auc = 0
        patience = 10
        patience_counter = 0
        
        for epoch in range(Config.EPOCHS):
            # Train
            train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler, Config.DEVICE)
            
            # Validate
            val_loss, val_auc, individual_aucs = validate_epoch(model, val_loader, criterion, Config.DEVICE)
            
            # Step scheduler
            scheduler.step()
            
            print(f"Epoch {epoch+1:3d} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"Val AUC: {val_auc:.4f}")

            processor.print_stats() # Print DICOM loading stats after each epoch

            # SANITY CHECK: Stop if data loading is fundamentally broken
            if processor.stats['total_loaded'] > 20:  # Only check after some attempts
                success_rate = processor.stats['successful_loads'] / processor.stats['total_loaded']
                if success_rate < 0.5:  # Less than 50% success rate
                    print(f"\n🚨 STOPPING TRAINING: Data loading success rate is {success_rate:.1%}")
                    print("Fix the DICOM loading issues before continuing training!")
                    print("Most volumes are returning empty - this is a waste of time!")
                    break
            
            # Save best model
            if val_auc > best_auc:
                best_auc = val_auc
                patience_counter = 0
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_auc': val_auc,
                    'epoch': epoch,
                    'fold': fold,
                    'individual_aucs': individual_aucs
                }, f'best_model_fold_{fold}.pth')
            else:
                patience_counter += 1
                
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
                
            # Memory cleanup
            if epoch % 5 == 0:
                gc.collect()
                torch.cuda.empty_cache()
        
        fold_scores.append(best_auc)
        print(f"Fold {fold + 1} best AUC: {best_auc:.4f}")
    
    # Final results
    mean_cv_score = np.mean(fold_scores)
    std_cv_score = np.std(fold_scores)
    
    print(f"\n{'='*50}")
    print(f"CROSS-VALIDATION RESULTS")
    print(f"{'='*50}")
    print(f"Mean CV AUC: {mean_cv_score:.4f} ± {std_cv_score:.4f}")
    print(f"Individual fold scores: {fold_scores}")
    
    # Save training summary
    results = {
        'cv_scores': fold_scores,
        'mean_cv_score': mean_cv_score,
        'std_cv_score': std_cv_score,
        'config': vars(Config())
    }
    
    with open('training_results.json', 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    print("Training complete! Models saved as 'best_model_fold_X.pth'")

if __name__ == "__main__":
    main()

Using device: cuda
Mixed precision: True
---!!! RUNNING IN DEBUG MODE ON A SMALL SUBSET !!!---
Training samples: 100 (limited for speed testing)
Positive aneurysm cases: 48
Class weights: tensor([3.2333e+01, 4.9000e+01, 1.3286e+01, 9.0000e+00, 4.9000e+01, 1.1500e+01,
        1.0111e+01, 9.9000e+01, 4.9000e+01, 4.9000e+01, 1.0000e+10, 3.2333e+01,
        4.9000e+01, 1.0833e+00])

FOLD 1/3
Train: 66, Validation: 34
BasicUNet features: (16, 32, 64, 128, 256, 51).


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Sample 0: DICOM load: 3.27s, Preprocess: 0.98s


Training Epoch:  12%|█▎        | 1/8 [00:32<03:44, 32.04s/it]

Batch 0: 1.68s


Training Epoch:  25%|██▌       | 2/8 [00:32<01:19, 13.32s/it]

Batch 1: 0.21s
Sample 4: DICOM load: 10.50s, Preprocess: 2.01s
Sample 3: DICOM load: 24.33s, Preprocess: 4.81s
Sample 1: DICOM load: 7.14s, Preprocess: 1.21s


Training Epoch:  38%|███▊      | 3/8 [01:09<02:01, 24.40s/it]

Batch 2: 0.26s


Training Epoch:  50%|█████     | 4/8 [01:10<00:59, 14.85s/it]

Batch 3: 0.22s
Sample 2: DICOM load: 2.88s, Preprocess: 0.01s


Training Epoch:  75%|███████▌  | 6/8 [01:31<00:22, 11.39s/it]

Batch 4: 0.26s


Training Epoch: 100%|██████████| 8/8 [02:15<00:00, 16.98s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Sample 0: DICOM load: 6.06s, Preprocess: 0.79s
Sample 1: DICOM load: 10.46s, Preprocess: 2.64s
Sample 2: DICOM load: 7.79s, Preprocess: 1.20s
Sample 3: DICOM load: 10.30s, Preprocess: 1.94s
Sample 4: DICOM load: 4.90s, Preprocess: 1.13s


Validating: 100%|██████████| 5/5 [01:42<00:00, 20.52s/it]


Epoch   1 | Train Loss: 452429748.0000 | Val Loss: 451950188.8000 | Val AUC: 0.4726


Training Epoch:   0%|          | 0/8 [00:00<?, ?it/s]

Sample 4: DICOM load: 9.80s, Preprocess: 3.30s
Sample 3: DICOM load: 20.40s, Preprocess: 4.98s
Sample 0: DICOM load: 2.33s, Preprocess: 0.91s
Sample 1: DICOM load: 6.97s, Preprocess: 1.56s


Training Epoch:  12%|█▎        | 1/8 [01:03<07:24, 63.57s/it]

Batch 0: 0.29s
Sample 2: DICOM load: 2.91s, Preprocess: 0.01s


Training Epoch:  25%|██▌       | 2/8 [01:22<03:43, 37.32s/it]

Batch 1: 0.28s


Training Epoch:  50%|█████     | 4/8 [01:22<00:49, 12.41s/it]

Batch 2: 0.21s
Batch 3: 0.19s


Training Epoch:  62%|██████▎   | 5/8 [01:27<00:28,  9.42s/it]

Batch 4: 0.24s


Training Epoch: 100%|██████████| 8/8 [01:50<00:00, 13.85s/it]
Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Sample 0: DICOM load: 4.82s, Preprocess: 0.99s
Sample 1: DICOM load: 11.21s, Preprocess: 2.70s
Sample 2: DICOM load: 4.52s, Preprocess: 1.15s
Sample 3: DICOM load: 8.60s, Preprocess: 1.95s
Sample 4: DICOM load: 2.88s, Preprocess: 1.10s
