<a href="https://www.kaggle.com/code/nicholas33/02-aneurysmnet-cnn-intracranial-nb153?scriptVersionId=254278695" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
!pip install monai

# ====================================================
# RSNA INTRACRANIAL ANEURYSM DETECTION - TRAINING PIPELINE
# ====================================================

import os
import gc
import warnings
import json
import time
import numpy as np
import pandas as pd
from typing import Tuple, Dict, List
from collections import Counter
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import roc_auc_score
import albumentations as A

warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler

import pydicom
import pydicom.errors
from scipy import ndimage
import nibabel as nib
from monai.transforms import (
    Compose, RandRotate90d, RandFlipd, RandAffined,
    RandGaussianNoised, RandAdjustContrastd, ToTensord
)
from monai.networks.nets import BasicUNet
from monai.losses import DiceCELoss, FocalLoss
from tqdm import tqdm


Collecting monai
  Downloading monai-1.5.0-py3-none-any.whl.metadata (13 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch<2.7.0,

2025-08-05 05:35:29.489767: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754372129.729696      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754372129.801450      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# ====================================================
# CONFIGURATION
# ====================================================

class Config:
    # Paths
    TRAIN_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train.csv'
    LOCALIZER_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train_localizers.csv'
    SERIES_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/series/'
    SEGMENTATION_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/segmentations/'
    
    # Model parameters
    TARGET_SIZE = (32, 64, 64)  # Increased resolution
    EPOCHS = 10
    BATCH_SIZE = 8  # Reduced due to larger input size
    LEARNING_RATE = 1e-3
    WEIGHT_DECAY = 1e-4
    N_FOLDS = 3
    
    # Training parameters
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    MIXED_PRECISION = True
    GRADIENT_ACCUMULATION = 4
    
    # Competition constants
    ID_COL = 'SeriesInstanceUID'
    LABEL_COLS = [
        'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery',
        'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
        'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery', 'Anterior Communicating Artery',
        'Left Anterior Cerebral Artery', 'Right Anterior Cerebral Artery',
        'Left Posterior Communicating Artery', 'Right Posterior Communicating Artery',
        'Basilar Tip', 'Other Posterior Circulation', 'Aneurysm Present',
    ]
    
    # Class weights for imbalanced data
    ANEURYSM_PRESENT_WEIGHT = 13.0  # Match evaluation metric weighting

# ====================================================
# ENHANCED DATA PREPROCESSING
# ====================================================

class AdvancedDICOMProcessor:
    def __init__(self, target_size: Tuple[int, int, int] = Config.TARGET_SIZE):
        self.target_size = target_size
        self.stats = {
            'total_loaded': 0,
            'successful_loads': 0,
            'shape_errors': 0,
            'empty_volumes': 0,
            'preprocessing_errors': 0
        }
        
    def load_dicom_series(self, series_path: str) -> Tuple[np.ndarray, Dict]:
        """Load DICOM series with robust error handling"""
        self.stats['total_loaded'] += 1
        try:
            dicom_files = [os.path.join(series_path, f) for f in os.listdir(series_path) if f.endswith('.dcm')]
            if not dicom_files:
                series_id = os.path.basename(series_path)
                print(f"No DICOM files found in {series_path}")
                self.stats['empty_volumes'] += 1
                # Log series with no DICOM files
                log_message = f"{series_id}: No DICOM files found\n"
                try:
                    with open('corrupted_series.txt', 'a') as f:
                        f.write(log_message)
                        f.flush()  # Ensure immediate write
                    print(f"📝 Logged series with no DICOMs: {series_id}")
                except Exception as e:
                    print(f"⚠️  Failed to log series: {e}")
                return np.zeros(self.target_size, dtype=np.float32), {}
                
            dicoms = []
            corrupted_count = 0
            
            for f in dicom_files:
                try:
                    ds = pydicom.dcmread(f, force=True)
                    if hasattr(ds, 'pixel_array') and ds.pixel_array.size > 0:
                        dicoms.append(ds)
                    else:
                        corrupted_count += 1
                        if corrupted_count <= 3:  # Only log first few
                            print(f"  Skipping empty pixel array: {os.path.basename(f)}")
                except pydicom.errors.InvalidDicomError:
                    corrupted_count += 1
                    if corrupted_count <= 3:
                        print(f"  Invalid DICOM file: {os.path.basename(f)}")
                    continue
                except Exception as e:
                    corrupted_count += 1
                    if corrupted_count <= 3:
                        print(f"  Error reading {os.path.basename(f)}: {e}")
                    continue
                    
            if not dicoms:
                series_id = os.path.basename(series_path)
                print(f"❌ No valid DICOMs in series ({corrupted_count} corrupted files)")
                self.stats['empty_volumes'] += 1
                # Log corrupted series for inspection
                log_message = f"{series_id}: No valid DICOMs ({corrupted_count} corrupted files)\n"
                try:
                    with open('corrupted_series.txt', 'a') as f:
                        f.write(log_message)
                        f.flush()  # Ensure immediate write
                    print(f"📝 Logged corrupted series: {series_id}")
                except Exception as e:
                    print(f"⚠️  Failed to log corrupted series: {e}")
                return np.zeros(self.target_size, dtype=np.float32), {}
                
            if corrupted_count > 0:
                print(f"⚠️  Series loaded with {corrupted_count} corrupted files (kept {len(dicoms)} valid)")
                
            # Extract metadata from first DICOM
            first_ds = dicoms[0]
            metadata = {
                'modality': getattr(first_ds, 'Modality', 'UNKNOWN'),
                'spacing': getattr(first_ds, 'PixelSpacing', [1.0, 1.0]),
                'slice_thickness': getattr(first_ds, 'SliceThickness', 1.0),
                'rescale_slope': getattr(first_ds, 'RescaleSlope', 1.0),
                'rescale_intercept': getattr(first_ds, 'RescaleIntercept', 0.0),
            }
            
            # Sort by instance number 
            dicoms.sort(key=lambda x: int(getattr(x, 'InstanceNumber', 0)))
            
            # Process pixel arrays (minimal logging)
            pixel_arrays = []
            shapes = []
            
            for d in dicoms:
                if hasattr(d, 'pixel_array'):
                    try:
                        arr = d.pixel_array
                        if arr.ndim == 2 and arr.size > 0:  # Valid 2D slice
                            pixel_arrays.append(arr)
                            shapes.append(arr.shape)
                    except:
                        continue  # Skip corrupted slices
            
            if len(pixel_arrays) == 0:
                print(f"❌ No valid pixel arrays in series (corrupted DICOM)")
                self.stats['shape_errors'] += 1
                return np.zeros(self.target_size, dtype=np.float32), metadata
            
            # Handle shape consistency
            unique_shapes = list(set(shapes))
            if len(unique_shapes) == 1:
                # All same shape - direct stacking
                volume = np.stack(pixel_arrays, axis=0).astype(np.float32)
            else:
                # Multiple shapes - resize to most common
                most_common_shape = Counter(shapes).most_common(1)[0][0]
                resized_arrays = []
                for arr in pixel_arrays:
                    if arr.shape == most_common_shape:
                        resized_arrays.append(arr.astype(np.float32))
                    else:
                        zoom_factors = (most_common_shape[0] / arr.shape[0], 
                                      most_common_shape[1] / arr.shape[1])
                        resized_arr = ndimage.zoom(arr, zoom_factors, order=1, prefilter=False)
                        resized_arrays.append(resized_arr.astype(np.float32))
                volume = np.stack(resized_arrays, axis=0).astype(np.float32)
            
            # Log successful loads (first few only)
            if self.stats['total_loaded'] <= 10:
                print(f"✅ Loaded: {volume.shape} from {len(pixel_arrays)} slices")
            
            # Apply rescale if available
            if metadata['rescale_slope'] != 1.0 or metadata['rescale_intercept'] != 0.0:
                volume = volume * metadata['rescale_slope'] + metadata['rescale_intercept']

            self.stats['successful_loads'] += 1
            return volume, metadata
            
        except Exception as e:
            print(f"Error loading {series_path}: {e}")
            self.stats['shape_errors'] += 1
            return np.zeros(self.target_size, dtype=np.float32), {}

    def print_stats(self):
        """Print loading statistics"""
        total = self.stats['total_loaded']
        successful = self.stats['successful_loads']
        empty = self.stats['empty_volumes']
        shape_errors = self.stats['shape_errors']
        
        if total > 0:
            success_rate = (successful / total) * 100
            print(f"\n📊 === DICOM Loading Stats ===")
            print(f"✅ Successful loads: {successful}/{total} ({success_rate:.1f}%)")
            print(f"❌ Corrupted/empty: {empty} ({empty/total*100:.1f}%)")
            print(f"⚠️  Shape errors: {shape_errors} ({shape_errors/total*100:.1f}%)")
            
            if success_rate < 70:
                print(f"🚨 SUCCESS RATE TOO LOW ({success_rate:.1f}%)!")
                print(f"   Most volumes are corrupted - check dataset quality!")
            elif success_rate < 85:
                print(f"⚠️  Moderate success rate ({success_rate:.1f}%) - some data quality issues")
            else:
                print(f"✅ Good success rate ({success_rate:.1f}%)")
            print(f"===============================")

    def preprocess_volume(self, volume: np.ndarray, metadata: Dict) -> np.ndarray:
        """Enhanced preprocessing with modality-specific handling"""
        if volume.ndim != 3 or volume.size == 0:
            print(f"Warning: Received a non-3D volume. Returning empty target volume.")
            return np.zeros(self.target_size, dtype=np.float32)
        
        # Default windowing
        p1, p99 = np.percentile(volume, [5, 95])
        volume = np.clip(volume, p1, p99)
        
        # Normalization
        vol_min, vol_max = volume.min(), volume.max()
        if vol_max > vol_min:
            volume = (volume - vol_min) / (vol_max - vol_min)
        
        # Resize to target size
        if volume.shape != self.target_size:
            zoom_factors = [self.target_size[i] / volume.shape[i] for i in range(3)]
            volume = ndimage.zoom(volume, zoom_factors, order=1, prefilter=False)
        
        return volume.astype(np.float32)

    def load_localization_mask(self, series_id: str, localizer_df: pd.DataFrame) -> np.ndarray:
        return np.zeros(self.target_size, dtype=np.float32)


# ====================================================
# ENHANCED DATASET
# ====================================================

class EnhancedAneurysmDataset(Dataset):
    def __init__(self, df: pd.DataFrame, localizer_df: pd.DataFrame, 
                 series_dir: str, processor: AdvancedDICOMProcessor, 
                 mode: str = 'train', fold: int = None):
        self.df = df
        self.localizer_df = localizer_df
        self.series_dir = series_dir
        self.processor = processor
        self.mode = mode
        self.fold = fold
        
        # Data augmentation for training
        if mode == 'train':
            self.transform = Compose([
                # RandRotate90d(keys=['volume'], prob=0.3, spatial_axes=(0, 1)),
                # RandFlipd(keys=['volume'], prob=0.3, spatial_axis=0),
                # RandFlipd(keys=['volume'], prob=0.3, spatial_axis=1),
                # RandFlipd(keys=['volume'], prob=0.3, spatial_axis=2),
                # RandAffined(keys=['volume'], prob=0.3, rotate_range=0.1, scale_range=0.1),
                # RandGaussianNoised(keys=['volume'], prob=0.3, std=0.05),
                # RandAdjustContrastd(keys=['volume'], prob=0.3, gamma=(0.8, 1.2)),
                RandFlipd(keys=['volume'], prob=0.5, spatial_axis=0),  # Only basic flips
                RandFlipd(keys=['volume'], prob=0.5, spatial_axis=1),
                ToTensord(keys=['volume'])
            ])
        else:
            self.transform = Compose([ToTensord(keys=['volume'])])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        load_start = time.time()
        row = self.df.iloc[idx]
        series_id = row[Config.ID_COL]
        series_path = os.path.join(self.series_dir, series_id)
        
        # Load and process volume
        volume, metadata = self.processor.load_dicom_series(series_path)
        dicom_time = time.time() - load_start
        preprocess_start = time.time()
        volume = self.processor.preprocess_volume(volume, metadata)
        
        # Create localization mask (for auxiliary loss)
        loc_mask = self.processor.load_localization_mask(series_id, self.localizer_df)
        
        # Get labels
        labels = row[Config.LABEL_COLS].values.astype(np.float32)
        
        # Apply transforms
        data_dict = {'volume': volume}
        if self.transform:
            data_dict = self.transform(data_dict)
        
        volume_tensor = data_dict['volume'].unsqueeze(0)  # Add channel dimension
        loc_mask_tensor = torch.from_numpy(loc_mask).unsqueeze(0)
        labels_tensor = torch.from_numpy(labels)
        
        # Add metadata features
        modality_encoding = self._encode_modality(metadata.get('modality', 'UNKNOWN'))
        metadata_tensor = torch.tensor(modality_encoding, dtype=torch.float32)
        
        preprocess_time = time.time() - preprocess_start
        # Print timing for first few samples to debug
        if idx < 5:
            print(f"Sample {idx}: DICOM load: {dicom_time:.2f}s, Preprocess: {preprocess_time:.2f}s")
        
        return {
            'volume': volume_tensor,
            'localization_mask': loc_mask_tensor,
            'labels': labels_tensor,
            'metadata': metadata_tensor,
            'series_id': series_id
        }
    
    def _encode_modality(self, modality: str) -> List[float]:
        """One-hot encode modality"""
        modalities = ['CTA', 'MRA', 'MRI', 'MR', 'UNKNOWN']
        encoding = [0.0] * len(modalities)
        if modality in modalities:
            encoding[modalities.index(modality)] = 1.0
        else:
            encoding[-1] = 1.0  # UNKNOWN
        return encoding


# ====================================================
# ADVANCED MODEL ARCHITECTURE
# ====================================================

#class MultiModalAneurysmNet(nn.Module):
class SimplifiedAneurysmNet(nn.Module):
    def __init__(self, num_classes: int = len(Config.LABEL_COLS), 
                 spatial_dims: int = 3, in_channels: int = 1, 
        #          features: Tuple = (32, 64, 128, 256, 512, 1024)):
        # super(MultiModalAneurysmNet, self).__init__()
                 features: Tuple = (16, 32, 64, 128, 256, 51)):
        super(SimplifiedAneurysmNet, self).__init__()
        
        # Main 3D U-Net backbone
        self.backbone = BasicUNet(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=features[0],
            features=features,
            dropout=0.1 #Reduced dropout 
        )
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        
        # # Metadata processing
        # self.metadata_mlp = nn.Sequential(
        #     nn.Linear(5, 32),  # 5 modality categories
        #     nn.ReLU(),
        #     nn.Dropout(0.3),
        #     nn.Linear(32, 64),
        #     nn.ReLU()
        # )
        
        # Classification head
        #feature_size = features[0] + 64  # backbone features + metadata features
        self.classifier = nn.Sequential(
            # nn.Linear(feature_size, 512),
            # nn.ReLU(),
            # nn.Dropout(0.5),
            # nn.Linear(512, 256),
            nn.Linear(features[0], 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, volume, metadata=None):
        # Extract features from 3D volume
        features = self.backbone(volume)
        # Global features for classification
        global_features = self.global_pool(features).flatten(1)
        classification_logits = self.classifier(global_features)
        return classification_logits, None

# ====================================================
# WEIGHTED LOSS FUNCTION
# ====================================================

class WeightedMultiLabelLoss(nn.Module):
    def __init__(self, pos_weights=None, aneurysm_weight=13.0):
        super().__init__()
        self.pos_weights = pos_weights
        self.aneurysm_weight = aneurysm_weight
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
        
    def forward(self, logits, targets):
        bce_loss = self.bce(logits, targets)
        
        # Apply position weights if provided
        if self.pos_weights is not None:
            bce_loss = bce_loss * self.pos_weights.to(logits.device)
        
        # Weight the "Aneurysm Present" class higher (last column)
        weights = torch.ones_like(bce_loss)
        weights[:, -1] = self.aneurysm_weight
        
        weighted_loss = bce_loss * weights
        return weighted_loss.mean()


# ====================================================
# TRAINING FUNCTIONS
# ====================================================

def compute_weighted_auc(y_true, y_pred):
    """Compute weighted AUC matching competition metric"""
    aucs = []
    weights = []
    
    for i in range(len(Config.LABEL_COLS)):
        try:
            auc = roc_auc_score(y_true[:, i], y_pred[:, i])
            aucs.append(auc)
            # Weight "Aneurysm Present" (last column) higher
            weights.append(13.0 if i == len(Config.LABEL_COLS) - 1 else 1.0)
        except ValueError:
            aucs.append(0.5)  # Default for no positive cases
            weights.append(13.0 if i == len(Config.LABEL_COLS) - 1 else 1.0)
    
    weighted_auc = sum(a * w for a, w in zip(aucs, weights)) / sum(weights)
    return weighted_auc, aucs

def train_epoch(model, train_loader, optimizer, criterion, scaler, device):
    model.train()
    total_loss = 0
    num_batches = 0
    skipped_batches = 0
    
    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Training Epoch")):
        start_time = time.time()
        volume = batch['volume'].to(device)
        metadata = batch['metadata'].to(device)
        labels = batch['labels'].to(device)
        loc_mask = batch['localization_mask'].to(device)

        # CRITICAL FIX: Skip batches with zero-filled volumes
        if torch.all(volume == 0) or torch.var(volume) < 1e-6:
            skipped_batches += 1
            if batch_idx < 5:  # Log first few skips
                print(f"⚠️  Skipping batch {batch_idx}: zero-filled or low-variance volume")
            continue
        
        with autocast(device_type=device.type, enabled=Config.MIXED_PRECISION):
            class_logits, _ = model(volume, metadata)
            total_loss_batch = criterion(class_logits, labels)
        
        # Gradient accumulation
        scaled_loss = total_loss_batch / Config.GRADIENT_ACCUMULATION
        scaler.scale(scaled_loss).backward()
        
        if (batch_idx + 1) % Config.GRADIENT_ACCUMULATION == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        total_loss += total_loss_batch.item()
        num_batches += 1

        # Print timing for first few batches to identify bottlenecks
        if batch_idx < 5:
            batch_time = time.time() - start_time
            print(f"Batch {batch_idx}: {batch_time:.2f}s")
    
    if skipped_batches > 0:
        print(f"⚠️  Skipped {skipped_batches} batches with corrupted/zero volumes")
    
    return total_loss / max(num_batches, 1) if num_batches > 0 else float('inf')

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    num_batches = 0
    skipped_batches = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(val_loader, desc="Validating")):
            volume = batch['volume'].to(device)
            metadata = batch['metadata'].to(device)
            labels = batch['labels'].to(device)
            
            # CRITICAL FIX: Skip batches with zero-filled volumes
            if torch.all(volume == 0) or torch.var(volume) < 1e-6:
                skipped_batches += 1
                if batch_idx < 3:  # Log first few skips
                    print(f"⚠️  Skipping validation batch {batch_idx}: zero-filled or low-variance volume")
                continue
            
            with autocast(device_type=device.type, enabled=Config.MIXED_PRECISION):
                class_logits, _ = model(volume, metadata)
                loss = criterion(class_logits, labels)

                # DEBUG: Check for extreme loss values
            if total_loss_batch.item() > 1e7:
                print(f"🚨 Warning: Extreme loss in batch {batch_idx}: {total_loss_batch.item():.2e}")
                print(f"   Labels: {labels[0].cpu().numpy()}")  # Print first sample's labels
            
            total_loss += loss.item()
            num_batches += 1
            
            # Collect predictions for AUC calculation
            probs = torch.sigmoid(class_logits).cpu().numpy()
            all_preds.append(probs)
            all_labels.append(labels.cpu().numpy())

    if skipped_batches > 0:
        print(f"⚠️  Skipped {skipped_batches} validation batches with corrupted/zero volumes")
    
    if len(all_preds) == 0:
        print("🚨 WARNING: No valid validation batches - all were corrupted!")
        return float('inf'), 0.5, [0.5] * len(Config.LABEL_COLS)
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    weighted_auc, individual_aucs = compute_weighted_auc(all_labels, all_preds)
    
    return total_loss / max(num_batches, 1), weighted_auc, individual_aucs



In [4]:
# ====================================================
# MAIN TRAINING EXECUTION
# ====================================================

def main():
    print(f"Using device: {Config.DEVICE}")
    print(f"Mixed precision: {Config.MIXED_PRECISION}")
    
    # Load data
    train_df = pd.read_csv(Config.TRAIN_CSV_PATH)
    localizer_df = pd.read_csv(Config.LOCALIZER_CSV_PATH)

    print("---!!! RUNNING IN DEBUG MODE ON A SMALL SUBSET !!!---")
    #print(f"Training samples: {len(train_df)}")
    train_df = train_df.head(100)  # Limit to 100 samples for speed testing
    print(f"Training samples: {len(train_df)} (limited for speed testing)")
    print(f"Positive aneurysm cases: {train_df['Aneurysm Present'].sum()}")

    # Calculate class weights for imbalanced data
    pos_counts = train_df[Config.LABEL_COLS].sum()
    neg_counts = len(train_df) - pos_counts
    pos_weights = neg_counts / (pos_counts + 1e-8)  # Add small epsilon
    pos_weights = np.minimum(pos_weights, 100.0)  # Cap weights at 100
    pos_weights = torch.tensor(pos_weights, dtype=torch.float32)  # Convert to tensor
    print("Class weights (capped at 100):", pos_weights)

    # DATASET INTEGRITY CHECK
    print("\n🔍 Checking dataset integrity...")
    valid_series = 0
    invalid_series = []
    
    for series_id in train_df[Config.ID_COL]:
        series_path = os.path.join(Config.SERIES_DIR, series_id)
        if os.path.exists(series_path):
            dicom_files = [f for f in os.listdir(series_path) if f.endswith('.dcm')]
            if dicom_files:
                valid_series += 1
            else:
                invalid_series.append(f"No DICOMs: {series_id}")
        else:
            invalid_series.append(f"Missing path: {series_id}")
    
    success_rate = valid_series / len(train_df)
    print(f"📊 Dataset check: {valid_series}/{len(train_df)} series valid ({success_rate:.1%})")
    
    if success_rate < 0.7:
        print(f"🚨 WARNING: Only {success_rate:.1%} of series are accessible!")
        print("First few issues:")
        for issue in invalid_series[:5]:
            print(f"  - {issue}")
        print(f"⚠️  Training will proceed, but expect many corrupted DICOM errors")
    else:
        print(f"✅ Good dataset integrity ({success_rate:.1%} valid)")
        
    print()

    # Initialize corrupted series log
    with open('corrupted_series.txt', 'w') as f:
        f.write("# Corrupted series log - check this file to identify problematic DICOM series\n")
    print("📝 Initialized 'corrupted_series.txt' for logging corrupted series")
    
    # Create stratified group k-fold split
    # Use patient-level grouping to prevent data leakage
    train_df['patient_group'] = train_df['PatientID'] if 'PatientID' in train_df.columns else range(len(train_df))
    
    skf = StratifiedGroupKFold(n_splits=Config.N_FOLDS, shuffle=True, random_state=42)
    train_df['fold'] = -1
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(
        train_df, train_df['Aneurysm Present'], groups=train_df['patient_group']
    )):
        train_df.loc[val_idx, 'fold'] = fold
    
    # Initialize processor
    processor = AdvancedDICOMProcessor()
    
    # Train models for each fold
    fold_scores = []
    
    for fold in range(Config.N_FOLDS):
        print(f"\n{'='*50}")
        print(f"FOLD {fold + 1}/{Config.N_FOLDS}")
        print(f"{'='*50}")
        
        # Split data
        train_fold_df = train_df[train_df['fold'] != fold].reset_index(drop=True)
        val_fold_df = train_df[train_df['fold'] == fold].reset_index(drop=True)
        
        print(f"Train: {len(train_fold_df)}, Validation: {len(val_fold_df)}")
        
        # Create datasets
        train_dataset = EnhancedAneurysmDataset(
            train_fold_df, localizer_df, Config.SERIES_DIR, processor, mode='train', fold=fold
        )
        val_dataset = EnhancedAneurysmDataset(
            val_fold_df, localizer_df, Config.SERIES_DIR, processor, mode='val', fold=fold
        )
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, 
            num_workers=4, pin_memory=True, drop_last=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, 
            num_workers=4, pin_memory=True
        )
        
        # Initialize model
        model = SimplifiedAneurysmNet().to(Config.DEVICE)
        criterion = WeightedMultiLabelLoss(pos_weights=pos_weights)
        
        # Optimizer with different learning rates for different parts
        # optimizer = optim.AdamW([
        #     {'params': model.backbone.parameters(), 'lr': Config.LEARNING_RATE},
        #     {'params': model.classifier.parameters(), 'lr': Config.LEARNING_RATE * 2},
        #     {'params': model.metadata_mlp.parameters(), 'lr': Config.LEARNING_RATE * 2}
        # ], weight_decay=Config.WEIGHT_DECAY)
        optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)
        
        # Learning rate scheduler
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=2, eta_min=1e-6
        )
        
        scaler = GradScaler(enabled=Config.MIXED_PRECISION)
        
        # Training loop
        best_auc = 0
        patience = 10
        patience_counter = 0
        
        for epoch in range(Config.EPOCHS):
            # Train
            train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler, Config.DEVICE)
            
            # Validate
            val_loss, val_auc, individual_aucs = validate_epoch(model, val_loader, criterion, Config.DEVICE)
            
            # Step scheduler
            scheduler.step()
            
            print(f"Epoch {epoch+1:3d} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"Val AUC: {val_auc:.4f}")

            processor.print_stats() # Print DICOM loading stats after each epoch

            # Check corrupted series log
            try:
                with open('corrupted_series.txt', 'r') as f:
                    lines = f.readlines()
                    corrupted_count = len([l for l in lines if not l.startswith('#')])
                    if corrupted_count > 0:
                        print(f"📄 Corrupted series logged: {corrupted_count} entries in 'corrupted_series.txt'")
                    else:
                        print(f"✅ No corrupted series logged this epoch")
            except FileNotFoundError:
                print(f"📄 No corrupted series log file found")

            # SANITY CHECK: Stop if data loading is fundamentally broken
            if processor.stats['total_loaded'] > 20:  # Only check after some attempts
                success_rate = processor.stats['successful_loads'] / processor.stats['total_loaded']
                if success_rate < 0.5:  # Less than 50% success rate
                    print(f"\n🚨 STOPPING TRAINING: Data loading success rate is {success_rate:.1%}")
                    print("Fix the DICOM loading issues before continuing training!")
                    print("Most volumes are returning empty - this is a waste of time!")
                    break
            
            # Save best model
            if val_auc > best_auc:
                best_auc = val_auc
                patience_counter = 0
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_auc': val_auc,
                    'epoch': epoch,
                    'fold': fold,
                    'individual_aucs': individual_aucs
                }, f'best_model_fold_{fold}.pth')
            else:
                patience_counter += 1
                
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
                
            # Memory cleanup
            if epoch % 5 == 0:
                gc.collect()
                torch.cuda.empty_cache()
        
        fold_scores.append(best_auc)
        print(f"Fold {fold + 1} best AUC: {best_auc:.4f}")
    
    # Final results
    mean_cv_score = np.mean(fold_scores)
    std_cv_score = np.std(fold_scores)
    
    print(f"\n{'='*50}")
    print(f"CROSS-VALIDATION RESULTS")
    print(f"{'='*50}")
    print(f"Mean CV AUC: {mean_cv_score:.4f} ± {std_cv_score:.4f}")
    print(f"Individual fold scores: {fold_scores}")
    
    # Save training summary
    results = {
        'cv_scores': fold_scores,
        'mean_cv_score': mean_cv_score,
        'std_cv_score': std_cv_score,
        'config': vars(Config())
    }
    
    with open('training_results.json', 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    print("Training complete! Models saved as 'best_model_fold_X.pth'")

if __name__ == "__main__":
    main()

Using device: cuda
Mixed precision: True
---!!! RUNNING IN DEBUG MODE ON A SMALL SUBSET !!!---
Training samples: 100 (limited for speed testing)
Positive aneurysm cases: 48

🔍 Checking dataset integrity...
📊 Dataset check: 100/100 series valid (100.0%)
✅ Good dataset integrity (100.0% valid)

📝 Initialized 'corrupted_series.txt' for logging corrupted series


UnboundLocalError: cannot access local variable 'pos_weights' where it is not associated with a value