<a href="https://www.kaggle.com/code/nicholas33/02-aneurysmnet-cnn-intracranial-training-nb153?scriptVersionId=254347897" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
!pip install monai

# ====================================================
# RSNA INTRACRANIAL ANEURYSM DETECTION - TRAINING PIPELINE
# ====================================================

import os
import gc
import warnings
import json
import time
import numpy as np
import pandas as pd
from typing import Tuple, Dict, List
from collections import Counter
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import roc_auc_score
import albumentations as A

warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler

import pydicom
import pydicom.errors
from scipy import ndimage
import nibabel as nib
from monai.transforms import (
    Compose, RandRotate90d, RandFlipd, RandAffined,
    RandGaussianNoised, RandAdjustContrastd, ToTensord
)
from monai.networks.nets import BasicUNet
from monai.losses import DiceCELoss, FocalLoss
from tqdm import tqdm


Collecting monai
  Downloading monai-1.5.0-py3-none-any.whl.metadata (13 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<2.7.0,>=2.4.1->monai)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch<2.7.0,

2025-08-05 12:17:40.631931: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754396260.854853      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754396260.921674      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# ====================================================
# CONFIGURATION
# ====================================================

class Config:
    # Paths
    TRAIN_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train.csv'
    LOCALIZER_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train_localizers.csv'
    SERIES_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/series/'
    SEGMENTATION_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/segmentations/'
    
    # Model parameters
    TARGET_SIZE = (32, 64, 64)  # Increased resolution
    EPOCHS = 2
    BATCH_SIZE = 16  # Reduced due to larger input size
    LEARNING_RATE = 1e-3
    WEIGHT_DECAY = 1e-4
    N_FOLDS = 3
    
    # Training parameters
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    MIXED_PRECISION = True
    GRADIENT_ACCUMULATION = 4
    
    # Competition constants
    ID_COL = 'SeriesInstanceUID'
    LABEL_COLS = [
        'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery',
        'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
        'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery', 'Anterior Communicating Artery',
        'Left Anterior Cerebral Artery', 'Right Anterior Cerebral Artery',
        'Left Posterior Communicating Artery', 'Right Posterior Communicating Artery',
        'Basilar Tip', 'Other Posterior Circulation', 'Aneurysm Present',
    ]
    
    # Class weights for imbalanced data
    ANEURYSM_PRESENT_WEIGHT = 13.0  # Match evaluation metric weighting

# ====================================================
# ENHANCED DATA PREPROCESSING
# ====================================================

class AdvancedDICOMProcessor:
    def __init__(self, target_size: Tuple[int, int, int] = Config.TARGET_SIZE):
        self.target_size = target_size
        self.stats = {
            'total_loaded': 0,
            'successful_loads': 0,
            'shape_errors': 0,
            'empty_volumes': 0,
            'preprocessing_errors': 0,
            'invalid_dicom_files': 0,
            'empty_pixel_arrays': 0,
            'corrupted_pixel_data': 0
        }
        
        # Initialize detailed corruption log
        self.corruption_log_file = 'detailed_corruption_log.txt'
        self._init_corruption_log()
    
    def _init_corruption_log(self):
        """Initialize detailed corruption logging"""
        try:
            with open(self.corruption_log_file, 'w') as f:
                f.write("# DETAILED DICOM CORRUPTION LOG\n")
                f.write("# This file tracks specific reasons for DICOM corruption\n")
                f.write("# Format: [TIMESTAMP] SeriesID | Error Type | Details\n")
                f.write("# Error Types: NO_FILES, INVALID_DICOM, EMPTY_PIXELS, CORRUPTED_DATA, SHAPE_MISMATCH\n\n")
        except Exception as e:
            print(f"⚠️  Could not initialize corruption log: {e}")
    
    def _log_corruption(self, series_id: str, error_type: str, details: str):
        """Log corruption details with timestamp"""
        try:
            timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
            log_entry = f"[{timestamp}] {series_id} | {error_type} | {details}\n"
            with open(self.corruption_log_file, 'a') as f:
                f.write(log_entry)
                f.flush()
        except Exception as e:
            print(f"⚠️  Could not log corruption: {e}")

    def _detect_localizer(self, ds, pixel_data, filename):
        """Detect if a DICOM is likely a localizer/scout image"""
        try:
            # Check various indicators that suggest this is a localizer
            indicators = {
                'small_slice_count': pixel_data.shape[0] < 50,  # Localizers usually have few slices
                'filename_indicator': any(term in filename.lower() for term in ['scout', 'localizer', 'topo']),
                'series_description': False,
                'image_type': False
            }
            
            # Check series description if available
            if hasattr(ds, 'SeriesDescription'):
                desc = str(ds.SeriesDescription).lower()
                indicators['series_description'] = any(term in desc for term in 
                    ['scout', 'localizer', 'topo', 'survey', 'plan'])
            
            # Check image type if available
            if hasattr(ds, 'ImageType'):
                img_type = str(ds.ImageType).lower() if isinstance(ds.ImageType, str) else str(ds.ImageType)
                indicators['image_type'] = any(term in img_type for term in ['localizer', 'scout'])
            
            # Consider it a localizer if any indicator is True
            is_localizer = any(indicators.values())
            
            if is_localizer:
                indicator_details = [k for k, v in indicators.items() if v]
                print(f"    🎯 Localizer indicators: {', '.join(indicator_details)}")
            
            return is_localizer
            
        except Exception as e:
            print(f"    ⚠️  Error detecting localizer: {e}")
            return False  # Default to treating as regular volume
        
    def load_dicom_series(self, series_path: str) -> Tuple[np.ndarray, Dict]:
        """Load DICOM series with comprehensive error tracking and logging"""
        self.stats['total_loaded'] += 1
        series_id = os.path.basename(series_path)
        
        try:
            # Check if directory exists
            if not os.path.exists(series_path):
                self._log_corruption(series_id, "NO_PATH", f"Directory does not exist: {series_path}")
                self.stats['empty_volumes'] += 1
                return self._get_fallback_volume(), {}
            
            # Get DICOM files
            all_files = os.listdir(series_path)
            dicom_files = [os.path.join(series_path, f) for f in all_files if f.endswith('.dcm')]
            
            if not dicom_files:
                self._log_corruption(series_id, "NO_FILES", f"No .dcm files found. Directory contains: {len(all_files)} files")
                print(f"❌ No DICOM files found in {series_path}, using mean volume fallback")
                self.stats['empty_volumes'] += 1
                return self._get_fallback_volume(), {}
            
            print(f"🔍 Loading series {series_id}: Found {len(dicom_files)} DICOM files")
            
            # Track detailed corruption stats for this series
            series_stats = {
                'total_files': len(dicom_files),
                'invalid_dicom': 0,
                'empty_pixels': 0,
                'corrupted_pixels': 0,
                'valid_dicoms': 0,
                'pixel_read_errors': 0,
                'localizer_converted': 0,
                'slices_extracted': 0
            }
            
            dicoms = []
            detailed_errors = []
            
            for i, dicom_file in enumerate(dicom_files):
                filename = os.path.basename(dicom_file)
                try:
                    # Try to read DICOM
                    ds = pydicom.dcmread(dicom_file, force=True)
                    
                    # Check if pixel array exists
                    if not hasattr(ds, 'pixel_array'):
                        series_stats['empty_pixels'] += 1
                        detailed_errors.append(f"  {filename}: No pixel_array attribute")
                        continue
                    
                    # Try to access pixel array
                    try:
                        pixel_data = ds.pixel_array
                        if pixel_data is None:
                            series_stats['empty_pixels'] += 1
                            detailed_errors.append(f"  {filename}: pixel_array is None")
                            continue
                        elif pixel_data.size == 0:
                            series_stats['empty_pixels'] += 1
                            detailed_errors.append(f"  {filename}: pixel_array is empty (size=0)")
                            continue
                        elif pixel_data.ndim == 2:
                            # Standard 2D slice - this is what we want
                            dicoms.append(ds)
                            series_stats['valid_dicoms'] += 1
                        elif pixel_data.ndim == 3:
                            # 3D volume in single DICOM (localizer/reconstructed volume)
                            # Check if this might be a localizer image
                            is_localizer = self._detect_localizer(ds, pixel_data, filename)
                            
                            if is_localizer:
                                print(f"  🔍 Detected localizer image in {filename}, extracting representative slices")
                                # For localizers, extract fewer slices (e.g., every 4th slice)
                                slice_step = max(1, pixel_data.shape[0] // 8)  # Get ~8 representative slices
                                selected_indices = range(0, pixel_data.shape[0], slice_step)
                            else:
                                print(f"  🔄 Converting 3D volume {pixel_data.shape} to 2D slices in {filename}")
                                # For regular 3D volumes, extract all slices
                                selected_indices = range(pixel_data.shape[0])
                            
                            # Create pseudo-DICOM objects for selected slices
                            slices_added = 0
                            for slice_idx in selected_indices:
                                slice_data = pixel_data[slice_idx]
                                if slice_data.size > 0:
                                    # Create a copy of the DICOM metadata for this slice
                                    slice_ds = pydicom.Dataset()
                                    slice_ds.pixel_array = slice_data
                                    # Copy important metadata
                                    for attr in ['Modality', 'PixelSpacing', 'SliceThickness', 
                                               'RescaleSlope', 'RescaleIntercept']:
                                        if hasattr(ds, attr):
                                            setattr(slice_ds, attr, getattr(ds, attr))
                                    # Set unique instance number for each slice
                                    slice_ds.InstanceNumber = slice_idx + 1
                                    dicoms.append(slice_ds)
                                    series_stats['valid_dicoms'] += 1
                                    slices_added += 1
                            
                            series_stats['localizer_converted'] += 1 if is_localizer else 0
                            series_stats['slices_extracted'] += slices_added
                            
                            if is_localizer:
                                print(f"  ✅ Extracted {slices_added} representative slices from localizer ({pixel_data.shape[0]} total)")
                            else:
                                print(f"  ✅ Extracted {slices_added} slices from 3D volume")
                        else:
                            series_stats['corrupted_pixels'] += 1
                            detailed_errors.append(f"  {filename}: Unsupported dimensions {pixel_data.shape} (ndim={pixel_data.ndim})")
                            continue
                            
                    except Exception as pixel_error:
                        series_stats['pixel_read_errors'] += 1
                        detailed_errors.append(f"  {filename}: Pixel read error - {str(pixel_error)}")
                        continue
                        
                except pydicom.errors.InvalidDicomError as dicom_error:
                    series_stats['invalid_dicom'] += 1
                    detailed_errors.append(f"  {filename}: Invalid DICOM - {str(dicom_error)}")
                    continue
                except Exception as e:
                    series_stats['invalid_dicom'] += 1
                    detailed_errors.append(f"  {filename}: Read error - {str(e)}")
                    continue
            
            # Log detailed results
            corruption_details = (f"Files: {series_stats['total_files']}, "
                                f"Valid: {series_stats['valid_dicoms']}, "
                                f"Invalid: {series_stats['invalid_dicom']}, "
                                f"Empty pixels: {series_stats['empty_pixels']}, "
                                f"Corrupted pixels: {series_stats['corrupted_pixels']}, "
                                f"Pixel errors: {series_stats['pixel_read_errors']}, "
                                f"Localizers converted: {series_stats['localizer_converted']}, "
                                f"Slices extracted: {series_stats['slices_extracted']}")
            
            if series_stats['valid_dicoms'] == 0:
                # Complete failure - log everything
                error_summary = f"TOTAL_FAILURE: {corruption_details}"
                self._log_corruption(series_id, "TOTAL_FAILURE", error_summary)
                
                print(f"❌ Complete failure for series {series_id}:")
                print(f"   📊 {corruption_details}")
                if detailed_errors:
                    print("   🔍 First 5 errors:")
                    for error in detailed_errors[:5]:
                        print(error)
                    
                self.stats['empty_volumes'] += 1
                return self._get_fallback_volume(), {}
            
            elif series_stats['valid_dicoms'] < series_stats['total_files'] * 0.5:
                # Partial failure - log but continue
                self._log_corruption(series_id, "PARTIAL_FAILURE", corruption_details)
                print(f"⚠️  Partial corruption in series {series_id}: {corruption_details}")
            elif series_stats['localizer_converted'] > 0:
                # Successfully converted localizer images
                self._log_corruption(series_id, "LOCALIZER_CONVERTED", corruption_details)
                print(f"🎯 Successfully converted localizer series {series_id}: {corruption_details}")
                
            # Continue with valid DICOMs
            print(f"✅ Successfully loaded {len(dicoms)} valid DICOMs from series {series_id}")
            
            # Extract metadata from first valid DICOM
            first_ds = dicoms[0]
            metadata = {
                'modality': getattr(first_ds, 'Modality', 'UNKNOWN'),
                'spacing': getattr(first_ds, 'PixelSpacing', [1.0, 1.0]),
                'slice_thickness': getattr(first_ds, 'SliceThickness', 1.0),
                'rescale_slope': getattr(first_ds, 'RescaleSlope', 1.0),
                'rescale_intercept': getattr(first_ds, 'RescaleIntercept', 0.0),
            }
            
            # Sort by instance number 
            dicoms.sort(key=lambda x: int(getattr(x, 'InstanceNumber', 0)))
            
            # Process pixel arrays
            pixel_arrays = []
            shapes = []
            pixel_processing_errors = 0
            
            for d in dicoms:
                try:
                    arr = d.pixel_array
                    if arr.ndim == 2 and arr.size > 0:
                        pixel_arrays.append(arr)
                        shapes.append(arr.shape)
                    else:
                        pixel_processing_errors += 1
                except Exception as e:
                    pixel_processing_errors += 1
                    continue
            
            if len(pixel_arrays) == 0:
                self._log_corruption(series_id, "PIXEL_PROCESSING_FAILURE", 
                                   f"All {len(dicoms)} DICOMs failed pixel processing")
                print(f"❌ No valid pixel arrays in series (corrupted DICOM), using mean volume fallback")
                self.stats['shape_errors'] += 1
                return self._get_fallback_volume(), metadata
            
            if pixel_processing_errors > 0:
                self._log_corruption(series_id, "PIXEL_PROCESSING_PARTIAL", 
                                   f"{pixel_processing_errors} DICOMs failed pixel processing out of {len(dicoms)}")
            
            # Handle shape consistency
            unique_shapes = list(set(shapes))
            if len(unique_shapes) == 1:
                # All same shape - direct stacking
                volume = np.stack(pixel_arrays, axis=0).astype(np.float32)
                print(f"✅ Consistent shapes: {unique_shapes[0]} across {len(pixel_arrays)} slices")
            else:
                # Multiple shapes - resize to most common
                most_common_shape = Counter(shapes).most_common(1)[0][0]
                shape_counter = Counter(shapes)
                shape_breakdown = ", ".join([f"{shape}: {count}" for shape, count in shape_counter.most_common(3)])
                shape_details = f"Found {len(unique_shapes)} different shapes. Breakdown: {shape_breakdown}"
                self._log_corruption(series_id, "SHAPE_MISMATCH", shape_details)
                print(f"⚠️  Shape inconsistency in {series_id}: {shape_details}")
                print(f"   📐 Resizing all slices to most common shape: {most_common_shape}")
                
                resized_arrays = []
                for arr in pixel_arrays:
                    if arr.shape == most_common_shape:
                        resized_arrays.append(arr.astype(np.float32))
                    else:
                        zoom_factors = (most_common_shape[0] / arr.shape[0], 
                                      most_common_shape[1] / arr.shape[1])
                        resized_arr = ndimage.zoom(arr, zoom_factors, order=1, prefilter=False)
                        resized_arrays.append(resized_arr.astype(np.float32))
                volume = np.stack(resized_arrays, axis=0).astype(np.float32)
            
            # Log successful loads with details
            if self.stats['total_loaded'] <= 10:
                print(f"✅ Final volume: {volume.shape} from {len(pixel_arrays)} slices")
            
            # Apply rescale if available
            if metadata['rescale_slope'] != 1.0 or metadata['rescale_intercept'] != 0.0:
                volume = volume * metadata['rescale_slope'] + metadata['rescale_intercept']

            self.stats['successful_loads'] += 1
            return volume, metadata
            
        except Exception as e:
            # Log unexpected errors
            error_details = f"Unexpected error: {str(e)}"
            self._log_corruption(series_id, "UNEXPECTED_ERROR", error_details)
            print(f"❌ Unexpected error loading {series_path}: {e}, using mean volume fallback")
            self.stats['shape_errors'] += 1
            return self._get_fallback_volume(), {}

    def _get_fallback_volume(self):
        """Get mean volume fallback or zeros if no dataset reference"""
        if hasattr(self, 'dataset') and hasattr(self.dataset, 'mean_volume'):
            return self.dataset.mean_volume.copy()
        return np.zeros(self.target_size, dtype=np.float32)

    def print_stats(self):
        """Print comprehensive loading statistics"""
        total = self.stats['total_loaded']
        successful = self.stats['successful_loads']
        empty = self.stats['empty_volumes']
        shape_errors = self.stats['shape_errors']
        invalid_dicom = self.stats['invalid_dicom_files']
        empty_pixels = self.stats['empty_pixel_arrays']
        corrupted_pixels = self.stats['corrupted_pixel_data']
        
        if total > 0:
            success_rate = (successful / total) * 100
            print(f"\n📊 === COMPREHENSIVE DICOM LOADING STATS ===")
            print(f"📈 Total attempts: {total}")
            print(f"✅ Successful loads: {successful}/{total} ({success_rate:.1f}%)")
            print(f"❌ Failed loads breakdown:")
            print(f"   🚫 Empty volumes: {empty} ({empty/total*100:.1f}%)")
            print(f"   📄 Invalid DICOM files: {invalid_dicom} ({invalid_dicom/total*100:.1f}%)")
            print(f"   🗂️  Empty pixel arrays: {empty_pixels} ({empty_pixels/total*100:.1f}%)")
            print(f"   💥 Corrupted pixel data: {corrupted_pixels} ({corrupted_pixels/total*100:.1f}%)")
            print(f"   📐 Shape errors: {shape_errors} ({shape_errors/total*100:.1f}%)")
            
            # Overall assessment
            total_failed = empty + invalid_dicom + empty_pixels + corrupted_pixels + shape_errors
            print(f"📊 Total corruption rate: {total_failed/total*100:.1f}%")
            
            if success_rate < 70:
                print(f"🚨 SUCCESS RATE TOO LOW ({success_rate:.1f}%)!")
                print(f"   🔍 Check detailed_corruption_log.txt for specific issues")
                print(f"   📋 Most common issues likely in the corruption log")
            elif success_rate < 85:
                print(f"⚠️  Moderate success rate ({success_rate:.1f}%) - some data quality issues")
                print(f"   💡 Review detailed_corruption_log.txt for improvement opportunities")
            else:
                print(f"✅ Good success rate ({success_rate:.1f}%)")
            
            print(f"📄 Detailed analysis available in: {self.corruption_log_file}")
            print(f"===============================")
    
    def analyze_corruption_patterns(self):
        """Analyze and summarize corruption patterns from the log"""
        try:
            if not os.path.exists(self.corruption_log_file):
                print("📄 No corruption log file found")
                return
            
            error_types = {}
            series_with_errors = set()
            
            with open(self.corruption_log_file, 'r') as f:
                for line in f:
                    if line.startswith('#') or not line.strip():
                        continue
                    
                    try:
                        # Parse: [TIMESTAMP] SeriesID | Error Type | Details
                        parts = line.strip().split(' | ')
                        if len(parts) >= 2:
                            series_id = parts[0].split('] ')[1] if '] ' in parts[0] else parts[0]
                            error_type = parts[1]
                            
                            series_with_errors.add(series_id)
                            error_types[error_type] = error_types.get(error_type, 0) + 1
                    except:
                        continue
            
            if error_types:
                print(f"\n🔍 === CORRUPTION PATTERN ANALYSIS ===")
                print(f"📊 Unique series with errors: {len(series_with_errors)}")
                print(f"🏷️  Error type breakdown:")
                
                sorted_errors = sorted(error_types.items(), key=lambda x: x[1], reverse=True)
                for error_type, count in sorted_errors:
                    print(f"   {error_type}: {count} occurrences")
                
                print(f"💡 Most common issue: {sorted_errors[0][0]} ({sorted_errors[0][1]} cases)")
                print(f"===============================")
        
        except Exception as e:
            print(f"⚠️  Could not analyze corruption patterns: {e}")

    def preprocess_volume(self, volume: np.ndarray, metadata: Dict) -> np.ndarray:
        """Enhanced preprocessing with modality-specific handling"""
        if volume.ndim != 3 or volume.size == 0:
            print(f"Warning: Received a non-3D volume. Returning empty target volume.")
            return np.zeros(self.target_size, dtype=np.float32)
        
        # Default windowing
        p1, p99 = np.percentile(volume, [5, 95])
        volume = np.clip(volume, p1, p99)
        
        # Normalization
        vol_min, vol_max = volume.min(), volume.max()
        if vol_max > vol_min:
            volume = (volume - vol_min) / (vol_max - vol_min)
        
        # Resize to target size
        if volume.shape != self.target_size:
            zoom_factors = [self.target_size[i] / volume.shape[i] for i in range(3)]
            volume = ndimage.zoom(volume, zoom_factors, order=1, prefilter=False)
        
        return volume.astype(np.float32)

    def load_localization_mask(self, series_id: str, localizer_df: pd.DataFrame) -> np.ndarray:
        return np.zeros(self.target_size, dtype=np.float32)


# ====================================================
# ENHANCED DATASET
# ====================================================

class EnhancedAneurysmDataset(Dataset):
    def __init__(self, df: pd.DataFrame, localizer_df: pd.DataFrame, 
                 series_dir: str, processor: AdvancedDICOMProcessor, 
                 mode: str = 'train', fold: int = None, shared_mean_volume: np.ndarray = None):
        self.df = df
        self.localizer_df = localizer_df
        self.series_dir = series_dir
        self.processor = processor
        self.mode = mode
        self.fold = fold
        
        # Data augmentation for training
        if mode == 'train':
            self.transform = Compose([
                # Lightweight augmentations to improve generalization
                RandRotate90d(keys=['volume'], prob=0.2, spatial_axes=(0, 1)),  # 90-degree rotations only
                RandFlipd(keys=['volume'], prob=0.5, spatial_axis=0),           # Axial flip
                RandFlipd(keys=['volume'], prob=0.5, spatial_axis=1),           # Sagittal flip
                RandGaussianNoised(keys=['volume'], prob=0.2, std=0.03),        # Very low noise level
                ToTensord(keys=['volume'])
            ])
        else:
            self.transform = Compose([ToTensord(keys=['volume'])])

        # Set dataset reference in processor for mean volume access
        self.processor.dataset = self

         # Use shared mean volume or compute new one
        if shared_mean_volume is not None:
            self.mean_volume = shared_mean_volume.copy()
            print(f"🔄 Using shared mean volume fallback (shape: {self.mean_volume.shape})")
        else:
            # Precompute mean volume for fallback on corrupted DICOM files
            print("🔄 Computing mean volume fallback from valid series...")
            valid_volumes = []
            sample_size = min(10, len(df))  # Sample 10 series for speed
            
            for i, series_id in enumerate(df[Config.ID_COL][:sample_size]):
                series_path = os.path.join(series_dir, series_id)
                try:
                    volume, _ = processor.load_dicom_series(series_path)
                    if not np.all(volume == 0) and volume.size > 0:
                        volume = processor.preprocess_volume(volume, {})
                        valid_volumes.append(volume)
                        print(f"  ✅ Valid volume {i+1}/{sample_size}: {volume.shape}")
                except Exception as e:
                    print(f"  ❌ Skipped corrupted volume {i+1}/{sample_size}: {e}")
                    continue
            
            if valid_volumes:
                self.mean_volume = np.mean(valid_volumes, axis=0).astype(np.float32)
                print(f"📊 Mean volume computed from {len(valid_volumes)} valid series: {self.mean_volume.shape}")
            else:
                self.mean_volume = np.zeros(processor.target_size, dtype=np.float32)
                print("⚠️  No valid volumes found, using zero fallback")
            
            print(f"🎯 Mean volume fallback ready (shape: {self.mean_volume.shape})")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        load_start = time.time()
        row = self.df.iloc[idx]
        series_id = row[Config.ID_COL]
        series_path = os.path.join(self.series_dir, series_id)
        
        # Load and process volume
        volume, metadata = self.processor.load_dicom_series(series_path)
        dicom_time = time.time() - load_start
        preprocess_start = time.time()
        volume = self.processor.preprocess_volume(volume, metadata)
        
        # Create localization mask (for auxiliary loss)
        loc_mask = self.processor.load_localization_mask(series_id, self.localizer_df)
        
        # Get labels
        labels = row[Config.LABEL_COLS].values.astype(np.float32)
        
        # Apply transforms
        data_dict = {'volume': volume}
        if self.transform:
            data_dict = self.transform(data_dict)
        
        volume_tensor = data_dict['volume'].unsqueeze(0)  # Add channel dimension
        loc_mask_tensor = torch.from_numpy(loc_mask).unsqueeze(0)
        labels_tensor = torch.from_numpy(labels)
        
        # Add metadata features
        modality_encoding = self._encode_modality(metadata.get('modality', 'UNKNOWN'))
        metadata_tensor = torch.tensor(modality_encoding, dtype=torch.float32)
        
        preprocess_time = time.time() - preprocess_start
        # Print timing for first few samples to debug
        if idx < 5:
            print(f"Sample {idx}: DICOM load: {dicom_time:.2f}s, Preprocess: {preprocess_time:.2f}s")
        
        return {
            'volume': volume_tensor,
            'localization_mask': loc_mask_tensor,
            'labels': labels_tensor,
            'metadata': metadata_tensor,
            'series_id': series_id
        }
    
    def _encode_modality(self, modality: str) -> List[float]:
        """One-hot encode modality"""
        modalities = ['CTA', 'MRA', 'MRI', 'MR', 'UNKNOWN']
        encoding = [0.0] * len(modalities)
        if modality in modalities:
            encoding[modalities.index(modality)] = 1.0
        else:
            encoding[-1] = 1.0  # UNKNOWN
        return encoding


# ====================================================
# ADVANCED MODEL ARCHITECTURE
# ====================================================

#class MultiModalAneurysmNet(nn.Module):
class SimplifiedAneurysmNet(nn.Module):
    def __init__(self, num_classes: int = len(Config.LABEL_COLS), 
                 spatial_dims: int = 3, in_channels: int = 1, 
        #          features: Tuple = (32, 64, 128, 256, 512, 1024)):
        # super(MultiModalAneurysmNet, self).__init__()
                 features: Tuple = (16, 32, 64, 128, 256, 51)):
        super(SimplifiedAneurysmNet, self).__init__()
        
        # Main 3D U-Net backbone
        self.backbone = BasicUNet(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=features[0],
            features=features,
            dropout=0.1 #Reduced dropout 
        )
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        
        # # Metadata processing
        # self.metadata_mlp = nn.Sequential(
        #     nn.Linear(5, 32),  # 5 modality categories
        #     nn.ReLU(),
        #     nn.Dropout(0.3),
        #     nn.Linear(32, 64),
        #     nn.ReLU()
        # )
        
        # Classification head
        #feature_size = features[0] + 64  # backbone features + metadata features
        self.classifier = nn.Sequential(
            # nn.Linear(feature_size, 512),
            # nn.ReLU(),
            # nn.Dropout(0.5),
            # nn.Linear(512, 256),
            nn.Linear(features[0], 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, volume, metadata=None):
        # Extract features from 3D volume
        features = self.backbone(volume)
        # Global features for classification
        global_features = self.global_pool(features).flatten(1)
        classification_logits = self.classifier(global_features)
        return classification_logits, None

# ====================================================
# WEIGHTED LOSS FUNCTION
# ====================================================

class WeightedMultiLabelLoss(nn.Module):
    def __init__(self, pos_weights=None, aneurysm_weight=13.0):
        super().__init__()
        self.pos_weights = pos_weights
        self.aneurysm_weight = aneurysm_weight
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
        
    def forward(self, logits, targets):
        bce_loss = self.bce(logits, targets)
        
        # Apply position weights if provided
        if self.pos_weights is not None:
            bce_loss = bce_loss * self.pos_weights.to(logits.device)
        
        # Weight the "Aneurysm Present" class higher (last column)
        weights = torch.ones_like(bce_loss)
        weights[:, -1] = self.aneurysm_weight
        
        weighted_loss = bce_loss * weights
        return weighted_loss.mean()


# ====================================================
# TRAINING FUNCTIONS
# ====================================================

def compute_weighted_auc(y_true, y_pred):
    """Compute weighted AUC matching competition metric"""
    aucs = []
    weights = []
    
    for i in range(len(Config.LABEL_COLS)):
        try:
            auc = roc_auc_score(y_true[:, i], y_pred[:, i])
            aucs.append(auc)
            # Weight "Aneurysm Present" (last column) higher
            weights.append(13.0 if i == len(Config.LABEL_COLS) - 1 else 1.0)
        except ValueError:
            aucs.append(0.5)  # Default for no positive cases
            weights.append(13.0 if i == len(Config.LABEL_COLS) - 1 else 1.0)
    
    weighted_auc = sum(a * w for a, w in zip(aucs, weights)) / sum(weights)
    return weighted_auc, aucs

def train_epoch(model, train_loader, optimizer, criterion, scaler, device):
    model.train()
    total_loss = 0
    num_batches = 0
    skipped_batches = 0
    
    # Device verification for debugging
    print(f"🔧 Training on device: {device}")
    if hasattr(model, 'module'):  # DataParallel wrapped
        print(f"🔧 Model device (DataParallel): {next(model.module.parameters()).device}")
    else:
        print(f"🔧 Model device: {next(model.parameters()).device}")
    print(f"🔧 Criterion on GPU: {hasattr(criterion, 'pos_weights') and criterion.pos_weights.device if hasattr(criterion, 'pos_weights') else 'N/A'}")
    
    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Training Epoch")):
        start_time = time.time()
        # Transfer data to GPU with non-blocking for better performance
        volume = batch['volume'].to(device, non_blocking=True)
        metadata = batch['metadata'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)
        loc_mask = batch['localization_mask'].to(device, non_blocking=True)
        
        # Add GPU monitoring for first few batches
        if batch_idx < 3:
            gpu_mem_before = torch.cuda.memory_allocated(0) / 1e9
            print(f"  🔍 Batch {batch_idx}: GPU memory before forward: {gpu_mem_before:.2f}GB")

        # CRITICAL FIX: Skip batches with zero-filled volumes
        if torch.all(volume == 0) or torch.var(volume) < 1e-6:
            skipped_batches += 1
            if batch_idx < 5:  # Log first few skips
                print(f"⚠️  Skipping batch {batch_idx}: zero-filled or low-variance volume")
            continue
            
        # Forward pass timing
        forward_start = time.time()
        with autocast(device_type=device.type, enabled=Config.MIXED_PRECISION):
            class_logits, _ = model(volume, metadata)
            total_loss_batch = criterion(class_logits, labels)
        forward_time = time.time() - forward_start
        
        # Add detailed timing for first few batches
        if batch_idx < 3:
            gpu_mem_after = torch.cuda.memory_allocated(0) / 1e9
            print(f"  ⚡ Batch {batch_idx}: Forward pass: {forward_time:.3f}s, GPU memory after: {gpu_mem_after:.2f}GB")
            
        # DEBUG: Check for extreme loss values
        if total_loss_batch.item() > 1e7:
            print(f"🚨 Warning: Extreme loss in training batch {batch_idx}: {total_loss_batch.item():.2e}")
            print(f"   Labels: {labels[0].cpu().numpy()}")  # Print first sample's labels 
        
        # Gradient accumulation and backward pass timing
        backward_start = time.time()
        scaled_loss = total_loss_batch / Config.GRADIENT_ACCUMULATION
        scaler.scale(scaled_loss).backward()
        backward_time = time.time() - backward_start
        
        if batch_idx < 3:
            print(f"  🔄 Batch {batch_idx}: Backward pass: {backward_time:.3f}s")
        
        if (batch_idx + 1) % Config.GRADIENT_ACCUMULATION == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        total_loss += total_loss_batch.item()
        num_batches += 1

        # Print timing for first few batches to identify bottlenecks
        if batch_idx < 5:
            batch_time = time.time() - start_time
            print(f"Batch {batch_idx}: {batch_time:.2f}s")
    
    if skipped_batches > 0:
        print(f"⚠️  Skipped {skipped_batches} batches with corrupted/zero volumes")
    
    return total_loss / max(num_batches, 1) if num_batches > 0 else float('inf')

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    num_batches = 0
    skipped_batches = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(val_loader, desc="Validating")):
            volume = batch['volume'].to(device)
            metadata = batch['metadata'].to(device)
            labels = batch['labels'].to(device)
            
            # CRITICAL FIX: Skip batches with zero-filled volumes
            if torch.all(volume == 0) or torch.var(volume) < 1e-6:
                skipped_batches += 1
                if batch_idx < 3:  # Log first few skips
                    print(f"⚠️  Skipping validation batch {batch_idx}: zero-filled or low-variance volume")
                continue
            
            with autocast(device_type=device.type, enabled=Config.MIXED_PRECISION):
                class_logits, _ = model(volume, metadata)
                loss = criterion(class_logits, labels)

                # DEBUG: Check for extreme loss values
                if loss.item() > 1e7:
                    print(f"🚨 Warning: Extreme loss in validation batch {batch_idx}: {loss.item():.2e}")
                    print(f"   Labels: {labels[0].cpu().numpy()}")  # Print first sample's labels
            
            total_loss += loss.item()
            num_batches += 1
            
            # Collect predictions for AUC calculation
            probs = torch.sigmoid(class_logits).cpu().numpy()
            all_preds.append(probs)
            all_labels.append(labels.cpu().numpy())

    if skipped_batches > 0:
        print(f"⚠️  Skipped {skipped_batches} validation batches with corrupted/zero volumes")
    
    if len(all_preds) == 0:
        print("🚨 WARNING: No valid validation batches - all were corrupted!")
        return float('inf'), 0.5, [0.5] * len(Config.LABEL_COLS)
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    weighted_auc, individual_aucs = compute_weighted_auc(all_labels, all_preds)
    
    return total_loss / max(num_batches, 1), weighted_auc, individual_aucs



In [None]:
# ====================================================
# MAIN TRAINING EXECUTION
# ====================================================

def main():
    print(f"Using device: {Config.DEVICE}")
    print(f"Mixed precision: {Config.MIXED_PRECISION}")
    if torch.cuda.is_available():
        print(f"GPU devices available: {torch.cuda.device_count()}")
        for i in range(torch.cuda.device_count()):
            gpu_mem = torch.cuda.get_device_properties(i).total_memory / 1e9
            print(f"  GPU {i}: {torch.cuda.get_device_name(i)} ({gpu_mem:.1f}GB)")
    print(f"Optimized settings - Batch size: {Config.BATCH_SIZE}, Workers: 6/4 (Conservative)")
    
    # Load data
    train_df = pd.read_csv(Config.TRAIN_CSV_PATH)
    localizer_df = pd.read_csv(Config.LOCALIZER_CSV_PATH)

    print("---!!! RUNNING IN DEBUG MODE ON A SMALL SUBSET !!!---")
    #print(f"Training samples: {len(train_df)}")
    train_df = train_df.head(100)  # Limit to 100 samples for speed testing
    print(f"Training samples: {len(train_df)} (limited for speed testing)")
    print(f"Positive aneurysm cases: {train_df['Aneurysm Present'].sum()}")

    # Calculate class weights for imbalanced data
    pos_counts = train_df[Config.LABEL_COLS].sum()
    neg_counts = len(train_df) - pos_counts
    pos_weights = neg_counts / (pos_counts + 1e-8)  # Add small epsilon
    pos_weights = np.minimum(pos_weights, 100.0)  # Cap weights at 100
    pos_weights = torch.tensor(pos_weights, dtype=torch.float32)  # Convert to tensor
    print("Class weights (capped at 100):", pos_weights)

    # DATASET INTEGRITY CHECK
    print("\n🔍 Checking dataset integrity...")
    valid_series = 0
    invalid_series = []
    
    for series_id in train_df[Config.ID_COL]:
        series_path = os.path.join(Config.SERIES_DIR, series_id)
        if os.path.exists(series_path):
            dicom_files = [f for f in os.listdir(series_path) if f.endswith('.dcm')]
            if dicom_files:
                valid_series += 1
            else:
                invalid_series.append(f"No DICOMs: {series_id}")
        else:
            invalid_series.append(f"Missing path: {series_id}")
    
    success_rate = valid_series / len(train_df)
    print(f"📊 Dataset check: {valid_series}/{len(train_df)} series valid ({success_rate:.1%})")
    
    if success_rate < 0.7:
        print(f"🚨 WARNING: Only {success_rate:.1%} of series are accessible!")
        print("First few issues:")
        for issue in invalid_series[:5]:
            print(f"  - {issue}")
        print(f"⚠️  Training will proceed, but expect many corrupted DICOM errors")
    else:
        print(f"✅ Good dataset integrity ({success_rate:.1%} valid)")
        
    print()

    # Initialize corrupted series log (legacy compatibility)
    with open('corrupted_series.txt', 'w') as f:
        f.write("# Corrupted series log - check this file to identify problematic DICOM series\n")
    print("📝 Initialized 'corrupted_series.txt' for logging corrupted series")
    print("🔍 Enhanced corruption tracking enabled with detailed analysis")
    
    # Create stratified group k-fold split
    # Use patient-level grouping to prevent data leakage
    train_df['patient_group'] = train_df['PatientID'] if 'PatientID' in train_df.columns else range(len(train_df))
    
    skf = StratifiedGroupKFold(n_splits=Config.N_FOLDS, shuffle=True, random_state=42)
    train_df['fold'] = -1
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(
        train_df, train_df['Aneurysm Present'], groups=train_df['patient_group']
    )):
        train_df.loc[val_idx, 'fold'] = fold
    
    # Initialize processor
    processor = AdvancedDICOMProcessor()
    
    # Train models for each fold
    fold_scores = []
    
    for fold in range(Config.N_FOLDS):
        print(f"\n{'='*50}")
        print(f"FOLD {fold + 1}/{Config.N_FOLDS}")
        print(f"{'='*50}")
        
        # Split data
        train_fold_df = train_df[train_df['fold'] != fold].reset_index(drop=True)
        val_fold_df = train_df[train_df['fold'] == fold].reset_index(drop=True)
        
        print(f"Train: {len(train_fold_df)}, Validation: {len(val_fold_df)}")
        
        # Create datasets
        train_dataset = EnhancedAneurysmDataset(
            train_fold_df, localizer_df, Config.SERIES_DIR, processor, mode='train', fold=fold
        )
        val_dataset = EnhancedAneurysmDataset(
            val_fold_df, localizer_df, Config.SERIES_DIR, processor, mode='val', fold=fold,
            shared_mean_volume=train_dataset.mean_volume
        )
        
        # Create data loaders with conservative settings for stability
        train_loader = DataLoader(
            train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, 
            num_workers=6, pin_memory=True, drop_last=True, 
            prefetch_factor=2, persistent_workers=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, 
            num_workers=4, pin_memory=True, prefetch_factor=2, persistent_workers=True
        )
        
        # Initialize model
        model = SimplifiedAneurysmNet().to(Config.DEVICE)
        # Enable multi-GPU training if available
        if torch.cuda.device_count() > 1:
            print(f"🚀 Using {torch.cuda.device_count()} GPUs for training")
            model = nn.DataParallel(model)
        else:
            print(f"📱 Using single GPU: {Config.DEVICE}")
        criterion = WeightedMultiLabelLoss(pos_weights=pos_weights).to(Config.DEVICE)
        
        # Optimizer with different learning rates for different parts
        # optimizer = optim.AdamW([
        #     {'params': model.backbone.parameters(), 'lr': Config.LEARNING_RATE},
        #     {'params': model.classifier.parameters(), 'lr': Config.LEARNING_RATE * 2},
        #     {'params': model.metadata_mlp.parameters(), 'lr': Config.LEARNING_RATE * 2}
        # ], weight_decay=Config.WEIGHT_DECAY)
        optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)
        
        # Learning rate scheduler
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=10, T_mult=2, eta_min=1e-6
        )
        
        scaler = GradScaler(enabled=Config.MIXED_PRECISION)
        
        # Training loop
        best_auc = 0
        patience = 10
        patience_counter = 0
        
        for epoch in range(Config.EPOCHS):
            # Train
            train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler, Config.DEVICE)
            
            # Validate
            val_loss, val_auc, individual_aucs = validate_epoch(model, val_loader, criterion, Config.DEVICE)
            
            # Step scheduler
            scheduler.step()
            
            print(f"Epoch {epoch+1:3d} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"Val AUC: {val_auc:.4f}")

            # Print comprehensive DICOM stats and analyze corruption patterns
            processor.print_stats()
            
            # Analyze corruption patterns (every 2 epochs to avoid spam)
            if epoch % 2 == 0:
                processor.analyze_corruption_patterns()

            # Check legacy corrupted series log (for compatibility)
            try:
                with open('corrupted_series.txt', 'r') as f:
                    lines = f.readlines()
                    corrupted_count = len([l for l in lines if not l.startswith('#')])
                    if corrupted_count > 0:
                        print(f"📄 Legacy corruption log: {corrupted_count} entries in 'corrupted_series.txt'")
            except FileNotFoundError:
                pass  # Expected if no legacy logging occurred

            # SANITY CHECK: Stop if data loading is fundamentally broken
            if processor.stats['total_loaded'] > 20:  # Only check after some attempts
                success_rate = processor.stats['successful_loads'] / processor.stats['total_loaded']
                if success_rate < 0.5:  # Less than 50% success rate
                    print(f"\n🚨 STOPPING TRAINING: Data loading success rate is {success_rate:.1%}")
                    print("Fix the DICOM loading issues before continuing training!")
                    print("Most volumes are returning empty - this is a waste of time!")
                    break
            
            # Save best model
            if val_auc > best_auc:
                best_auc = val_auc
                patience_counter = 0
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_auc': val_auc,
                    'epoch': epoch,
                    'fold': fold,
                    'individual_aucs': individual_aucs
                }, f'best_model_fold_{fold}.pth')
            else:
                patience_counter += 1
                
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
                
            # Memory cleanup
            if epoch % 5 == 0:
                gc.collect()
                torch.cuda.empty_cache()
        
        fold_scores.append(best_auc)
        print(f"Fold {fold + 1} best AUC: {best_auc:.4f}")
    
    # Final results
    mean_cv_score = np.mean(fold_scores)
    std_cv_score = np.std(fold_scores)
    
    print(f"\n{'='*50}")
    print(f"CROSS-VALIDATION RESULTS")
    print(f"{'='*50}")
    print(f"Mean CV AUC: {mean_cv_score:.4f} ± {std_cv_score:.4f}")
    print(f"Individual fold scores: {fold_scores}")
    
    # Save training summary
    results = {
        'cv_scores': fold_scores,
        'mean_cv_score': mean_cv_score,
        'std_cv_score': std_cv_score,
        'config': vars(Config())
    }
    
    with open('training_results.json', 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    print("Training complete! Models saved as 'best_model_fold_X.pth'")

if __name__ == "__main__":
    main()

Using device: cuda
Mixed precision: True
GPU devices available: 2
  GPU 0: Tesla T4 (15.8GB)
  GPU 1: Tesla T4 (15.8GB)
Optimized settings - Batch size: 16, Workers: 6/4 (Conservative)
---!!! RUNNING IN DEBUG MODE ON A SMALL SUBSET !!!---
Training samples: 100 (limited for speed testing)
Positive aneurysm cases: 48
Class weights (capped at 100): tensor([ 32.3333,  49.0000,  13.2857,   9.0000,  49.0000,  11.5000,  10.1111,
         99.0000,  49.0000,  49.0000, 100.0000,  32.3333,  49.0000,   1.0833])

🔍 Checking dataset integrity...
📊 Dataset check: 100/100 series valid (100.0%)
✅ Good dataset integrity (100.0% valid)

📝 Initialized 'corrupted_series.txt' for logging corrupted series
🔍 Enhanced corruption tracking enabled with detailed analysis

FOLD 1/3
Train: 66, Validation: 34
🔄 Computing mean volume fallback from valid series...
🔍 Loading series 1.2.826.0.1.3680043.8.498.10004684224894397679901841656954650085: Found 147 DICOM files
✅ Successfully loaded 147 valid DICOMs from series 

Training Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

🔍 Loading series 1.2.826.0.1.3680043.8.498.10058383541003792190302541266378919328: Found 88 DICOM files🔍 Loading series 1.2.826.0.1.3680043.8.498.10221223003274066645389576091413528073: Found 191 DICOM files🔍 Loading series 1.2.826.0.1.3680043.8.498.10168980078157176521154364692096920137: Found 52 DICOM files

🔍 Loading series 1.2.826.0.1.3680043.8.498.10022796280698534221758473208024838831: Found 671 DICOM files

✅ Successfully loaded 52 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10168980078157176521154364692096920137
✅ Consistent shapes: (320, 256) across 52 slices
🔍 Loading series 1.2.826.0.1.3680043.8.498.10022688097731894079510930966432818105: Found 178 DICOM files
✅ Successfully loaded 88 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10058383541003792190302541266378919328
✅ Consistent shapes: (512, 512) across 88 slices
🔍 Loading series 1.2.826.0.1.3680043.8.498.10232731436838657115800303234983509594: Found 30 DICOM files
✅ Successfully loaded 30 valid DICOMs from se

Training Epoch:  25%|██▌       | 1/4 [02:03<06:10, 123.50s/it]

  🔄 Batch 0: Backward pass: 0.286s
Batch 0: 1.74s
  🔍 Batch 1: GPU memory before forward: 0.07GB
  ⚡ Batch 1: Forward pass: 0.091s, GPU memory after: 1.14GB
  🔄 Batch 1: Backward pass: 0.013s


Training Epoch:  50%|█████     | 2/4 [02:03<01:41, 50.98s/it] 

Batch 1: 0.21s
  🔍 Batch 2: GPU memory before forward: 0.08GB
  ⚡ Batch 2: Forward pass: 0.085s, GPU memory after: 1.14GB
  🔄 Batch 2: Backward pass: 0.012s


Training Epoch:  75%|███████▌  | 3/4 [02:03<00:27, 27.79s/it]

Batch 2: 0.20s


Training Epoch: 100%|██████████| 4/4 [02:04<00:00, 31.06s/it]


Batch 3: 0.33s


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

🔍 Loading series 1.2.826.0.1.3680043.8.498.10129580404994628606227497184499173213: Found 313 DICOM files🔍 Loading series 1.2.826.0.1.3680043.8.498.10256018119694768427929632156620347034: Found 180 DICOM files🔍 Loading series 1.2.826.0.1.3680043.8.498.10004044428023505108375152878107656647: Found 188 DICOM files


✅ Successfully loaded 188 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10004044428023505108375152878107656647
✅ Consistent shapes: (512, 512) across 188 slices
Sample 0: DICOM load: 5.77s, Preprocess: 0.79s
🔍 Loading series 1.2.826.0.1.3680043.8.498.10009383108068795488741533244914370182: Found 224 DICOM files
✅ Successfully loaded 313 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10129580404994628606227497184499173213
✅ Consistent shapes: (512, 512) across 313 slices
✅ Successfully loaded 180 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10256018119694768427929632156620347034
✅ Consistent shapes: (1008, 1008) across 180 slices
🔍 Loading series 1.2.826.0.1.3680

Validating: 100%|██████████| 3/3 [01:54<00:00, 38.13s/it] 


Epoch   1 | Train Loss: 27.8750 | Val Loss: 27.3719 | Val AUC: 0.4746

📊 === COMPREHENSIVE DICOM LOADING STATS ===
📈 Total attempts: 10
✅ Successful loads: 7/10 (70.0%)
❌ Failed loads breakdown:
   🚫 Empty volumes: 3 (30.0%)
   📄 Invalid DICOM files: 0 (0.0%)
   🗂️  Empty pixel arrays: 0 (0.0%)
   💥 Corrupted pixel data: 0 (0.0%)
   📐 Shape errors: 0 (0.0%)
📊 Total corruption rate: 30.0%
⚠️  Moderate success rate (70.0%) - some data quality issues
   💡 Review detailed_corruption_log.txt for improvement opportunities
📄 Detailed analysis available in: detailed_corruption_log.txt

🔍 === CORRUPTION PATTERN ANALYSIS ===
📊 Unique series with errors: 9
🏷️  Error type breakdown:
   TOTAL_FAILURE: 10 occurrences
   SHAPE_MISMATCH: 2 occurrences
💡 Most common issue: TOTAL_FAILURE (10 cases)
🔧 Training on device: cuda
🔧 Model device (DataParallel): cuda:0
🔧 Criterion on GPU: cpu


Training Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

🔍 Loading series 1.2.826.0.1.3680043.8.498.10145340168188681268595785827168799711: Found 23 DICOM files🔍 Loading series 1.2.826.0.1.3680043.8.498.10240701911188793595728082556212433173: Found 140 DICOM files🔍 Loading series 1.2.826.0.1.3680043.8.498.10048925006598672000564912882060003872: Found 594 DICOM files


🔍 Loading series 1.2.826.0.1.3680043.8.498.10126487256624050201543415947047895825: Found 898 DICOM files
✅ Successfully loaded 140 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10240701911188793595728082556212433173
✅ Consistent shapes: (352, 352) across 140 slices
🔍 Loading series 1.2.826.0.1.3680043.8.498.10116626135148932224643146695383345963: Found 24 DICOM files
✅ Successfully loaded 23 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10145340168188681268595785827168799711
✅ Consistent shapes: (1024, 1024) across 23 slices
✅ Successfully loaded 24 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10116626135148932224643146695383345963
✅ Consistent shapes: (512, 512

Training Epoch:  25%|██▌       | 1/4 [01:14<03:42, 74.04s/it]

✅ Successfully loaded 74 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10052893794239333131781802642788307307
✅ Consistent shapes: (512, 512) across 74 slices
Batch 0: 0.32s
✅ Successfully loaded 240 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10095912539619532839962135126795591815
✅ Consistent shapes: (512, 512) across 240 slices
🔍 Loading series 1.2.826.0.1.3680043.8.498.10177117050965285724806213067235546942: Found 47 DICOM files
✅ Successfully loaded 47 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10177117050965285724806213067235546942
✅ Consistent shapes: (512, 512) across 47 slices
🔍 Loading series 1.2.826.0.1.3680043.8.498.10058383541003792190302541266378919328: Found 88 DICOM files
🔍 Loading series 1.2.826.0.1.3680043.8.498.10144083517869641752799954597390552857: Found 194 DICOM files
✅ Successfully loaded 136 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10046318991957083423208748012349179640
✅ Consistent shapes: (512, 512) across 136 slices
✅ Successful

Training Epoch:  50%|█████     | 2/4 [01:32<01:22, 41.23s/it]

Batch 1: 0.28s
  🔍 Batch 2: GPU memory before forward: 0.13GB
  ⚡ Batch 2: Forward pass: 0.110s, GPU memory after: 1.19GB
  🔄 Batch 2: Backward pass: 0.019s


Training Epoch:  75%|███████▌  | 3/4 [01:32<00:22, 22.52s/it]

Batch 2: 0.24s
🔍 Loading series 1.2.826.0.1.3680043.8.498.10098743283291956051221530305664415374: Found 44 DICOM files
✅ Successfully loaded 44 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10098743283291956051221530305664415374
✅ Consistent shapes: (512, 512) across 44 slices
🔍 Loading series 1.2.826.0.1.3680043.8.498.10229915682372012073055285556885310225: Found 204 DICOM files
✅ Successfully loaded 204 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10229915682372012073055285556885310225
✅ Consistent shapes: (512, 512) across 204 slices


Training Epoch: 100%|██████████| 4/4 [01:37<00:00, 24.33s/it]


Batch 3: 0.27s


Validating:   0%|          | 0/3 [00:00<?, ?it/s]

🔍 Loading series 1.2.826.0.1.3680043.8.498.10256018119694768427929632156620347034: Found 180 DICOM files🔍 Loading series 1.2.826.0.1.3680043.8.498.10004044428023505108375152878107656647: Found 188 DICOM files🔍 Loading series 1.2.826.0.1.3680043.8.498.10129580404994628606227497184499173213: Found 313 DICOM files


✅ Successfully loaded 188 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10004044428023505108375152878107656647
✅ Consistent shapes: (512, 512) across 188 slices
Sample 0: DICOM load: 4.10s, Preprocess: 0.75s
🔍 Loading series 1.2.826.0.1.3680043.8.498.10009383108068795488741533244914370182: Found 224 DICOM files
✅ Successfully loaded 313 valid DICOMs from series 1.2.826.0.1.3680043.8.498.10129580404994628606227497184499173213
✅ Consistent shapes: (512, 512) across 313 slices
🔍 Loading series 1.2.826.0.1.3680043.8.498.10133805409448598100180344093077653742: Found 554 DICOM files
✅ Successfully loaded 180 valid DICOMs from series 1.2.826.0.1.3680043.8.498.1025601811969476842