### Import Modules

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # %env CUDA_VISIBLE_DEVICES=0
import numpy as np
import pandas as pd
import nibabel as nib
import torch
import torch.nn as nn
from torch.utils.data import Subset
import time
import logging
import matplotlib.pyplot as plt
import sys
from sklearn.model_selection import train_test_split
from monai.data import Dataset, DataLoader
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    DivisiblePadd,
    Resized,
    ScaleIntensityd,
    RandAdjustContrastd,
    RandGaussianNoised
)
from monai.networks.nets import Regressor, ResNet, DenseNet121, SEResNet50, EfficientNetBN, ViT
from monai.metrics import MAEMetric
from monai.utils import first, set_determinism
import shap
from nilearn import plotting

### Define Functions and Classes

In [None]:
def calculate_required_divisibility(model_name, **divisibility_kwargs):
    if model_name in ["Regressor", "ResNet50", "DenseNet121", "SEResNet50", "EfficientNetB0", "SFCN"]:
        # These models use adaptive pooling - any size works
        return None
    elif model_name == "ViT":
        # ViT: input must be divisible by patch_size
        patch_size = divisibility_kwargs.get('patch_size', 16)
        if isinstance(patch_size, (list, tuple)):
            return patch_size[0]
        else:
            return patch_size
    else:
        print(f"Unsupported model name: {model_name}")
        return None

def load_data(data_dir, modalities, batch_size, resize_dim=None, test_size=0.2, inference=False, model_name=None, **divisibility_params):
    df = pd.read_csv(os.path.join(data_dir, 'Subjects.csv'))
    subjects = df['ID'].apply(lambda x: f'{x:03d}').to_numpy()
    data_dicts = []
    for index, subject in enumerate(subjects):
        subject_dict = {}
        for modality in modalities:
            subject_dict[modality] = os.path.join(data_dir, modality, f"{subject}.nii.gz")
        if not inference:
            subject_dict['Age'] = df['Age'].to_numpy()[index]
        data_dicts.append(subject_dict)
    if not inference: # Training/Validation
        # Define training transforms with minimal intensity-based augmentatioin
        train_transforms = [
            LoadImaged(keys=modalities, image_only=True),
            EnsureChannelFirstd(keys=modalities),
            ScaleIntensityd(keys=modalities, minv=0, maxv=1),
            RandAdjustContrastd(keys=modalities, prob=0.3, gamma=(0.9, 1.1)),
            RandGaussianNoised(keys=modalities, prob=0.3, std=0.01)
        ]
        # Define validation transforms without augmentation
        val_transforms = [
            LoadImaged(keys=modalities, image_only=True),
            EnsureChannelFirstd(keys=modalities),
            ScaleIntensityd(keys=modalities, minv=0, maxv=1)
        ]
        # Add padding or resizing before normalization (ScaleIntensityd)
        if resize_dim is None:
            required_k = calculate_required_divisibility(model_name, **divisibility_params)
            if required_k is not None and required_k > 1:
                trainval_pad_transform = DivisiblePadd(keys=modalities, k=required_k, method="symmetric")
                train_transforms.insert(2, trainval_pad_transform)
                val_transforms.insert(2, trainval_pad_transform)
        else:
            trainval_resize_transform = Resized(keys=modalities, spatial_size=resize_dim, mode="trilinear")
            train_transforms.insert(2, trainval_resize_transform)
            val_transforms.insert(2, trainval_resize_transform)
        train_transforms = Compose(train_transforms)
        val_transforms = Compose(val_transforms)
        # Split data into train and validation sets
        train_files, val_files = train_test_split(data_dicts, test_size=test_size, random_state=42)
        train_ds = Dataset(data=train_files, transform=train_transforms)
        val_ds = Dataset(data=val_files, transform=val_transforms)
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
        val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
        return train_loader, val_loader
    else: # Test
        # Define test transforms
        test_transforms = [
            LoadImaged(keys=modalities, image_only=True),
            EnsureChannelFirstd(keys=modalities),
            ScaleIntensityd(keys=modalities, minv=0, maxv=1)
        ]
        # Add padding or resizing
        if resize_dim is None:
            required_k = calculate_required_divisibility(model_name, **divisibility_params)
            if required_k is not None and required_k > 1:
                pad_transform = DivisiblePadd(keys=modalities, k=required_k, method="symmetric")
                test_transforms.insert(2, pad_transform)
        else:
            resize_transform = Resized(keys=modalities, spatial_size=resize_dim, mode="trilinear")
            test_transforms.insert(2, resize_transform)
        test_transforms = Compose(test_transforms)
        test_ds = Dataset(data=data_dicts, transform=test_transforms)
        test_loader = DataLoader(test_ds, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
        return test_loader, subjects

class SFCN(nn.Module): # https://github.com/ha-ha-ha-han/UKBiobank_deep_pretrain/blob/master/dp_model/model_files/sfcn.py
    def __init__(self, input_channels, output_dim, channel_number=[32, 64, 128, 256, 256, 64], feature_dim=64, dropout=True):
        super(SFCN, self).__init__()
        n_layer = len(channel_number)
        self.feature_extractor = nn.Sequential()
        # Build convolutional layers
        for i in range(n_layer):
            in_channel = input_channels if i == 0 else channel_number[i-1]
            out_channel = channel_number[i]
            if i < n_layer-1:
                self.feature_extractor.add_module(
                    f'conv_{i}',
                    self.conv_layer(in_channel, out_channel, maxpool=True, kernel_size=3, padding=1)
                )
            else:
                self.feature_extractor.add_module(
                    f'conv_{i}',
                    self.conv_layer(in_channel, out_channel, maxpool=False, kernel_size=1, padding=0)
                )
        # Pooling and projection
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool3d((1, 1, 1)),
            nn.Dropout(0.5) if dropout else nn.Identity(),
            nn.Conv3d(channel_number[-1], feature_dim, kernel_size=1)
        )
        # Additional FC layers for refinement
        self.regressor = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(feature_dim, output_dim)
        )
    @staticmethod
    def conv_layer(in_channel, out_channel, maxpool=True, kernel_size=3, padding=0, maxpool_stride=2):
        if maxpool:
            layer = nn.Sequential(
                nn.Conv3d(in_channel, out_channel, padding=padding, kernel_size=kernel_size),
                nn.BatchNorm3d(out_channel),
                nn.MaxPool3d(2, stride=maxpool_stride),
                nn.ReLU(),
            )
        else:
            layer = nn.Sequential(
                nn.Conv3d(in_channel, out_channel, padding=padding, kernel_size=kernel_size),
                nn.BatchNorm3d(out_channel),
                nn.ReLU()
            )
        return layer
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.regressor(x)
        return x

def get_model(model_name, num_input_channels, img_size):
    if model_name == "Regressor":
        params = {
            'in_shape': [num_input_channels, *img_size], # Required: input shape
            'out_shape': 1, # Required: output dimension
            'channels': (16, 32, 64, 128, 256), # Required: sequence of feature channels
            'strides': (2, 2, 2, 2) # Required: sequence of convolution strides
        }
        model = Regressor(**params)
    elif model_name == "ResNet50":
        params = {
            'spatial_dims': 3, # Default: spatial dimensions
            'n_input_channels': num_input_channels, # Non-default: input channels
            'num_classes': 1, # Non-default: regression output
            'block': "bottleneck", # Required: bottleneck block for ResNet50
            'layers': [3, 4, 6, 3], # Required: layer configuration for ResNet50
            'block_inplanes': [64, 128, 256, 512] # Required: sequence of bottleneck internal channels; tunable with widen_factor (default=1.0)
            # Output channels = block_inplanes × widen_factor × expansion
            #                 = [64, 128, 256, 512] × 1.0 × 4
            #                 = [256, 512, 1024, 2048]
        }
        model = ResNet(**params)
    elif model_name == "DenseNet121":
        params = {
            'spatial_dims': 3, # Default: spatial dimensions
            'in_channels': num_input_channels, # Required: input channels
            'out_channels': 1 # Required: regression output
        }
        model = DenseNet121(**params)
    elif model_name == "SEResNet50":
        params = {
            'spatial_dims': 3, # Required: spatial dimensions
            'in_channels': num_input_channels, # Required: input channels
            'num_classes': 1 # Non-default: regression output
        }
        model = SEResNet50(**params)
    elif model_name == "EfficientNetB0":
        params = {
            'spatial_dims': 3, # Non-default: spatial dimensions
            'in_channels': num_input_channels, # Non-default: input channels
            'num_classes': 1, # Non-default: regression output
            'model_name': "efficientnet-b0" # Required: EfficientNet variant
        }
        model = EfficientNetBN(**params)
    elif model_name == "ViT":
        params = {
            'spatial_dims': 3, # Default: spatial dimensions
            'in_channels': num_input_channels, # Required: input channels
            'num_classes': 1, # Non-default: regression output
            'img_size': img_size, # Required: input image size
            'patch_size': 16, # Required: spatial patch size for tokenization
            'hidden_size': 768, # Default: transformer embedding dimension
            'mlp_dim': 3072, # Default: MLP dimension (typically 4 × hidden_size)
            'num_layers': 12, # Default: number of transformer blocks
            'num_heads': 12, # Default: number of attention heads
            'classification': True # Non-default: add prediction head (if Flase, feature extraction only)
        }
        model = ViT(**params)
    elif model_name == "SFCN":
        params = {
            'input_channels': num_input_channels, # Required: input channels
            'output_dim': 1, # Required: regression output
            'channel_number': [32, 64, 128, 256, 256, 64], # Default: sequence of feature channels
            'feature_dim': 64, # Default: refined feature dimension
            'dropout': True # Default: use dropout for regularization
        }
        model = SFCN(**params)
    else:
        raise ValueError(f"Unsupported model name: {model_name}")
    return model

def get_grad_scaler(device):
    if device.type != "cuda":
        return None
    try: # Try newest API first (PyTorch 2.0+)
        return torch.GradScaler("cuda")
    except (AttributeError, TypeError):
        try: # Try torch.amp (PyTorch 1.10+)
            return torch.amp.GradScaler("cuda")
        except (AttributeError, TypeError): # Fall back to old API
            return torch.cuda.amp.GradScaler()

def get_autocast_context(device, enabled=True):
    if not enabled:
        from contextlib import nullcontext
        return nullcontext()
    try:
        # Try newest API first (PyTorch 2.0+)
        return torch.autocast(device_type=device.type, dtype=torch.float16)
    except (AttributeError, TypeError):
        try:
            # Try torch.amp (PyTorch 1.10+)
            return torch.amp.autocast(device.type)
        except (AttributeError, TypeError):
            # Fall back to CUDA-specific (old)
            if device.type == "cuda":
                return torch.cuda.amp.autocast()
            else:
                from contextlib import nullcontext
                return nullcontext()

def train_one_epoch(model, device, train_loader, modalities, optimizer, criterion, scaler, metric):
    model.train() # Set model to training mode
    epoch_loss = 0.0
    metric.reset()
    for batch_data in train_loader:
        # Prepare data
        targets = batch_data['Age'].unsqueeze(1).to(device)
        images = [batch_data[modality].to(device) for modality in modalities]
        inputs = torch.cat(images, dim=1)
        # Forward pass with mixed precision (if available)
        optimizer.zero_grad()
        with get_autocast_context(device, enabled=(scaler is not None)):
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        # Backward pass with gradient scaling (if available)
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        # Accumulate metrics
        epoch_loss += loss.item()
        metric(y_pred=outputs, y=targets)
    epoch_metric = metric.aggregate().item()
    return epoch_loss / len(train_loader), epoch_metric

def validate_one_epoch(model, device, val_loader, modalities, metric):
    model.eval()
    metric.reset()
    with torch.no_grad():
        for batch_data in val_loader:
            targets = batch_data['Age'].unsqueeze(1).to(device)
            images = [batch_data[modality].to(device) for modality in modalities]
            inputs = torch.cat(images, dim=1)
            outputs = model(inputs)
            metric(y_pred=outputs, y=targets)
    return metric.aggregate().item()

class EarlyStopping:
    def __init__(self, patience=30, delta=0):
        self.patience = patience # Number of epochs to wait before stopping
        self.delta = delta # Minimum improvement threshold
        self.best_score = None
        self.early_stop = False
        self.counter = 0
    def __call__(self, metric):
        score = metric
        if self.best_score is None: # First epoch
            self.best_score = score
        elif score > self.best_score + self.delta: # Metric decreased
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else: # Metric improved
            self.best_score = score
            self.counter = 0

def train_model(model_dir, model, device, train_loader, val_loader, modalities, logger,
        criterion, metric, max_epochs=100, learning_rate=1e-4, weight_decay=1e-5, val_interval=1, es_patience=30):
    # Setup optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    scaler = get_grad_scaler(device)
    start_time = time.time()
    best_metric = float("inf")
    best_metric_epoch = -1
    best_model_state = None
    early_stopping = EarlyStopping(patience=es_patience, delta=0)
    epoch_loss_values, epoch_metric_values, metric_values = [], [], []
    for epoch in range(max_epochs):
        epoch_start_time = time.time()
        # Training phase
        epoch_loss, epoch_metric = train_one_epoch(model, device, train_loader, modalities, optimizer, criterion, scaler, metric)
        epoch_loss_values.append(epoch_loss)
        epoch_metric_values.append(epoch_metric)
        # Validation phase
        if (epoch + 1) % val_interval == 0:
            val_metric = validate_one_epoch(model, device, val_loader, modalities, metric)
            metric_values.append(val_metric)
            # Save best model
            if val_metric < best_metric:
                best_metric = val_metric
                best_metric_epoch = epoch + 1
                best_model_state = model.state_dict()
                torch.save(model.state_dict(), os.path.join(model_dir, "BestMetricModel.pth"))
                logger.info(f"Best MAE: {best_metric:.4f} at epoch {best_metric_epoch}")
            # Check early stopping
            early_stopping(val_metric)
            if early_stopping.early_stop:
                logger.info(f"Early stopping triggered at epoch {epoch + 1}")
                print(f"; Early stopping triggered at epoch {epoch + 1}", end="")
                break
        epoch_end_time = time.time()
        logger.info(
            f"Epoch {epoch + 1} completed for {(epoch_end_time - epoch_start_time)/60:.2f} mins - "
            f"Training loss: {epoch_loss:.4f}, Training MAE: {epoch_metric:.4f}, Validation MAE: {val_metric:.4f}"
        )
        # Update learning rate
        lr_scheduler.step()
        sys.stdout.write(f"\rEpoch {epoch + 1}/{max_epochs} completed")
        sys.stdout.flush()
    end_time = time.time()
    total_time = end_time - start_time
    logger.info(
        f"Best MAE: {best_metric:.3f} at epoch {best_metric_epoch}; "
        f"Total time consumed: {total_time/60:.2f} mins"
    )
    print(
        f"\nBest MAE: {best_metric:.3f} at epoch {best_metric_epoch}; "
        f"Total time consumed: {total_time/60:.2f} mins"
    )
    # Load best model weights
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model, epoch_loss_values, epoch_metric_values, metric_values

def plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval=1):
    _, axs = plt.subplots(1, 2, figsize=(12, 5))
    axs[0].plot([i + 1 for i in range(len(epoch_loss_values))], epoch_loss_values, label='Training Loss', color='red')
    axs[0].set_title('Training Loss')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')
    axs[1].plot([i + 1 for i in range(len(epoch_metric_values))], epoch_metric_values, label='Training MAE', color='red')
    axs[1].plot([val_interval * (i + 1) for i in range(len(metric_values))], metric_values, label='Validation MAE', color='blue')
    axs[1].set_title('Training MAE vs. Validation MAE')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('MAE')
    axs[1].legend()
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, "Performance.png"), dpi=300)

def select_representative_samples(dataset, age_thresholds=(40, 60)):
    # Extract all ages
    ages = []
    indices = []
    for idx in range(len(dataset)):
        sample = dataset[idx]
        ages.append(sample['Age'])
        indices.append(idx)
    ages = np.array(ages)
    indices = np.array(indices)
    young_thresh, old_thresh = age_thresholds
    # Define groups
    groups = {
        'Young': (ages < young_thresh),
        'Middle': (ages >= young_thresh) & (ages < old_thresh),
        'Old': (ages >= old_thresh)
    }
    representative_samples = {}
    for group_name, mask in groups.items():
        if not mask.any():
            print(f"Warning: No samples in {group_name} group")
            continue
        group_ages = ages[mask]
        group_indices = indices[mask]
        # Find sample closest to median
        group_median = np.median(group_ages)
        closest_idx = np.argmin(np.abs(group_ages - group_median))
        representative_samples[group_name] = {
            'index': group_indices[closest_idx],
            'age': group_ages[closest_idx],
            'group_median': group_median,
            'group_range': (group_ages.min(), group_ages.max()),
            'group_count': len(group_ages)
        }
    return representative_samples, ages

class SHAP3D:
    def __init__(self, model, modalities, device, batch_size=1):
        self.model = model
        self.modalities = modalities
        self.device = device
        self.batch_size = batch_size
        self.explainer = None
    def create_background(self, dl, num_samples=20): # Create background dataset for SHAP
        dataset = dl.dataset
        if num_samples is not None:
            subset_indices = np.random.choice(len(dataset), num_samples, replace=False)
            subset_ds = Subset(dataset, subset_indices)
            subset_dl = DataLoader(subset_ds, batch_size=self.batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
        else:
            subset_dl = dl
        background_images = []
        for batch_data in subset_dl:
            images = [batch_data[modality].to(self.device) for modality in self.modalities]
            inputs = torch.cat(images, dim=1)
            background_images.append(inputs)
        background_images = torch.cat(background_images, dim=0)
        self.explainer = shap.GradientExplainer(self.model, background_images)
    def compute_shap_values(self, dl, num_samples=None): # Compute SHAP values for samples
        dataset = dl.dataset
        if num_samples is not None:
            subset_indices = np.random.choice(len(dataset), num_samples, replace=False)
            subset_ds = Subset(dataset, subset_indices)
            subset_dl = DataLoader(subset_ds, batch_size=self.batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
        else:
            subset_dl = dl
        shap_values_list = []
        images_list = []
        for batch_data in subset_dl:
            images = [batch_data[modality].to(self.device) for modality in self.modalities]
            inputs = torch.cat(images, dim=1)
            shap_values = self.explainer.shap_values(inputs)
            shap_values_list.append(shap_values)
            images_list.append(inputs.cpu().numpy())
        self.shap_values = np.concatenate(shap_values_list, axis=0)
        self.images = np.concatenate(images_list, axis=0)
    def visualize_shap(self, sample_img_path, shap_dir, vmin=-0.0025, vmax=0.0025): # Visualize mean absolute SHAP values using glass brain plots
        reference_img = nib.load(sample_img_path)
        os.makedirs(shap_dir, exist_ok=True)
        _, axs = plt.subplots(len(self.modalities), 1, figsize=(12, len(self.modalities) * 4))
        if len(self.modalities) == 1:
            axs = [axs]
        for i, modality in enumerate(self.modalities):
            # Extract SHAP values for this modality
            shap_values = self.shap_values[:, i, :, :, :, 0]
            feature_values = self.images[:, i, :, :, :]
            # Compute mean absolute SHAP values
            mean_abs_shap = np.mean(np.abs(shap_values), axis=0)
            # Mask non-brain regions
            common_mask = np.all(feature_values != 0, axis=0)
            masked_shap = np.zeros_like(mean_abs_shap)
            masked_shap[common_mask] = mean_abs_shap[common_mask]
            # Save as NIfTI
            shap_img = nib.Nifti1Image(
                masked_shap, 
                affine=reference_img.affine,
                header=reference_img.header
            )
            shap_img.header['descrip'] = 'Mean absolute SHAP values'
            nib.save(shap_img, os.path.join(shap_dir, f"MeanAbsSHAPValues_{modality}.nii.gz"))
            # Plot glass brain
            plotting.plot_glass_brain(
                shap_img, threshold=None, annotate=False,
                plot_abs=False, black_bg='auto', axes=axs[i],
                colorbar=True, cmap='black_red', symmetric_cbar=False,
                alpha=0.3, vmin=vmin, vmax=vmax
            )
            axs[i].set_title(f"{modality}")
        plt.tight_layout()
        plt.savefig(os.path.join(shap_dir, "SHAP_GlassBrain.png"), dpi=300)
        plt.show()

def apply_best_model(model_dir, model, device, test_loader, modalities, pred_dir, subjects):
    # Load best model weights
    model.load_state_dict(torch.load(os.path.join(model_dir, "BestMetricModel.pth")))
    model.eval()
    os.makedirs(pred_dir, exist_ok=True)
    pred_values = []
    with torch.no_grad():
        for batch_data in test_loader:
            images = [batch_data[modality].to(device) for modality in modalities]
            inputs = torch.cat(images, dim=1)
            # Model inference
            outputs = model(inputs)
            pred_values.extend(outputs.cpu().numpy().flatten())
    # Save predictions
    pred_df = pd.DataFrame({
        'ID': subjects,
        'PredictedAge': pred_values
    })
    pred_df.to_csv(os.path.join(pred_dir, "PredictedAge.csv"), index=False)

### Prepare Inputs

In [None]:
data_dir = os.path.join("AgePrediction", "Datasets")
model_dir_prefix = "AgePrediction"
model_name = "Regressor" # any supported model name: Regressor, ResNet50, DenseNet121, SEResNet50, EfficientNetB0, ViT, SFCN
modalities = ["GM", "WM", "CSF"]
resize_dim = None # Use padding or specify tuple for resizing
test_size = 0.2
batch_size = 5
max_epochs = 100
learning_rate = 1e-4
weight_decay = 1e-5
val_interval = 1
es_patience = 30

# Setup output directory and logging
model_dir = f"{model_dir_prefix}_{model_name}_{'+'.join(modalities)}"
os.makedirs(model_dir, exist_ok=True)
log_file = os.path.join(model_dir, "Training.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format="%(message)s")
logger = logging.getLogger()

### Read Data

In [None]:
set_determinism(seed=0)
train_loader, val_loader = load_data(
    os.path.join(data_dir, "train"), modalities, batch_size, resize_dim=resize_dim, test_size=test_size,
    inference=False, model_name=model_name
)

# Check data shape
tr = first(train_loader)
img_size = tuple(tr[modalities[0]].shape[-3:])
print('\nData shape for training:')
for key, value in tr.items():
    print(f'\u2022 {key}: {tuple(value.shape)} × {len(train_loader)}')
vl = first(val_loader)
print('\nData shape for validation:')
for key, value in vl.items():
    print(f'\u2022 {key}: {tuple(value.shape)} × {len(val_loader)}')

# Visualize data with adaptive sizing
slice_index = img_size[2] // 2  # Middle slice
num_modalities = len(modalities)
if num_modalities <= 2:
    fig_width = num_modalities * 6
    fig_height = 6
elif num_modalities <= 4:
    fig_width = num_modalities * 4
    fig_height = 5
else:
    fig_width = num_modalities * 3
    fig_height = 4
_, axs = plt.subplots(1, num_modalities, figsize=(fig_width, fig_height))
if num_modalities == 1:
    axs = [axs]
for i, modality in enumerate(modalities):
    image = tr[modality][0, 0, :, :, :].detach().cpu()
    img_slice = torch.rot90(image[:, :, slice_index], k=1, dims=(0, 1))
    axs[i].imshow(img_slice, cmap='gray')
    axs[i].set_title(modality)
    axs[i].axis('off')
plt.tight_layout()
plt.show()

### Train Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model(model_name, len(modalities), img_size)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model.to(device)
print(f"Selected model: {model_name}")
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params:,}")
criterion = nn.L1Loss()
metric = MAEMetric(reduction="mean")
model, epoch_loss_values, epoch_metric_values, metric_values = train_model(
    model_dir, model, device, train_loader, val_loader, modalities, logger,
    criterion, metric, max_epochs, learning_rate, weight_decay, val_interval, es_patience
)
plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval)

# Visualize outcome
representative_samples, all_ages = select_representative_samples(val_loader.dataset)
plt.figure(figsize=(10, 4))
plt.hist(all_ages, bins=20, alpha=0.7, edgecolor='black')
for group_name, info in representative_samples.items():
    color = {'Young': 'green', 'Middle': 'orange', 'Old': 'red'}[group_name]
    plt.axvline(info['age'], color=color, linestyle='--', linewidth=2, label=f"{group_name}: {info['age']:.1f} yrs")
plt.xlabel('Age (years)')
plt.ylabel('Count')
plt.title('Age Distribution with Representative Samples')
plt.legend()
plt.tight_layout()
plt.show()
model.eval()
sample_info = []
for group_name, info in representative_samples.items():
    sample_info.append((info['index'], info['age'], group_name))
sample_info.sort(key=lambda x: x[1])
num_samples = len(sample_info)
num_modalities = len(modalities)
fig, axs = plt.subplots(num_samples, num_modalities, figsize=(num_modalities * 4, num_samples * 4))
if num_samples == 1:
    axs = axs.reshape(1, -1)
if num_modalities == 1:
    axs = axs.reshape(-1, 1)
slice_index = img_size[2] // 2  # Middle slice
for row, (sample_idx, true_age, group_name) in enumerate(sample_info):
    with torch.no_grad():
        sample = val_loader.dataset[sample_idx]
        target = torch.tensor(sample['Age']).unsqueeze(0).unsqueeze(1).to(device)
        images = [sample[modality].unsqueeze(0).to(device) for modality in modalities]
        inputs = torch.cat(images, dim=1)
        output = model(inputs)
        pred_age = output.item()
        error = pred_age - true_age
    for col, modality in enumerate(modalities):
        image = images[col][0, 0, :, :, :].detach().cpu()
        img_slice = torch.rot90(image[:, :, slice_index], k=1, dims=(0, 1))
        axs[row, col].imshow(img_slice, cmap='gray')
        if col == 0:
            axs[row, col].set_ylabel(f'{group_name}\n({true_age:.0f} yrs)', fontsize=12)
        axs[row, col].set_title(modality)
        axs[row, col].axis('off')
title_lines = []
for sample_idx, true_age, group_name in sample_info:
    with torch.no_grad():
        sample = val_loader.dataset[sample_idx]
        images = [sample[modality].unsqueeze(0).to(device) for modality in modalities]
        inputs = torch.cat(images, dim=1)
        output = model(inputs)
        pred_age = output.item()
        error = pred_age - true_age
    title_lines.append(
        f'{group_name}: True={true_age:.1f} yrs, Pred={pred_age:.1f} yrs, Error={error:+.1f} yrs'
    )
plt.suptitle('\n'.join(title_lines), fontsize=11, y=1.02)
plt.tight_layout()
plt.show()

### SHAP

In [None]:
shap_analyzer = SHAP3D(model, modalities, device, batch_size=1)
shap_analyzer.create_background(train_loader, num_samples=10)
shap_analyzer.compute_shap_values(val_loader, num_samples=10)
sample_img_path = os.path.join(data_dir, 'train', modalities[0], '001.nii.gz')
shap_dir = os.path.join(model_dir, "SHAP")
vmin = 0
vmax = 0.01
shap_analyzer.visualize_shap(sample_img_path, shap_dir, vmin=vmin, vmax=vmax)

### Inference

In [None]:
test_loader, subjects = load_data(
    os.path.join(data_dir, "test"), modalities, None, resize_dim=resize_dim, test_size=None,
    inference=True, model_name=model_name
)
pred_dir = os.path.join(model_dir, "Prediction")
apply_best_model(model_dir, model, device, test_loader, modalities, pred_dir, subjects)