In [None]:
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import timm
from typing import Optional, Tuple, List
import warnings
warnings.filterwarnings('ignore')

# Processing

## DICOM Processing Functions

In [None]:
def load_dicom_array(dcm_path: str) -> np.ndarray:
    """Load and process DICOM file to numpy array"""
    dicom = pydicom.dcmread(dcm_path)
    
    # Apply VOI LUT (if available) for proper windowing
    data = apply_voi_lut(dicom.pixel_array, dicom)
    
    # Handle slope and intercept
    if hasattr(dicom, 'RescaleSlope') and hasattr(dicom, 'RescaleIntercept'):
        data = data * dicom.RescaleSlope + dicom.RescaleIntercept
    
    return data.astype(np.float32)

def apply_windowing(img: np.ndarray, 
                   window_center: float, 
                   window_width: float) -> np.ndarray:
    """Apply CT windowing for better visualization"""
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    img = np.clip(img, img_min, img_max)
    return img

def get_brain_windows(img: np.ndarray) -> np.ndarray:
    """Apply multiple brain-specific CT windows"""
    # Common brain CT windows
    windows = {
        'brain': (40, 80),      # Brain tissue
        'subdural': (80, 200),  # Subdural window
        'stroke': (40, 40),     # Stroke window
        'aneurysm': (50, 150),  # Aneurysm visualization
    }
    
    windowed_images = []
    for name, (center, width) in windows.items():
        windowed = apply_windowing(img.copy(), center, width)
        # Normalize to 0-255
        windowed = ((windowed - windowed.min()) / 
                   (windowed.max() - windowed.min() + 1e-8) * 255).astype(np.uint8)
        windowed_images.append(windowed)
    
    # Stack as multi-channel image (can select 3 for RGB)
    return np.stack(windowed_images[:3], axis=-1)

## Image Preprocessing Functions

In [None]:
def normalize_hounsfield(img: np.ndarray) -> np.ndarray:
    """Normalize Hounsfield units to 0-1 range"""
    # Typical HU range for brain CT
    MIN_HU = -100
    MAX_HU = 200
    
    img = np.clip(img, MIN_HU, MAX_HU)
    img = (img - MIN_HU) / (MAX_HU - MIN_HU)
    return img

def remove_skull_artifacts(img: np.ndarray) -> np.ndarray:
    """Simple skull stripping using morphological operations"""
    # Convert to uint8 for OpenCV operations
    img_uint8 = (img * 255).astype(np.uint8) if img.max() <= 1 else img.astype(np.uint8)
    
    # Threshold to get brain region
    _, binary = cv2.threshold(img_uint8, 30, 255, cv2.THRESH_BINARY)
    
    # Morphological operations to clean up
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=2)
    binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=1)
    
    # Find largest contour (brain)
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        largest_contour = max(contours, key=cv2.contourArea)
        mask = np.zeros_like(binary)
        cv2.drawContours(mask, [largest_contour], -1, 255, -1)
        
        # Apply mask
        img_masked = cv2.bitwise_and(img_uint8, img_uint8, mask=mask)
        return img_masked
    return img_uint8

def resize_with_padding(img: np.ndarray, 
                       target_size: Tuple[int, int],
                       pad_value: int = 0) -> np.ndarray:
    """Resize image while maintaining aspect ratio with padding"""
    h, w = img.shape[:2]
    target_h, target_w = target_size
    
    # Calculate scale to fit within target size
    scale = min(target_w / w, target_h / h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    
    # Resize image
    resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
    
    # Create padded image
    if len(img.shape) == 3:
        padded = np.full((target_h, target_w, img.shape[2]), pad_value, dtype=img.dtype)
    else:
        padded = np.full((target_h, target_w), pad_value, dtype=img.dtype)
    
    # Center the resized image
    y_offset = (target_h - new_h) // 2
    x_offset = (target_w - new_w) // 2
    padded[y_offset:y_offset + new_h, x_offset:x_offset + new_w] = resized
    
    return padded

## Augmentation Pipeline

In [None]:
def get_train_transforms(img_size: int = 384):
    """Training augmentations optimized for medical images"""
    return A.Compose([
        # Spatial augmentations
        A.RandomResizedCrop(img_size, img_size, scale=(0.8, 1.0), p=0.5),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(
            shift_limit=0.1, 
            scale_limit=0.15, 
            rotate_limit=15, 
            p=0.5,
            border_mode=cv2.BORDER_CONSTANT
        ),
        
        # Intensity augmentations (careful with medical images)
        A.OneOf([
            A.RandomBrightnessContrast(
                brightness_limit=0.1, 
                contrast_limit=0.1, 
                p=1.0
            ),
            A.RandomGamma(gamma_limit=(90, 110), p=1.0),
            A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=1.0),
        ], p=0.5),
        
        # Noise and blur (minimal for medical images)
        A.OneOf([
            A.GaussNoise(var_limit=(5.0, 15.0), p=1.0),
            A.GaussianBlur(blur_limit=(3, 5), p=1.0),
        ], p=0.3),
        
        # Elastic deformation (useful for brain images)
        A.ElasticTransform(
            alpha=10, 
            sigma=5, 
            alpha_affine=0, 
            p=0.3
        ),
        
        # Normalize based on model requirements
        A.Normalize(
            mean=[0.485, 0.456, 0.406],  # ImageNet stats
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0
        ),
        ToTensorV2(),
    ])

def get_val_transforms(img_size: int = 384):
    """Validation/test augmentations"""
    return A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0
        ),
        ToTensorV2(),
    ])

## Dataset Class

In [None]:
class AneurysmDataset(Dataset):
    def __init__(self, 
                 df: pd.DataFrame,
                 img_dir: str,
                 transforms=None,
                 use_windowing: bool = True,
                 img_size: int = 384):
        self.df = df
        self.img_dir = img_dir
        self.transforms = transforms
        self.use_windowing = use_windowing
        self.img_size = img_size
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load DICOM
        img_path = f"{self.img_dir}/{row['image_id']}.dcm"
        img = load_dicom_array(img_path)
        
        # Apply preprocessing
        if self.use_windowing:
            # Multi-window approach for 3-channel input
            img = get_brain_windows(img)
        else:
            # Single channel approach
            img = normalize_hounsfield(img)
            img = (img * 255).astype(np.uint8)
            # Convert to 3-channel for models expecting RGB
            img = np.stack([img, img, img], axis=-1)
        
        # Optional: skull stripping (can be slow)
        # img = remove_skull_artifacts(img)
        
        # Resize with padding to maintain aspect ratio
        img = resize_with_padding(img, (self.img_size, self.img_size))
        
        # Apply augmentations
        if self.transforms:
            augmented = self.transforms(image=img)
            img = augmented['image']
        
        # Get label (adjust based on your label structure)
        label = row['label'] if 'label' in row else 0
        
        return img, torch.tensor(label, dtype=torch.float32)

## Model-Specific Preprocessing

In [None]:
def get_model_specific_preprocessing(model_name: str, img_size: int = 384):
    """Get model-specific preprocessing parameters"""
    
    preprocessing_configs = {
        'efficientnetv2': {
            'img_size': 384,  # EfficientNetV2 works well with 384
            'normalize': {
                'mean': [0.485, 0.456, 0.406],
                'std': [0.229, 0.224, 0.225]
            }
        },
        'convnext': {
            'img_size': 384,  # ConvNeXt standard size
            'normalize': {
                'mean': [0.485, 0.456, 0.406],
                'std': [0.229, 0.224, 0.225]
            }
        },
        'swin': {
            'img_size': 384,  # Swin Transformer preferred size
            'normalize': {
                'mean': [0.485, 0.456, 0.406],
                'std': [0.229, 0.224, 0.225]
            }
        },
        'maxvit': {
            'img_size': 384,
            'normalize': {
                'mean': [0.5, 0.5, 0.5],
                'std': [0.5, 0.5, 0.5]
            }
        },
        'coatnet': {
            'img_size': 384,
            'normalize': {
                'mean': [0.485, 0.456, 0.406],
                'std': [0.229, 0.224, 0.225]
            }
        }
    }
    
    config = preprocessing_configs.get(model_name.lower(), {
        'img_size': img_size,
        'normalize': {
            'mean': [0.485, 0.456, 0.406],
            'std': [0.229, 0.224, 0.225]
        }
    })
    
    return config



## Test Time Augmentation

In [None]:
def apply_tta(model, img_tensor, device):
    """Apply test-time augmentation for better predictions"""
    tta_transforms = [
        lambda x: x,  # Original
        lambda x: torch.flip(x, dims=[3]),  # Horizontal flip
        lambda x: torch.rot90(x, k=1, dims=[2, 3]),  # 90 degree rotation
        lambda x: torch.rot90(x, k=2, dims=[2, 3]),  # 180 degree rotation
        lambda x: torch.rot90(x, k=3, dims=[2, 3]),  # 270 degree rotation
    ]
    
    predictions = []
    model.eval()
    
    with torch.no_grad():
        for transform in tta_transforms:
            augmented = transform(img_tensor.to(device))
            pred = model(augmented)
            
            # Reverse transformation for predictions if needed
            # (not needed for classification, but important for segmentation)
            predictions.append(pred)
    
    # Average predictions
    final_pred = torch.stack(predictions).mean(dim=0)
    return final_pred