### Import Modules

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Subset
import time
import logging
import matplotlib.pyplot as plt
import sys
from sklearn.model_selection import train_test_split
from monai.data import Dataset, DataLoader
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    ScaleIntensityd
)
from monai.networks.nets import Regressor, ResNet, DenseNet121, DenseNet169, SEResNet50, SENet154, EfficientNetBN, ViT
from monai.metrics import MAEMetric
from monai.utils import first, set_determinism
import shap
import nibabel as nib
from nilearn import plotting
%env CUDA_VISIBLE_DEVICES=0

### Define Functions and Classes

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data, transforms=None):
        self.data = data
        self.transforms = transforms
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        data_transformed = self.transforms(self.data[idx]) if self.transforms else self.data[idx]
        return data_transformed

def load_data(data_dir, modalities, additional_variables, batch_size, test_size=0.2, inference=False):
    df = pd.read_csv(os.path.join(data_dir, 'Subjects.csv'))
    subjects = df['No'].apply(lambda x: f'{x:03d}').to_numpy()
    data_dicts = []
    for index, subject in enumerate(subjects):
        subject_dict = {}
        for modality in modalities:
            subject_dict[modality] = os.path.join(data_dir, modality, f"{subject}.nii.gz")
        if inference:
            for variable in additional_variables:
                if variable != 'Age':
                    subject_dict[variable] = df[variable].to_numpy()[index]
        else:
            for variable in additional_variables:
                subject_dict[variable] = df[variable].to_numpy()[index]
        data_dicts.append(subject_dict)
    img_transforms = Compose([
        LoadImaged(keys=modalities, image_only=True),
        EnsureChannelFirstd(keys=modalities),
        ScaleIntensityd(keys=modalities, minv=0, maxv=1)
    ])
    ds = CustomDataset(data=data_dicts, transforms=img_transforms)
    if inference:
        test_loader = DataLoader(ds, batch_size=batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
        return test_loader, subjects
    else:
        train_ds, val_ds = train_test_split(ds, test_size=test_size, random_state=42)
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
        val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
        return train_loader, val_loader

class ViTFeatureExtractor(nn.Module):
    def __init__(self, input_channels, img_size, img_features=64):
        super().__init__()
        self.vit = ViT(
            in_channels=input_channels,
            img_size=img_size,
            patch_size=(16, 16, 16),
            classification=False,
        )
        self.fc = nn.Linear(768, img_features)
    def forward(self, x):
        vit_output = self.vit(x)
        x = vit_output[0]
        x = x[:, 0, :]
        x = self.fc(x)
        return x

class SFCN(nn.Module):
    def __init__(self, input_channels=1, channel_number=[32, 64, 128, 256, 256, 64], output_dim=1, dropout=True):
        super(SFCN, self).__init__()
        n_layer = len(channel_number)
        self.feature_extractor = nn.Sequential()
        for i in range(n_layer):
            if i == 0:
                in_channel = input_channels
            else:
                in_channel = channel_number[i-1]
            out_channel = channel_number[i]
            if i < n_layer-1:
                self.feature_extractor.add_module('conv_%d' % i,
                                                  self.conv_layer(in_channel,
                                                                  out_channel,
                                                                  maxpool=True,
                                                                  kernel_size=3,
                                                                  padding=1))
            else:
                self.feature_extractor.add_module('conv_%d' % i,
                                                  self.conv_layer(in_channel,
                                                                  out_channel,
                                                                  maxpool=False,
                                                                  kernel_size=1,
                                                                  padding=0))
        self.regressor = nn.Sequential()
        avg_shape = [4, 5, 4]
        self.regressor.add_module('average_pool', nn.AvgPool3d(avg_shape))
        if dropout is True:
            self.regressor.add_module('dropout', nn.Dropout(0.5))
        self.regressor.add_module('final_conv', nn.Conv3d(channel_number[-1], output_dim, padding=0, kernel_size=1))
    @staticmethod
    def conv_layer(in_channel, out_channel, maxpool=True, kernel_size=3, padding=0, maxpool_stride=2):
        if maxpool is True:
            layer = nn.Sequential(
                nn.Conv3d(in_channel, out_channel, padding=padding, kernel_size=kernel_size),
                nn.BatchNorm3d(out_channel),
                nn.MaxPool3d(2, stride=maxpool_stride),
                nn.ReLU(),
            )
        else:
            layer = nn.Sequential(
                nn.Conv3d(in_channel, out_channel, padding=padding, kernel_size=kernel_size),
                nn.BatchNorm3d(out_channel),
                nn.ReLU()
            )
        return layer
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.regressor(x)
        x = x.view(x.size(0), -1)
        return x

class CNN3D(nn.Module):
    def __init__(self, input_channels=1, output_dim=1):
        super(CNN3D, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv3d(input_channels, out_channels=16, kernel_size=6, padding=2)
        self.pool = nn.MaxPool3d(2)
        self.bn1 = nn.BatchNorm3d(16)
        self.conv2 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=6, padding=2)
        self.bn2 = nn.BatchNorm3d(32)
        self.dropout1 = nn.Dropout3d(0.2)
        self.conv3 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=6, padding=2)
        self.bn3 = nn.BatchNorm3d(64)
        self.dropout2 = nn.Dropout3d(0.2)
        self.conv4 = nn.Conv3d(in_channels=64, out_channels=128, kernel_size=6, padding=2)
        self.bn4 = nn.BatchNorm3d(128)
        self.dropout3 = nn.Dropout3d(0.2)
        self.global_avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc1 = nn.Linear(128, 128)
        self.dropout4 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, output_dim)
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.bn1(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.bn2(x)
        x = self.dropout1(x)
        x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = self.bn3(x)
        x = self.dropout2(x)
        x = self.relu(self.conv4(x))
        x = self.pool(x)
        x = self.bn4(x)
        x = self.dropout3(x)
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.dropout4(x)
        x = self.relu(self.fc2(x))
        return x

def get_model(model_name, input_channels, img_size, img_features):
    if model_name == "Regressor":
        return Regressor(
            in_shape=[input_channels, *img_size],
            out_shape=img_features,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2)
        )
    elif model_name == "ResNet18":
        return ResNet(
            block="basic",
            layers=[2, 2, 2, 2],
            block_inplanes=[64, 128, 256, 512],
            spatial_dims=3,
            n_input_channels=input_channels,
            num_classes=img_features
        )
    elif model_name == "ResNet50":
        return ResNet(
            block="bottleneck",
            layers=[3, 4, 6, 3],
            block_inplanes=[64, 128, 256, 512],
            spatial_dims=3,
            n_input_channels=input_channels,
            num_classes=img_features
        )
    elif model_name == "DenseNet121":
        return DenseNet121(
            spatial_dims=3,
            in_channels=input_channels,
            out_channels=img_features
        )
    elif model_name == "DenseNet169":
        return DenseNet169(
            spatial_dims=3,
            in_channels=input_channels,
            out_channels=img_features,
        )
    elif model_name == "SEResNet50":
        return SEResNet50(
            spatial_dims=3,
            in_channels=input_channels,
            num_classes=img_features
        )
    elif model_name == "SENet154":
        return SENet154(
            spatial_dims=3,
            in_channels=input_channels,
            num_classes=img_features
        )
    elif model_name == "EfficientNetB0":
        return EfficientNetBN(
            model_name="efficientnet-b0",
            spatial_dims=3,
            in_channels=input_channels,
            num_classes=img_features,
        )
    elif model_name == "EfficientNetB2":
        return EfficientNetBN(
            model_name="efficientnet-b2",
            spatial_dims=3,
            in_channels=input_channels,
            num_classes=img_features,
        )
    elif model_name == "ViT":
        return ViTFeatureExtractor(
            input_channels=input_channels,
            img_size=img_size,
            img_features=img_features
        )
    elif model_name == "SFCN":
        return SFCN(
            input_channels=input_channels,
            channel_number=[32, 64, 128, 256, 256],
            output_dim=img_features
        )
    elif model_name == "CNN3D":
        return CNN3D(
            input_channels=input_channels,
            output_dim=img_features
        )
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

class AttentionLayer(nn.Module):
    def __init__(self, channels=3, reduction=16):
        super(AttentionLayer, self).__init__()
        reduced_channel_size = max(channels // reduction, 1)
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, reduced_channel_size, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channel_size, channels, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y.expand_as(x)

class CustomRegressor(nn.Module):
    def __init__(self, model_name, input_channels, img_size, img_features, additional_features=1, out_features=1):
        super(CustomRegressor, self).__init__()
        self.initial_attention = AttentionLayer(channels=input_channels)
        self.img_regressor = get_model(model_name, input_channels, img_size, img_features)
        self.fc_layers = nn.Sequential(
            nn.Linear(img_features + additional_features, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, out_features))
    def forward(self, x, y):
        if x.size(1) > 1:
            x = self.initial_attention(x)
        x = self.img_regressor(x)
        x = x.view(x.size(0), -1)
        x = torch.cat((x, y), dim=1)
        x = self.fc_layers(x)
        return x

def train_one_epoch(model, device, train_loader, modalities, additional_variables, optimizer, criterion, scaler, metric):
    model.train()
    epoch_loss = 0.0
    metric.reset()
    for batch_data in train_loader:
        targets = batch_data['Age'].unsqueeze(1).to(device)
        images = [batch_data[modality].to(device) for modality in modalities]
        img_inputs = torch.cat(images, dim=1)
        variables = [batch_data[variable].view(-1,1).to(device) for variable in additional_variables if variable != 'Age']
        variable_inputs = torch.cat(variables, dim=1)
        optimizer.zero_grad()
        outputs = model(img_inputs, variable_inputs)
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            loss = criterion(outputs, targets)
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        epoch_loss += loss.item()
        metric(y_pred=outputs, y=targets)
    epoch_metric = metric.aggregate().item()
    return epoch_loss / len(train_loader), epoch_metric

def validate_one_epoch(model, device, val_loader, modalities, additional_variables, metric):
    model.eval()
    metric.reset()
    with torch.no_grad():
        for batch_data in val_loader:
            targets = batch_data['Age'].unsqueeze(1).to(device)
            images = [batch_data[modality].to(device) for modality in modalities]
            img_inputs = torch.cat(images, dim=1)
            variables = [batch_data[variable].view(-1,1).to(device) for variable in additional_variables if variable != 'Age']
            variable_inputs = torch.cat(variables, dim=1)
            outputs = model(img_inputs, variable_inputs)
            metric(y_pred=outputs, y=targets)
    return metric.aggregate().item()

class EarlyStopping:
    def __init__(self, patience=30, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
    def __call__(self, metric):
        score = metric
        if self.best_score is None:
            self.best_score = score
        elif score > self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

def train_model(model_dir, model, device, train_loader, val_loader, modalities, additional_variables, logger,
        criterion, metric, max_epochs=100, learning_rate=1e-4, weight_decay=1e-5, val_interval=1, es_patience=30):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    start_time = time.time()
    best_metric = float("inf")
    best_metric_epoch = -1
    early_stopping = EarlyStopping(patience=es_patience, delta=0)
    epoch_loss_values, epoch_metric_values, metric_values = [], [], []
    for epoch in range(max_epochs):
        epoch_start_time = time.time()
        epoch_loss, epoch_metric = train_one_epoch(model, device, train_loader, modalities, additional_variables, optimizer, criterion, scaler, metric)
        epoch_loss_values.append(epoch_loss)
        epoch_metric_values.append(epoch_metric)
        if (epoch + 1) % val_interval == 0:
            val_metric = validate_one_epoch(model, device, val_loader, modalities, additional_variables, metric)
            metric_values.append(val_metric)
            if val_metric < best_metric:
                best_metric = val_metric
                best_metric_epoch = epoch + 1
                best_model_state = model.state_dict()
                torch.save(model.state_dict(), os.path.join(model_dir, "BestMetricModel.pth"))
                logger.info(f"Best MAE: {best_metric:.4f} at epoch {best_metric_epoch}")
            early_stopping(val_metric)
            if early_stopping.early_stop:
                logger.info(f"Early stopping triggered at epoch {epoch + 1}")
                print(f"; Early stopping triggered at epoch {epoch + 1}", end="")
                break
        epoch_end_time = time.time()
        logger.info(f"Epoch {epoch + 1} computed for {(epoch_end_time - epoch_start_time)/60:.2f} mins - Training loss: {epoch_loss:.4f}, Training MAE: {epoch_metric:.4f}, Validation MAE: {val_metric:.4f}")
        lr_scheduler.step()
        sys.stdout.write(f"\rEpoch {epoch + 1}/{max_epochs} completed")
        sys.stdout.flush()
    end_time = time.time()
    total_time = end_time - start_time
    logger.info(f"Best MAE: {best_metric:.3f} at epoch {best_metric_epoch}; Total time consumed: {total_time/60:.2f} mins")
    print(f"\nBest MAE: {best_metric:.3f} at epoch {best_metric_epoch}; Total time consumed: {total_time/60:.2f} mins")
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model, epoch_loss_values, epoch_metric_values, metric_values

def plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval=1):
    _, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].plot( [i + 1 for i in range(len(epoch_loss_values))], epoch_loss_values, label='Training Loss', color='red')
    axs[0].set_title('Training Loss')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')
    axs[1].plot([i + 1 for i in range(len(epoch_metric_values))], epoch_metric_values, label='Training MAE', color='red')
    axs[1].plot([val_interval * (i + 1) for i in range(len(metric_values))], metric_values, label='Validation MAE', color='blue')
    axs[1].set_title('Training MAE vs. Validation MAE')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('MAE')
    axs[1].legend()
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, "Performance.png"), dpi=300)

class SHAP3D:
    def __init__(self, model, modalities, additional_variables, device, batch_size):
        self.model = model
        self.modalities = modalities
        self.additional_variables = additional_variables
        self.device = device
        self.batch_size = batch_size
        self.explainer = None
    def create_background(self, dl, num_samples=20):
        dataset = dl.dataset
        if num_samples is not None:
            subset_indices = np.random.choice(len(dataset), num_samples, replace=False)
            subset_ds = Subset(dataset, subset_indices)
            subset_dl = DataLoader(subset_ds, batch_size=self.batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
        else:
            subset_dl = dl
        background_images = []
        background_variables = []
        for batch_data in subset_dl:
            images = [batch_data[modality].to(self.device) for modality in self.modalities]
            img_inputs = torch.cat(images, dim=1)
            variables = [batch_data[variable].view(-1,1).to(self.device) for variable in self.additional_variables if variable != 'Age']
            variable_inputs = torch.cat(variables, dim=1)
            background_images.append(img_inputs)
            background_variables.append(variable_inputs)
        background_images = torch.cat(background_images, dim=0)
        background_variables = torch.cat(background_variables, dim=0)
        self.explainer = shap.GradientExplainer(self.model, [background_images, background_variables])
    def compute_shap_values(self, dl, num_samples=None):
        dataset = dl.dataset
        if num_samples is not None:
            subset_indices = np.random.choice(len(dataset), num_samples, replace=False)
            subset_ds = Subset(dataset, subset_indices)
            subset_dl = DataLoader(subset_ds, batch_size=self.batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
        else:
            subset_dl = dl
        img_shap_values_list = []
        variable_shap_values_list = []
        images_list = []
        variables_list = []
        for batch_data in subset_dl:
            images = [batch_data[modality].to(self.device) for modality in self.modalities]
            img_inputs = torch.cat(images, dim=1)
            variables = [batch_data[variable].view(-1,1).to(self.device) for variable in self.additional_variables if variable != 'Age']
            variable_inputs = torch.cat(variables, dim=1)
            shap_values = self.explainer.shap_values([img_inputs, variable_inputs])
            img_shap_values_list.append(shap_values[0])
            variable_shap_values_list.append(shap_values[1])
            images_list.append(img_inputs.cpu().numpy())
            variables_list.append(variable_inputs.cpu().numpy())
        self.img_shap_values = np.concatenate(img_shap_values_list, axis=0)
        self.variable_shap_values = np.concatenate(variable_shap_values_list, axis=0)
        self.images = np.concatenate(images_list, axis=0)
        self.variables = np.concatenate(variables_list, axis=0)
    def visualize_shap(self, sample_img_path, pred_dir, vmin=-0.0025, vmax=0.0025):
        reference_img = nib.load(sample_img_path)
        _, axs = plt.subplots(len(self.modalities), 1, figsize=(12, len(self.modalities) * 4))
        if len(self.modalities) == 1:
            axs = [axs]
        for i, key in enumerate(self.modalities):
            shap_values = self.img_shap_values[:, i, :, :, :, 0]
            feature_values = self.images[:, i, :, :, :]
            mean_abs_shap_values = np.mean(np.abs(shap_values), axis=0)
            common_mask = np.all(feature_values != 0, axis=0)
            masked_mean_abs_shap_values = np.zeros_like(mean_abs_shap_values)
            masked_mean_abs_shap_values[common_mask] = mean_abs_shap_values[common_mask]
            masked_mean_abs_shap_values_img = nib.Nifti1Image(masked_mean_abs_shap_values, 
                                                              affine=reference_img.affine, 
                                                              header=reference_img.header)
            masked_mean_abs_shap_values_img.header['descrip'] = 'Mean absolute SHAP values'
            nib.save(masked_mean_abs_shap_values_img, os.path.join(pred_dir, f"MeanAbsSHAPValues_{key}.nii.gz"))
            plotting.plot_glass_brain(masked_mean_abs_shap_values_img, threshold=None, annotate=False,
                                      plot_abs=False, black_bg='auto', axes=axs[i],
                                      colorbar=True, cmap='black_red', symmetric_cbar=False,
                                      alpha=0.3, vmin=vmin, vmax=vmax)
            axs[i].set_title(f"{key}")
        plt.show()
        shap.summary_plot(self.variable_shap_values[:, :, 0], self.variables)

def calculate_regression_parameters(model_dir, model, device, val_loader, modalities, additional_variables, metric, pred_dir):
    model.load_state_dict(torch.load(os.path.join(model_dir, "BestMetricModel.pth")))
    model.eval()
    os.makedirs(pred_dir, exist_ok=True)
    pred_values = []
    target_values = []
    metric.reset()
    with torch.no_grad():
        for batch_data in val_loader:
            targets = batch_data['Age'].unsqueeze(1).to(device)
            images = [batch_data[modality].to(device) for modality in modalities]
            img_inputs = torch.cat(images, dim=1)
            variables = [batch_data[variable].view(-1,1).to(device) for variable in additional_variables if variable != 'Age']
            variable_inputs = torch.cat(variables, dim=1) if variables else None
            outputs = model(img_inputs, variable_inputs) if variable_inputs is not None else model(img_inputs)
            metric(y_pred=outputs, y=targets)
            pred_values.extend(outputs.cpu().numpy().flatten())
            target_values.extend(targets.cpu().numpy().flatten())
        metric_value = metric.aggregate().item()
    pred_values = np.array(pred_values)
    target_values = np.array(target_values)
    regression_param = np.polyfit(target_values, pred_values, 1)
    slope, intercept = regression_param
    r_value = np.corrcoef(target_values, pred_values)[0, 1]
    df = pd.DataFrame({
        'Slope': [slope],
        'Intercept': [intercept],
        'r': [r_value]
    })
    df.to_csv(os.path.join(pred_dir, "RegressionParameter.csv"), index=False)
    corrected_pred_values = (pred_values - intercept) / slope
    metric.reset()
    with torch.no_grad():
        y_pred = torch.tensor(corrected_pred_values, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        y = torch.tensor(target_values, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        metric(y_pred, y)
        corrected_metric_value = metric.aggregate().item()
    print(f'MAE on validation set: {metric_value:.3f}'
          f'\nCorrected MAE on validation set: {corrected_metric_value:.3f}')
    return regression_param

def apply_best_model(model_dir, model, device, test_loader, modalities, additional_variables,  pred_dir, subjects, regression_param=None):
    model.load_state_dict(torch.load(os.path.join(model_dir, "BestMetricModel.pth")))
    model.eval()
    os.makedirs(pred_dir, exist_ok=True)
    pred_values = np.array([])
    with torch.no_grad():
        for batch_data in test_loader:
            images = [batch_data[modality].to(device) for modality in modalities]
            img_inputs = torch.cat(images, dim=1)
            variables = [batch_data[variable].view(-1,1).to(device) for variable in additional_variables if variable != 'Age']
            variable_inputs = torch.cat(variables, dim=1)
            outputs = model(img_inputs, variable_inputs)
            pred_values = np.append(pred_values, outputs.cpu().numpy())
    df = pd.DataFrame({
        'No': subjects,
        'BrainAge': pred_values
    })
    if regression_param is not None:
        a, b = regression_param
        df['CorrectedBrainAge'] = (df['BrainAge'] - b) / a
    df.to_csv(os.path.join(pred_dir, "BrainAge.csv"), index=False)

### Prepare Inputs

In [None]:
data_dir = "BrainAgeEstimation_2mm"
model_dir_prefix = "BrainAgeEstimation"
model_name = "SFCN" # any supported model name: Regressor, ResNet18, ResNet50, DenseNet121, DenseNet169, SEResNet50, SENet154, EfficientNetB0, EfficientNetB2, ViT, SFCN, CNN3D
modalities = ["GM", "FA"]
additional_variables = ["Age", "Sex"]
test_size = 0.2
batch_size = 5
max_epochs = 100
learning_rate = 1e-4
weight_decay = 1e-5
val_interval = 1
es_patience = 30

model_dir = f"{model_dir_prefix}_{model_name}"
os.makedirs(model_dir, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print('Device:', device)
log_file = os.path.join(model_dir, "Prediction.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format="%(message)s")
logger = logging.getLogger()

### Read Data

In [None]:
set_determinism(seed=0)
train_loader, val_loader = load_data(os.path.join(data_dir, "train"), modalities, additional_variables, batch_size, test_size)

# Check data shape
tr = first(train_loader)
img_size = tuple(tr[modalities[0]].shape[-3:])
print('Data shape for training:')
for key, value in tr.items():
    print(f'\u2022 {key}: {tuple(value.shape)} \u00D7 {len(train_loader)}')
vl = first(val_loader)
print('\nData shape for validation:')
for key, value in vl.items():
    print(f'\u2022 {key}: {tuple(value.shape)} \u00D7 {len(val_loader)}')

# Visualize data
_, axs = plt.subplots(1, len(modalities), figsize=(len(modalities) * 4, 5))
slice_index = 40
for i, key in enumerate(modalities):
    image = tr[key][0, 0, :, :, :].detach().cpu()
    img_slice = torch.rot90(image[:, :, slice_index], k=1, dims=(0, 1))
    ax = axs[i]
    ax.imshow(img_slice, cmap='gray')
    ax.set_title(key)
    ax.axis('off')
plt.tight_layout()
plt.show()

### Train Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CustomRegressor(
            model_name=model_name,
            input_channels=len(modalities),
            img_size=img_size,
            img_features=64,
            additional_features=1,
            out_features=1
        ).to(device)
print(f"Selected model: {model_name}")
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
num_attentionalmodulator_params = sum(p.numel() for p in model.initial_attention.parameters() if p.requires_grad)
num_featureextractor_params = sum(p.numel() for p in model.img_regressor.parameters() if p.requires_grad)
num_regressor_params = sum(p.numel() for p in model.fc_layers.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params} = "
      f"{num_attentionalmodulator_params} (attentional modulator) + "
      f"{num_featureextractor_params} (feature extractor) + "
      f"{num_regressor_params} (regressor)")

criterion = nn.L1Loss()
metric = MAEMetric(reduction="mean")
model, epoch_loss_values, epoch_metric_values, metric_values = train_model(model_dir, model, device, train_loader, val_loader,
    modalities, additional_variables, logger, criterion, metric, max_epochs, learning_rate, weight_decay, val_interval, es_patience)
plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval)

# Visualize outcome
sample_indices = [76, 64] # 36-year-old female, 100-year-old female
slice_index = 40
model.eval()
fig, axs = plt.subplots(len(sample_indices), len(modalities), figsize=(len(modalities) * 4, len(sample_indices) * 5))
titles = []
for row, sample_index in enumerate(sample_indices):
    with torch.no_grad():
        target = torch.tensor(val_loader.dataset[sample_index]['Age']).unsqueeze(0).unsqueeze(1).to(device) # torch.tensor(val_loader.dataset[sample_index]['Age']).view(1, 1).to(device)
        images = [val_loader.dataset[sample_index][modality].unsqueeze(0).to(device) for modality in modalities]
        img_inputs = torch.cat(images, dim=1)
        variables = [torch.tensor(val_loader.dataset[sample_index][variable]).view(1, -1).to(device) for variable in additional_variables if variable != 'Age']
        variable_inputs = torch.cat(variables, dim=1)
        output = model(img_inputs, variable_inputs)
    for col, key in enumerate(modalities):
        image = images[col][0, 0, :, :, :].detach().cpu()
        img_slice = torch.rot90(image[:, :, slice_index], k=1, dims=(0, 1))
        ax = axs[row, col]
        ax.imshow(img_slice, cmap='gray')
        ax.set_title(key)
        ax.axis('off')
    title = (
        f'Sample {row + 1}: '
        f'Chronological age = {target.item():.1f} yrs, '
        f'Brain age = {output.item():.1f} yrs, '
        f'BAG = {output.item() - target.item():.1f} yrs'
    )
    titles.append(title)
combined_title = '\n'.join(titles)
plt.suptitle(combined_title, fontsize=12)
plt.tight_layout()
plt.show()

### SHAP

In [None]:
shap_analyzer = SHAP3D(model, modalities, additional_variables, device, batch_size)
shap_analyzer.create_background(train_loader, num_samples=20)
shap_analyzer.compute_shap_values(val_loader, num_samples=20)
sample_img_path = os.path.join(data_dir, 'train', 'GM', '001.nii.gz')
pred_dir = os.path.join(model_dir, "Prediction")
vmin=0; vmax=0.01
shap_analyzer.visualize_shap(sample_img_path, pred_dir, vmin, vmax)

### Inference

In [None]:
pred_dir = os.path.join(model_dir, "Prediction")
regression_param = calculate_regression_parameters(model_dir, model, device, val_loader, modalities, additional_variables, metric, pred_dir)
test_loader, subjects = load_data(os.path.join(data_dir, "test"), modalities, additional_variables, batch_size, None, True)
apply_best_model(model_dir, model, device, test_loader, modalities, additional_variables,  pred_dir, subjects, regression_param)