<a href="https://www.kaggle.com/code/nicholas33/02-aneurysmnet-cnn-intracranial-training-nb153?scriptVersionId=255257425" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
# ====================================================
# RSNA INTRACRANIAL ANEURYSM DETECTION - TRAINING PIPELINE
# ====================================================

import os
import gc
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pydicom
import nibabel as nib
import cv2
from scipy import ndimage
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [2]:
# ====================================================
# CELL 2: CONFIGURATION
# ====================================================

class Config:
    # Paths
    TRAIN_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train.csv'
    LOCALIZER_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train_localizers.csv'
    SERIES_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/series/'
    SEGMENTATION_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/segmentations/'
    
    # Stage 1: 3D Segmentation
    STAGE1_TARGET_SIZE = (64, 128, 128)  # Smaller for speed
    STAGE1_BATCH_SIZE = 4
    STAGE1_EPOCHS = 7
    STAGE1_LR = 3e-4
    
    # General
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    MIXED_PRECISION = True
    N_FOLDS = 3
    
    # Competition constants
    ID_COL = 'SeriesInstanceUID'
    LABEL_COLS = [
        'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery',
        'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
        'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery', 'Anterior Communicating Artery',
        'Left Anterior Cerebral Artery', 'Right Anterior Cerebral Artery',
        'Left Posterior Communicating Artery', 'Right Posterior Communicating Artery',
        'Basilar Tip', 'Other Posterior Circulation', 'Aneurysm Present',
    ]
    TARGET_COL = 'Aneurysm Present'
    
    # Debug settings
    DEBUG_MODE = False
    DEBUG_SAMPLES = 200  # Use small subset for testing

print(f"✅ Configuration loaded - Device: {Config.DEVICE}")

# ====================================================
# CELL 2.5: CUSTOM 3D UNET (REPLACES MONAI BASICUNET)
# ====================================================

class Custom3DUNet(nn.Module):
    """Pure PyTorch 3D UNet implementation to replace MONAI BasicUNet"""
    
    def __init__(self, spatial_dims=3, in_channels=1, out_channels=32, 
                 features=(32, 64, 128, 256, 512, 32), dropout=0.1):
        super().__init__()
        
        self.features = features
        self.dropout = dropout
        
        # Encoder (downsampling path)
        self.encoder_blocks = nn.ModuleList()
        prev_channels = in_channels
        
        for i, feature_count in enumerate(features[:-1]):  # Exclude last feature (decoder output)
            # Each encoder block: Conv3D -> BatchNorm -> ReLU -> Conv3D -> BatchNorm -> ReLU
            block = nn.Sequential(
                nn.Conv3d(prev_channels, feature_count, kernel_size=3, padding=1),
                nn.BatchNorm3d(feature_count),
                nn.ReLU(inplace=True),
                nn.Conv3d(feature_count, feature_count, kernel_size=3, padding=1),
                nn.BatchNorm3d(feature_count),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout) if dropout > 0 else nn.Identity()
            )
            self.encoder_blocks.append(block)
            prev_channels = feature_count
        
        # Downsampling layers (MaxPool)
        self.downsample_layers = nn.ModuleList([
            nn.MaxPool3d(kernel_size=2, stride=2) 
            for _ in range(len(features) - 2)  # No downsampling after last encoder block
        ])
        
        # Decoder (upsampling path)
        self.decoder_blocks = nn.ModuleList()
        self.upsample_layers = nn.ModuleList()
        
        # Reverse the features for decoder (skip the input feature count)
        decoder_features = list(reversed(features[:-1]))  # [512, 256, 128, 64, 32]
        
        for i in range(len(decoder_features) - 1):
            current_features = decoder_features[i]
            next_features = decoder_features[i + 1]
            
            # Upsampling layer
            upsample = nn.ConvTranspose3d(
                current_features, next_features, 
                kernel_size=2, stride=2
            )
            self.upsample_layers.append(upsample)
            
            # Decoder block (concatenation + convolutions)
            # Input: upsampled features + skip connection = next_features * 2
            decoder_block = nn.Sequential(
                nn.Conv3d(next_features * 2, next_features, kernel_size=3, padding=1),
                nn.BatchNorm3d(next_features),
                nn.ReLU(inplace=True),
                nn.Conv3d(next_features, next_features, kernel_size=3, padding=1),
                nn.BatchNorm3d(next_features),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout) if dropout > 0 else nn.Identity()
            )
            self.decoder_blocks.append(decoder_block)
        
        # Final output convolution
        self.final_conv = nn.Conv3d(features[0], out_channels, kernel_size=1)
        
    def forward(self, x):
        # Store skip connections
        skip_connections = []
        
        # Encoder path
        for i, encoder_block in enumerate(self.encoder_blocks):
            x = encoder_block(x)
            skip_connections.append(x)
            
            # Downsample (except for the last encoder block)
            if i < len(self.downsample_layers):
                x = self.downsample_layers[i](x)
        
        # Decoder path
        skip_connections = skip_connections[:-1]  # Remove the deepest layer (no skip for bottleneck)
        skip_connections.reverse()  # Reverse to match decoder order
        
        for i, (upsample_layer, decoder_block) in enumerate(zip(self.upsample_layers, self.decoder_blocks)):
            # Upsample
            x = upsample_layer(x)
            
            # Get corresponding skip connection
            skip = skip_connections[i]
            
            # Ensure spatial dimensions match (handle odd-sized inputs)
            if x.shape[2:] != skip.shape[2:]:
                x = nn.functional.interpolate(x, size=skip.shape[2:], mode='trilinear', align_corners=False)
            
            # Concatenate skip connection
            x = torch.cat([x, skip], dim=1)
            
            # Apply decoder block
            x = decoder_block(x)
        
        # Final output
        x = self.final_conv(x)
        
        return x

class Enhanced3DAugmentation:
    """Intensive 3D augmentations for medical imaging using scipy/numpy"""
    
    def __init__(self, mode='train'):
        self.mode = mode
        self.apply_augmentation = (mode == 'train')
        
    def random_rotation_3d(self, volume, max_angle=15):
        """Random 3D rotation"""
        if not self.apply_augmentation or np.random.random() > 0.5:
            return volume
            
        angle = np.random.uniform(-max_angle, max_angle)
        # Rotate around z-axis (axial plane)
        rotated = ndimage.rotate(volume, angle, axes=(1, 2), reshape=False, order=1)
        return rotated
    
    def random_elastic_deformation(self, volume, sigma=4, points=3):
        """Random elastic deformation for vessel-like structures"""
        if not self.apply_augmentation or np.random.random() > 0.3:
            return volume
            
        shape = volume.shape
        dx = ndimage.gaussian_filter((np.random.random(shape) - 0.5), sigma) * points
        dy = ndimage.gaussian_filter((np.random.random(shape) - 0.5), sigma) * points
        dz = ndimage.gaussian_filter((np.random.random(shape) - 0.5), sigma) * points
        
        x, y, z = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]), indexing='ij')
        indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1)), np.reshape(z + dz, (-1, 1))
        
        deformed = ndimage.map_coordinates(volume, indices, order=1, mode='reflect')
        return deformed.reshape(shape)
    
    def random_brightness_contrast(self, volume, brightness=0.2, contrast=0.2):
        """Random brightness and contrast for aneurysm visibility"""
        if not self.apply_augmentation or np.random.random() > 0.7:
            return volume
            
        # Brightness adjustment
        brightness_factor = 1 + np.random.uniform(-brightness, brightness)
        volume = volume * brightness_factor
        
        # Contrast adjustment
        contrast_factor = 1 + np.random.uniform(-contrast, contrast)
        mean = volume.mean()
        volume = (volume - mean) * contrast_factor + mean
        
        return np.clip(volume, 0, 1)
    
    def random_gaussian_noise(self, volume, std_range=(0, 0.05)):
        """Add Gaussian noise to improve robustness"""
        if not self.apply_augmentation or np.random.random() > 0.4:
            return volume
            
        std = np.random.uniform(std_range[0], std_range[1])
        noise = np.random.normal(0, std, volume.shape)
        return np.clip(volume + noise, 0, 1)
    
    def random_gamma_correction(self, volume, gamma_range=(0.8, 1.2)):
        """Gamma correction for intensity variations"""
        if not self.apply_augmentation or np.random.random() > 0.5:
            return volume
            
        gamma = np.random.uniform(gamma_range[0], gamma_range[1])
        return np.power(volume, gamma)
    
    def __call__(self, data_dict):
        """Apply all augmentations"""
        result = {}
        
        for key in data_dict:
            if key == 'volume' and isinstance(data_dict[key], np.ndarray):
                volume = data_dict[key].copy()
                
                # Apply augmentations sequentially
                volume = self.random_rotation_3d(volume)
                volume = self.random_elastic_deformation(volume)
                volume = self.random_brightness_contrast(volume)
                volume = self.random_gaussian_noise(volume)
                volume = self.random_gamma_correction(volume)
                
                # Convert to tensor
                result[key] = torch.from_numpy(volume).float()
            elif isinstance(data_dict[key], np.ndarray):
                result[key] = torch.from_numpy(data_dict[key]).float()
            else:
                result[key] = data_dict[key]
        
        return result

class CustomTransforms:
    """Simple transforms for validation (no augmentation)"""
    
    def __init__(self, keys=['volume']):
        self.keys = keys
        
    def __call__(self, data_dict):
        """Apply transforms to data dictionary"""
        result = {}
        
        for key in data_dict:
            if key in self.keys:
                # Convert numpy array to tensor if needed
                if isinstance(data_dict[key], np.ndarray):
                    result[key] = torch.from_numpy(data_dict[key]).float()
                else:
                    result[key] = data_dict[key]
            else:
                result[key] = data_dict[key]
        
        return result

print("✅ Enhanced 3D UNet with medical augmentations loaded (MONAI-free!)")

# ====================================================
# CELL 3: SIMPLE DICOM PROCESSOR
# ====================================================

class SimpleDICOMProcessor:
    def __init__(self, target_size=None):
        self.target_size = target_size or Config.STAGE1_TARGET_SIZE
        
    def load_dicom_series(self, series_path):
        """Simple DICOM loading - no complex error handling"""
        try:
            dicom_files = [f for f in os.listdir(series_path) if f.endswith('.dcm')]
            if not dicom_files:
                return np.zeros(self.target_size, dtype=np.float32)
            
            # Load all DICOMs
            pixel_arrays = []
            for f in dicom_files[:50]:  # Limit to 50 files max for speed
                try:
                    ds = pydicom.dcmread(os.path.join(series_path, f), force=True)
                    if hasattr(ds, 'pixel_array'):
                        arr = ds.pixel_array
                        if arr.ndim == 2:  # Standard 2D slice
                            pixel_arrays.append(arr)
                        elif arr.ndim == 3:  # 3D volume - take middle slices
                            mid_start = arr.shape[0] // 4
                            mid_end = 3 * arr.shape[0] // 4
                            for slice_idx in range(mid_start, mid_end, 2):  # Every 2nd slice
                                pixel_arrays.append(arr[slice_idx])
                except:
                    continue
            
            if not pixel_arrays:
                return np.zeros(self.target_size, dtype=np.float32)
            
            # Resize all slices to same shape before stacking
            if len(pixel_arrays) > 0:
                # Use first slice shape as reference, or use a standard size
                target_slice_shape = (256, 256)  # Standard size for all slices
                
                resized_arrays = []
                for arr in pixel_arrays:
                    if arr.shape != target_slice_shape:
                        # Resize slice to target shape
                        resized_arr = ndimage.zoom(arr, 
                                                 (target_slice_shape[0] / arr.shape[0], 
                                                  target_slice_shape[1] / arr.shape[1]), 
                                                 order=1)
                        resized_arrays.append(resized_arr)
                    else:
                        resized_arrays.append(arr)
                
                # Now stack - all arrays have same shape
                volume = np.stack(resized_arrays, axis=0).astype(np.float32)
            else:
                return np.zeros(self.target_size, dtype=np.float32)
            
            # Simple preprocessing
            volume = self.preprocess_volume(volume)
            return volume
            
        except Exception as e:
            print(f"Failed to load {series_path}: {e}")
            return np.zeros(self.target_size, dtype=np.float32)
    
    def preprocess_volume(self, volume):
        """Simple preprocessing"""
        # Normalize
        p1, p99 = np.percentile(volume, [1, 99])
        volume = np.clip(volume, p1, p99)
        volume = (volume - p1) / (p99 - p1 + 1e-8)
        
        # Resize to target
        if volume.shape != self.target_size:
            zoom_factors = [self.target_size[i] / volume.shape[i] for i in range(3)]
            volume = ndimage.zoom(volume, zoom_factors, order=1)
        
        return volume.astype(np.float32)

print("✅ DICOM Processor loaded")

# ====================================================
# CELL 4: DATASET CLASS
# ====================================================

class SimpleSegmentationDataset(Dataset):
    def __init__(self, df, series_dir, processor, mode='train'):
        self.df = df
        self.series_dir = series_dir
        self.processor = processor
        self.mode = mode
        
        # Enhanced augmentation for training, simple transforms for validation
        if mode == 'train':
            self.transform = Enhanced3DAugmentation(mode='train')
        else:
            self.transform = CustomTransforms(keys=['volume'])
        
    def __len__(self):
        return len(self.df)

    def validate_segmentation_mask(self, series_id, mask):
        """Validate segmentation mask quality"""
        # Check if mask is empty
        if mask.max() == 0:
            return False
            
        # Check mask connectivity and size
        mask_binary = (mask > 0.5).astype(np.uint8)
        labeled_mask, num_components = ndimage.label(mask_binary)
        
        if num_components == 0:
            return False
            
        # Check component sizes (aneurysms should be small but not tiny)
        component_sizes = []
        for i in range(1, num_components + 1):
            component_size = np.sum(labeled_mask == i)
            component_sizes.append(component_size)
        
        # Valid if has reasonably sized components
        valid_components = [size for size in component_sizes if 10 < size < 10000]
        return len(valid_components) > 0
    
    
    def load_segmentation_mask(self, series_id, volume_shape):
        """Load real segmentation mask from competition data with validation"""
        seg_path = os.path.join(Config.SEGMENTATION_DIR, f"{series_id}.nii")
        
        try:
            if os.path.exists(seg_path):
                # Load NIfTI segmentation mask
                import nibabel as nib
                nii_img = nib.load(seg_path)
                mask = nii_img.get_fdata().astype(np.float32)
                
                # Resize mask to match volume shape
                if mask.shape != volume_shape:
                    zoom_factors = [volume_shape[i] / mask.shape[i] for i in range(3)]
                    mask = ndimage.zoom(mask, zoom_factors, order=0)  # Nearest neighbor for masks
                
                # Normalize mask values to 0-1
                mask = (mask > 0).astype(np.float32)
                
                # Validate mask quality
                if self.validate_segmentation_mask(series_id, mask):
                    return mask
                else:
                    # Mask failed validation - use fallback for aneurysm cases
                    has_aneurysm = int(self.df[self.df[Config.ID_COL] == series_id][Config.TARGET_COL].iloc[0])
                    if has_aneurysm:
                        # Create enhanced central region mask for aneurysm cases
                        mask = np.zeros(volume_shape, dtype=np.float32)
                        h, w, d = volume_shape
                        # Multiple small regions to simulate potential aneurysm locations
                        mask[h//3:2*h//3, w//3:2*w//3, d//3:2*d//3] = 0.7
                        mask[h//4:3*h//4, w//4:3*w//4, d//2:d//2+d//8] = 1.0  # Central strong region
                        return mask
                    else:
                        return np.zeros(volume_shape, dtype=np.float32)
            else:
                # No segmentation available - create empty mask
                return np.zeros(volume_shape, dtype=np.float32)
                
        except Exception as e:
            print(f"Error loading segmentation for {series_id}: {e}")
            # Fallback: create simple mask if aneurysm present
            has_aneurysm = int(self.df[self.df[Config.ID_COL] == series_id][Config.TARGET_COL].iloc[0])
            if has_aneurysm:
                # Create a rough central region mask as fallback
                mask = np.zeros(volume_shape, dtype=np.float32)
                h, w, d = volume_shape
                mask[h//4:3*h//4, w//4:3*w//4, d//4:3*d//4] = 1.0
                return mask
            else:
                return np.zeros(volume_shape, dtype=np.float32)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        series_id = row[Config.ID_COL]
        series_path = os.path.join(self.series_dir, series_id)
        
        # Load volume
        volume = self.processor.load_dicom_series(series_path)
        
        # Load REAL segmentation mask from competition data
        mask = self.load_segmentation_mask(series_id, volume.shape)
        
        # Get aneurysm presence label
        has_aneurysm = int(row[Config.TARGET_COL])
        
        # Transform
        data_dict = {'volume': volume}
        if self.transform:
            data_dict = self.transform(data_dict)
        
        volume_tensor = data_dict['volume'].unsqueeze(0)  # Add channel dim
        mask_tensor = torch.from_numpy(mask).unsqueeze(0)
        
        return {
            'volume': volume_tensor,
            'mask': mask_tensor,
            'has_aneurysm': torch.tensor(has_aneurysm, dtype=torch.float32),
            'series_id': series_id
        }

print("✅ Dataset class loaded")

# ====================================================
# CELL 5: 3D U-NET MODEL
# ====================================================

class Simple3DSegmentationNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        
        # Use our Custom3DUNet - pure PyTorch implementation!
        self.backbone = Custom3DUNet(
            spatial_dims=3,
            in_channels=in_channels,
            out_channels=32,
            features=(32, 64, 128, 256, 512, 32),
            dropout=0.1
        )
        
        # Segmentation head
        self.seg_head = nn.Conv3d(32, out_channels, kernel_size=1)
        
        # Classification head (aneurysm presence)
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.classifier = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )
        
    def forward(self, x):
        # Extract features
        features = self.backbone(x)
        
        # Segmentation output
        seg_logits = self.seg_head(features)
        
        # Classification output
        pooled_features = self.global_pool(features).flatten(1)
        cls_logits = self.classifier(pooled_features)
        
        return seg_logits, cls_logits

print("✅ Model architecture loaded")

# ====================================================
# CELL 6: ENHANCED LOSS FUNCTIONS
# ====================================================

class DiceLoss(nn.Module):
    """Dice Loss for better segmentation of small objects"""
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
        
    def forward(self, predictions, targets):
        # Apply sigmoid to logits
        predictions = torch.sigmoid(predictions)
        
        # Flatten tensors
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        # Calculate intersection and union
        intersection = (predictions * targets).sum()
        dice = (2. * intersection + self.smooth) / (predictions.sum() + targets.sum() + self.smooth)
        
        return 1 - dice

class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    def __init__(self, alpha=0.25, gamma=2.0, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.smooth = smooth
        
    def forward(self, predictions, targets):
        # Apply sigmoid to get probabilities
        probs = torch.sigmoid(predictions)
        
        # Calculate focal loss components
        pt = torch.where(targets == 1, probs, 1 - probs)
        ce_loss = nn.functional.binary_cross_entropy_with_logits(predictions, targets, reduction='none')
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        return focal_loss.mean()
        
class EnhancedCombinedLoss(nn.Module):
    """Enhanced loss combining Dice + BCE + Focal for medical segmentation"""
    def __init__(self):
        super().__init__()
        self.dice_loss = DiceLoss()
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.focal_loss = FocalLoss(alpha=0.25, gamma=2)
        
    def forward(self, seg_logits, cls_logits, seg_targets, cls_targets):
        # Multi-component segmentation loss
        dice_loss = self.dice_loss(seg_logits, seg_targets)
        bce_seg_loss = self.bce_loss(seg_logits, seg_targets)
        focal_seg_loss = self.focal_loss(seg_logits, seg_targets)
        
        # Weighted combination for class imbalance
        seg_loss = 0.5 * dice_loss + 0.3 * bce_seg_loss + 0.2 * focal_seg_loss
        
        # Classification loss
        # Ensure batch dimension is preserved even when batch size == 1
        cls_loss = self.bce_loss(cls_logits.view(-1), cls_targets)
        
        # Higher weight on segmentation for Stage 1
        total_loss = 2.0 * seg_loss + 0.5 * cls_loss
        return total_loss, seg_loss, cls_loss

print("✅ Enhanced loss functions loaded (Dice + BCE + Focal)")

# ====================================================
# CELL 7: TRAINING FUNCTIONS
# ====================================================

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_seg_loss = 0
    total_cls_loss = 0
    num_batches = 0
    
    for batch in tqdm(loader, desc="Training"):
        volume = batch['volume'].to(device)
        mask = batch['mask'].to(device)
        has_aneurysm = batch['has_aneurysm'].to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        seg_logits, cls_logits = model(volume)
        
        # Calculate loss
        loss, seg_loss, cls_loss = criterion(seg_logits, cls_logits, mask, has_aneurysm)
        
        # Backward pass with gradient clipping
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_seg_loss += seg_loss.item()
        total_cls_loss += cls_loss.item()
        num_batches += 1
    
    return (total_loss / num_batches, 
            total_seg_loss / num_batches, 
            total_cls_loss / num_batches)

def validate_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    total_seg_loss = 0
    total_cls_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating"):
            volume = batch['volume'].to(device)
            mask = batch['mask'].to(device)
            has_aneurysm = batch['has_aneurysm'].to(device)
            
            # Forward pass
            seg_logits, cls_logits = model(volume)
            
            # Calculate loss
            loss, seg_loss, cls_loss = criterion(seg_logits, cls_logits, mask, has_aneurysm)
            
            total_loss += loss.item()
            total_seg_loss += seg_loss.item()
            total_cls_loss += cls_loss.item()
            num_batches += 1

            # Release per-batch tensors to help GPU memory
            del volume, mask, has_aneurysm, seg_logits, cls_logits, loss, seg_loss, cls_loss
    
    return (total_loss / num_batches, 
            total_seg_loss / num_batches, 
            total_cls_loss / num_batches)

print("✅ Training functions loaded")

# ====================================================
# CELL 8: MAIN TRAINING LOOP
# ====================================================

def main():
    print(f"🚀 STAGE 1: 3D SEGMENTATION FOR REGION LOCALIZATION")
    print(f"Using device: {Config.DEVICE}")
    print(f"Target size: {Config.STAGE1_TARGET_SIZE}")
    
    # Load data
    train_df = pd.read_csv(Config.TRAIN_CSV_PATH)
    
    # Load localizer data (for future use)
    try:
        localizer_df = pd.read_csv(Config.LOCALIZER_CSV_PATH)
        print(f"Loaded localizer data: {len(localizer_df)} entries")
    except:
        localizer_df = None
        print("No localizer data found - continuing without it")
    
    # Debug mode - small subset
    if Config.DEBUG_MODE:
        train_df = train_df.head(Config.DEBUG_SAMPLES)
    print(f"Training samples: {len(train_df)}")
    print(f"Aneurysm cases: {train_df[Config.TARGET_COL].sum()}")
    
    # Simple train/val split
    val_size = len(train_df) // 5
    val_df = train_df[:val_size].reset_index(drop=True)
    train_df = train_df[val_size:].reset_index(drop=True)
    
    print(f"Train: {len(train_df)}, Val: {len(val_df)}")
    
    # Create datasets
    processor = SimpleDICOMProcessor()
    train_dataset = SimpleSegmentationDataset(train_df, Config.SERIES_DIR, processor, 'train')
    val_dataset = SimpleSegmentationDataset(val_df, Config.SERIES_DIR, processor, 'val')
    
    # Create loaders
    train_loader = DataLoader(train_dataset, batch_size=Config.STAGE1_BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=Config.STAGE1_BATCH_SIZE, shuffle=False, num_workers=2)
    
    # Create model
    model = Simple3DSegmentationNet().to(Config.DEVICE)
    
    # Multi-GPU if available
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs")
        model = nn.DataParallel(model)
    
    # Enhanced optimizer and loss - proven optimization from winning solutions
    optimizer = optim.AdamW(model.parameters(), lr=Config.STAGE1_LR, weight_decay=1e-4)
    criterion = EnhancedCombinedLoss()  # Use enhanced loss function
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.STAGE1_EPOCHS, eta_min=1e-6)
    
    # Training loop
    best_loss = float('inf')
    
    for epoch in range(Config.STAGE1_EPOCHS):
        print(f"\nEpoch {epoch+1}/{Config.STAGE1_EPOCHS}")
        
        # Train
        train_loss, train_seg_loss, train_cls_loss = train_epoch(
            model, train_loader, optimizer, criterion, Config.DEVICE
        )
        
        # Validate
        val_loss, val_seg_loss, val_cls_loss = validate_epoch(
            model, val_loader, criterion, Config.DEVICE
        )
        
        # Step scheduler
        scheduler.step()
        
        print(f"Train - Total: {train_loss:.4f}, Seg: {train_seg_loss:.4f}, Cls: {train_cls_loss:.4f}")
        print(f"Val   - Total: {val_loss:.4f}, Seg: {val_seg_loss:.4f}, Cls: {val_cls_loss:.4f}")
        
        # Save best model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'val_loss': val_loss
            }, 'stage1_segmentation_best.pth')
            print(f"💾 Saved best model (val_loss: {val_loss:.4f})")

        # Proactive memory cleanup after each epoch
        try:
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
        except Exception:
            pass
        gc.collect()
    
    print(f"\n✅ Stage 1 complete! Best val loss: {best_loss:.4f}")
    print("📁 Model saved as 'stage1_segmentation_best.pth'")
    
    return model

# ====================================================
# CELL 9: ROI EXTRACTOR FOR STAGE 2 (FUTURE USE)
# ====================================================

class ROIExtractor:
    def __init__(self, roi_size=(224, 224), confidence_threshold=0.5):
        self.roi_size = roi_size
        self.confidence_threshold = confidence_threshold
    
    def extract_rois(self, volume, segmentation_mask):
        """Extract 2D ROI slices from 3D volume using segmentation mask"""
        rois = []
        
        # Find slices with high confidence regions
        for slice_idx in range(volume.shape[0]):
            slice_volume = volume[slice_idx]
            slice_mask = segmentation_mask[slice_idx]
            
            # Check if this slice has potential aneurysm regions
            if np.max(slice_mask) > self.confidence_threshold:
                # Find connected components
                binary_mask = (slice_mask > self.confidence_threshold).astype(np.uint8)
                
                # Find contours
                contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                
                for contour in contours:
                    # Get bounding box
                    x, y, w, h = cv2.boundingRect(contour)
                    
                    # Expand bounding box
                    margin = max(w, h) // 4
                    x = max(0, x - margin)
                    y = max(0, y - margin)
                    w = min(slice_volume.shape[1] - x, w + 2*margin)
                    h = min(slice_volume.shape[0] - y, h + 2*margin)
                    
                    # Extract ROI
                    roi = slice_volume[y:y+h, x:x+w]
                    
                    # Resize to standard size
                    roi_resized = cv2.resize(roi, self.roi_size)
                    
                    rois.append({
                        'roi': roi_resized,
                        'slice_idx': slice_idx,
                        'bbox': (x, y, w, h),
                        'confidence': np.max(slice_mask[y:y+h, x:x+w])
                    })
        
        return rois

print("✅ ROI Extractor loaded (for Stage 2)")

✅ Configuration loaded - Device: cuda
✅ Enhanced 3D UNet with medical augmentations loaded (MONAI-free!)
✅ DICOM Processor loaded
✅ Dataset class loaded
✅ Model architecture loaded
✅ Enhanced loss functions loaded (Dice + BCE + Focal)
✅ Training functions loaded
✅ ROI Extractor loaded (for Stage 2)


In [3]:
# ====================================================
# CELL 10: RUN TRAINING
# ====================================================

# Start Training
model = main()

print("Expected training time: ??? hours")
print("Output: stage1_segmentation_best.pth")

🚀 STAGE 1: 3D SEGMENTATION FOR REGION LOCALIZATION
Using device: cuda
Target size: (64, 128, 128)
Loaded localizer data: 2286 entries
Training samples: 4405
Aneurysm cases: 1893
Train: 3524, Val: 881
Using 2 GPUs

Epoch 1/7


Training: 100%|██████████| 881/881 [45:17<00:00,  3.08s/it]
Validating: 100%|██████████| 221/221 [09:37<00:00,  2.61s/it]


Train - Total: 1.3773, Seg: 0.5175, Cls: 0.6848
Val   - Total: 1.3459, Seg: 0.5007, Cls: 0.6889
💾 Saved best model (val_loss: 1.3459)

Epoch 2/7


Training: 100%|██████████| 881/881 [44:38<00:00,  3.04s/it]
Validating: 100%|██████████| 221/221 [09:20<00:00,  2.53s/it]


Train - Total: 1.3433, Seg: 0.5007, Cls: 0.6838
Val   - Total: 1.3469, Seg: 0.5002, Cls: 0.6927

Epoch 3/7


Training: 100%|██████████| 881/881 [46:07<00:00,  3.14s/it]
Validating: 100%|██████████| 221/221 [09:36<00:00,  2.61s/it]


Train - Total: 1.3397, Seg: 0.5002, Cls: 0.6784
Val   - Total: 1.3455, Seg: 0.5001, Cls: 0.6906
💾 Saved best model (val_loss: 1.3455)

Epoch 4/7


Training: 100%|██████████| 881/881 [44:50<00:00,  3.05s/it]
Validating: 100%|██████████| 221/221 [09:27<00:00,  2.57s/it]


Train - Total: 1.3407, Seg: 0.5001, Cls: 0.6808
Val   - Total: 1.3418, Seg: 0.5001, Cls: 0.6833
💾 Saved best model (val_loss: 1.3418)

Epoch 5/7


Training: 100%|██████████| 881/881 [45:47<00:00,  3.12s/it]
Validating: 100%|██████████| 221/221 [09:23<00:00,  2.55s/it]


Train - Total: 1.3396, Seg: 0.5001, Cls: 0.6789
Val   - Total: 1.3417, Seg: 0.5001, Cls: 0.6831
💾 Saved best model (val_loss: 1.3417)

Epoch 6/7


Training: 100%|██████████| 881/881 [45:47<00:00,  3.12s/it]
Validating: 100%|██████████| 221/221 [09:35<00:00,  2.60s/it]


Train - Total: 1.3389, Seg: 0.5001, Cls: 0.6774
Val   - Total: 1.3424, Seg: 0.5001, Cls: 0.6845

Epoch 7/7


Training: 100%|██████████| 881/881 [44:15<00:00,  3.01s/it]
Validating: 100%|██████████| 221/221 [09:31<00:00,  2.59s/it]


Train - Total: 1.3392, Seg: 0.5001, Cls: 0.6781
Val   - Total: 1.3411, Seg: 0.5000, Cls: 0.6820
💾 Saved best model (val_loss: 1.3411)

✅ Stage 1 complete! Best val loss: 1.3411
📁 Model saved as 'stage1_segmentation_best.pth'
Expected training time: ??? hours
Output: stage1_segmentation_best.pth
