# ImageCAS Training Pipeline

This notebook demonstrates the complete training pipeline for coronary artery segmentation using the ImageCAS dataset.

**Contents:**
1. Dataset exploration and loading
2. Data preprocessing and augmentation
3. Model setup (U-Net with transfer learning)
4. Training loop
5. Evaluation and visualization
6. Results analysis across different data regimes

In [None]:
import sys
sys.path.append('..')

# Core libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from pathlib import Path
import yaml

# MONAI imports
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd,
    ScaleIntensityRanged, RandCropByPosNegLabeld, RandFlipd,
    RandRotate90d, ToTensord
)
from monai.data import Dataset, DataLoader, CacheDataset
from monai.networks.nets import UNet
from monai.losses import DiceLoss, DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference

# Visualization
import seaborn as sns
from IPython.display import display
import pandas as pd

# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 1. Configuration Loading

Load the base configuration and any specific data regime settings.

In [None]:
# Load base configuration
config_path = Path('../configs/base_config.yaml')

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Base Configuration:")
print(yaml.dump(config, default_flow_style=False))

# Optional: Load a specific data regime
# regime_path = Path('../configs/data_regimes/regime_20.yaml')
# with open(regime_path, 'r') as f:
#     regime_config = yaml.safe_load(f)
#     config.update(regime_config)

## 2. Dataset Exploration

Explore the ImageCAS dataset structure and contents.

In [None]:
# Define data directory
data_dir = Path(config['data']['data_dir'])

# Check if directory exists
if not data_dir.exists():
    print(f"⚠️  Data directory not found: {data_dir}")
    print("Please download the ImageCAS dataset from Kaggle and place it in the data/imagecas/ directory.")
    print("Dataset URL: https://www.kaggle.com/datasets/xiaoweixumedicalai/imagecas")
else:
    print(f"✓ Data directory found: {data_dir}")
    
    # List contents
    image_files = sorted(list(data_dir.rglob('*image*.nii.gz')))
    label_files = sorted(list(data_dir.rglob('*label*.nii.gz')))
    
    print(f"\nFound {len(image_files)} image files")
    print(f"Found {len(label_files)} label files")
    
    if len(image_files) > 0:
        print(f"\nExample files:")
        print(f"  Image: {image_files[0].name}")
        if len(label_files) > 0:
            print(f"  Label: {label_files[0].name}")

## 3. Data Preparation

Create data dictionaries and split into train/validation sets.

In [None]:
# Create data dictionaries
if data_dir.exists() and len(image_files) > 0:
    # Match images with labels
    data_dicts = []
    for img_path in image_files[:100]:  # Limit for quick testing
        # Try to find corresponding label
        # Adjust this logic based on actual ImageCAS naming convention
        img_id = img_path.stem.replace('.nii', '').replace('image', '')
        
        # Search for matching label
        label_path = None
        for lbl_path in label_files:
            if img_id in lbl_path.stem:
                label_path = lbl_path
                break
        
        if label_path:
            data_dicts.append({
                'image': str(img_path),
                'label': str(label_path)
            })
    
    print(f"Created {len(data_dicts)} data pairs")
    
    # Split into train/val (80/20)
    from sklearn.model_selection import train_test_split
    
    train_files, val_files = train_test_split(
        data_dicts, 
        test_size=0.2, 
        random_state=config['seed']
    )
    
    # Apply data regime if specified
    num_samples = config['data'].get('num_samples')
    if num_samples and num_samples < len(train_files):
        train_files = train_files[:num_samples]
        print(f"Using data regime with {num_samples} training samples")
    
    print(f"\nTrain samples: {len(train_files)}")
    print(f"Validation samples: {len(val_files)}")
else:
    print("Cannot create data dictionaries - data not available")
    train_files = []
    val_files = []

## 4. Data Transforms

Define preprocessing and augmentation transforms.

In [None]:
# Training transforms with augmentation
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(
        keys=["image", "label"],
        pixdim=config['data']['pixdim'],
        mode=("bilinear", "nearest")
    ),
    ScaleIntensityRanged(
        keys=["image"],
        a_min=config['data']['a_min'],
        a_max=config['data']['a_max'],
        b_min=0.0,
        b_max=1.0,
        clip=True
    ),
    RandCropByPosNegLabeld(
        keys=["image", "label"],
        label_key="label",
        spatial_size=config['data']['spatial_size'],
        pos=1,
        neg=1,
        num_samples=4
    ),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3),
])

# Validation transforms (no augmentation)
val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(
        keys=["image", "label"],
        pixdim=config['data']['pixdim'],
        mode=("bilinear", "nearest")
    ),
    ScaleIntensityRanged(
        keys=["image"],
        a_min=config['data']['a_min'],
        a_max=config['data']['a_max'],
        b_min=0.0,
        b_max=1.0,
        clip=True
    ),
])

print("Transforms defined successfully")

## 5. Create DataLoaders

Create PyTorch DataLoaders for training and validation.

In [None]:
if len(train_files) > 0:
    # Create datasets
    train_ds = CacheDataset(
        data=train_files,
        transform=train_transforms,
        cache_rate=0.5,
        num_workers=config['data']['num_workers']
    )
    
    val_ds = CacheDataset(
        data=val_files,
        transform=val_transforms,
        cache_rate=1.0,
        num_workers=config['data']['num_workers']
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_ds,
        batch_size=config['data']['batch_size'],
        shuffle=True,
        num_workers=config['data']['num_workers']
    )
    
    val_loader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=config['data']['num_workers']
    )
    
    print(f"✓ DataLoaders created")
    print(f"  Training batches: {len(train_loader)}")
    print(f"  Validation batches: {len(val_loader)}")
else:
    print("Cannot create DataLoaders - no training data available")

## 6. Visualize Sample Data

Load and visualize a sample from the training set.

In [None]:
if len(train_files) > 0:
    # Get a batch
    batch = next(iter(train_loader))
    images = batch['image']
    labels = batch['label']
    
    print(f"Batch shape: {images.shape}")
    print(f"Label shape: {labels.shape}")
    
    # Visualize middle slice
    sample_idx = 0
    slice_idx = images.shape[-1] // 2
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Image
    axes[0].imshow(images[sample_idx, 0, :, :, slice_idx].cpu(), cmap='gray')
    axes[0].set_title('CT Image (Axial)')
    axes[0].axis('off')
    
    # Label
    axes[1].imshow(labels[sample_idx, 0, :, :, slice_idx].cpu(), cmap='jet')
    axes[1].set_title('Segmentation Label')
    axes[1].axis('off')
    
    # Overlay
    axes[2].imshow(images[sample_idx, 0, :, :, slice_idx].cpu(), cmap='gray')
    overlay = labels[sample_idx, 0, :, :, slice_idx].cpu()
    axes[2].imshow(overlay, cmap='jet', alpha=0.3 * (overlay > 0))
    axes[2].set_title('Overlay')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

## 7. Model Setup

Create the U-Net model with optional transfer learning.

In [None]:
# Create U-Net model
model = UNet(
    spatial_dims=3,
    in_channels=config['model']['in_channels'],
    out_channels=config['model']['out_channels'],
    channels=config['model']['channels'],
    strides=config['model']['strides'],
    num_res_units=2,
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: {config['model']['architecture'].upper()}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Loss function
if config['training']['loss'] == 'dice':
    loss_function = DiceLoss(to_onehot_y=True, softmax=True)
elif config['training']['loss'] == 'dice_ce':
    loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
else:
    loss_function = DiceCELoss(to_onehot_y=True, softmax=True)

print(f"Loss function: {config['training']['loss']}")

# Optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=config['training']['learning_rate'],
    weight_decay=config['training']['weight_decay']
)

print(f"Optimizer: {config['training']['optimizer']}")
print(f"Learning rate: {config['training']['learning_rate']}")

# Metric
dice_metric = DiceMetric(include_background=False, reduction="mean")

print("\n✓ Model setup complete")

## 8. Training Loop

Train the model with validation.

In [None]:
from tqdm import tqdm

def train_epoch(model, loader, optimizer, loss_function, device):
    """Train for one epoch."""
    model.train()
    epoch_loss = 0
    
    for batch_data in tqdm(loader, desc="Training"):
        inputs = batch_data['image'].to(device)
        labels = batch_data['label'].to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(loader)

def validate(model, loader, metric, device):
    """Validate the model."""
    model.eval()
    metric.reset()
    
    with torch.no_grad():
        for batch_data in tqdm(loader, desc="Validation"):
            inputs = batch_data['image'].to(device)
            labels = batch_data['label'].to(device)
            
            # Inference
            outputs = model(inputs)
            
            # Compute metric
            outputs = torch.argmax(outputs, dim=1, keepdim=True)
            metric(y_pred=outputs, y=labels)
    
    # Aggregate metric
    dice_score = metric.aggregate().item()
    return dice_score

print("Training functions defined")

In [None]:
# Training configuration
num_epochs = 10  # Reduced for notebook testing; use config['training']['epochs'] for full training
best_dice = 0.0
train_losses = []
val_dices = []

if len(train_files) > 0:
    print(f"Starting training for {num_epochs} epochs...\n")
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print("-" * 50)
        
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, loss_function, device)
        train_losses.append(train_loss)
        print(f"Train Loss: {train_loss:.4f}")
        
        # Validate
        if (epoch + 1) % config['training']['val_interval'] == 0:
            val_dice = validate(model, val_loader, dice_metric, device)
            val_dices.append(val_dice)
            print(f"Validation Dice: {val_dice:.4f}")
            
            # Save best model
            if val_dice > best_dice:
                best_dice = val_dice
                checkpoint_dir = Path(config['logging']['checkpoint_dir'])
                checkpoint_dir.mkdir(parents=True, exist_ok=True)
                torch.save(
                    model.state_dict(),
                    checkpoint_dir / 'best_model.pth'
                )
                print(f"✓ Saved new best model (Dice: {best_dice:.4f})")
        
        print()
    
    print(f"\nTraining complete!")
    print(f"Best validation Dice: {best_dice:.4f}")
else:
    print("Cannot train - no training data available")

## 9. Training Curves

Visualize training loss and validation metrics.

In [None]:
if len(train_losses) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Training loss
    axes[0].plot(range(1, len(train_losses) + 1), train_losses, 'b-', marker='o')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss')
    axes[0].grid(True, alpha=0.3)
    
    # Validation Dice
    if len(val_dices) > 0:
        val_epochs = list(range(config['training']['val_interval'], 
                               len(train_losses) + 1, 
                               config['training']['val_interval']))
        axes[1].plot(val_epochs, val_dices, 'g-', marker='s')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Dice Score')
        axes[1].set_title('Validation Dice Score')
        axes[1].grid(True, alpha=0.3)
        axes[1].axhline(y=best_dice, color='r', linestyle='--', 
                       label=f'Best: {best_dice:.4f}')
        axes[1].legend()
    
    plt.tight_layout()
    plt.show()

## 10. Prediction Visualization

Visualize predictions on validation samples.

In [None]:
if len(val_files) > 0:
    model.eval()
    
    # Get a validation sample
    val_batch = next(iter(val_loader))
    val_image = val_batch['image'].to(device)
    val_label = val_batch['label'].to(device)
    
    # Predict
    with torch.no_grad():
        val_output = model(val_image)
        val_pred = torch.argmax(val_output, dim=1, keepdim=True)
    
    # Visualize
    slice_idx = val_image.shape[-1] // 2
    
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    # Image
    axes[0].imshow(val_image[0, 0, :, :, slice_idx].cpu(), cmap='gray')
    axes[0].set_title('Input Image')
    axes[0].axis('off')
    
    # Ground truth
    axes[1].imshow(val_label[0, 0, :, :, slice_idx].cpu(), cmap='jet')
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')
    
    # Prediction
    axes[2].imshow(val_pred[0, 0, :, :, slice_idx].cpu(), cmap='jet')
    axes[2].set_title('Prediction')
    axes[2].axis('off')
    
    # Overlay
    axes[3].imshow(val_image[0, 0, :, :, slice_idx].cpu(), cmap='gray')
    pred_overlay = val_pred[0, 0, :, :, slice_idx].cpu()
    axes[3].imshow(pred_overlay, cmap='jet', alpha=0.3 * (pred_overlay > 0))
    axes[3].set_title('Overlay')
    axes[3].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Compute Dice for this sample
    dice_metric.reset()
    dice_metric(y_pred=val_pred, y=val_label)
    sample_dice = dice_metric.aggregate().item()
    print(f"Dice score for this sample: {sample_dice:.4f}")

## 11. Export Results

Save training history and metrics to CSV for further analysis.

In [None]:
if len(train_losses) > 0:
    # Create results dataframe
    results = pd.DataFrame({
        'epoch': range(1, len(train_losses) + 1),
        'train_loss': train_losses
    })
    
    # Add validation scores (where available)
    results['val_dice'] = None
    val_epochs = list(range(config['training']['val_interval'], 
                           len(train_losses) + 1, 
                           config['training']['val_interval']))
    for i, epoch in enumerate(val_epochs):
        if i < len(val_dices):
            results.loc[epoch - 1, 'val_dice'] = val_dices[i]
    
    # Save to CSV
    results_dir = Path('../experiments/results')
    results_dir.mkdir(parents=True, exist_ok=True)
    results.to_csv(results_dir / 'training_history.csv', index=False)
    
    print("✓ Results saved to experiments/results/training_history.csv")
    print("\nTraining Summary:")
    display(results)

## 12. Next Steps

**Experiment with different configurations:**
1. Test different data regimes (5, 10, 20, 30, 50, 100, 200 cases)
2. Try transfer learning with pretrained encoders (ImageNet, RadImageNet, CT-FM)
3. Evaluate on ASOCA dataset for benchmark comparison
4. Implement 5-fold cross-validation
5. Generate publication-ready figures and tables

**Model improvements:**
- Experiment with deeper architectures
- Try attention mechanisms
- Test different loss functions (Focal, Tversky)
- Implement post-processing (connected components, morphological operations)