# Brain Age Estimation with Regressor

## Setup imports

In [None]:
import torch
import torch.nn as nn
import numpy as np
import glob
import os
import pandas as pd
import logging
import time
import matplotlib.pyplot as plt
%matplotlib inline

from monai.config import print_config
from monai.data import ArrayDataset, decollate_batch, DataLoader
from monai.metrics import MAEMetric
from monai.networks.nets import Regressor
from monai.transforms import (
    Compose,
    LoadImage,
    EnsureChannelFirst,
    SpatialCrop,
    Resize,
    NormalizeIntensity,
    Activations,
)
from monai.utils import first, set_determinism

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

import torchinfo

print_config()

## Set deterministic training for reproducibility

In [2]:
set_determinism(seed=0)

## Setup directories and data

In [None]:
root_dir = "C:\\BrainAgeEstimation\\Brain"
print(root_dir)
model_dir = os.path.join(root_dir, "Regressor_Batch5_LR4")
os.makedirs(model_dir, exist_ok=True)
images = sorted(glob.glob(os.path.join(root_dir, "train", "*.nii.gz")))
df = pd.read_csv(os.path.join(root_dir, "age_train.csv"))

## Setup logging

In [4]:
log_file = os.path.join(model_dir, "brain_age_estimation.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format="%(asctime)s -  %(message)s")
logger = logging.getLogger()

## Setup transforms and dataset

In [None]:
batch_size = 5

# Define transforms for image
imtrans = Compose(
    [
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        SpatialCrop(roi_center=(84, 102, 84), roi_size=(160, 192, 160)),
        Resize((64, 64, 64), mode="trilinear"),
        NormalizeIntensity(nonzero=True, channel_wise=True),
    ]
)

# Define dataset and dataloader
ds = ArrayDataset(img=images, img_transform=imtrans, labels=df["Age"].values)
train_ds, val_ds = train_test_split(ds, test_size=0.2)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=0, pin_memory=torch.cuda.is_available())

# Check data shape
tr = first(train_loader)
print(f"training: ({list(tr[0].shape)}, {list(tr[1].shape)}) \u00D7 {len(train_loader)}")
vl = first(val_loader)
print(f"validation: ({list(vl[0].shape)}, {list(vl[1].shape)}) \u00D7 {len(val_loader)}")

## Check data shape and visualize

In [None]:
fig = plt.figure("Example image for training", (12, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
ax.set_title(f"Age = {tr[1][0]} years")
ax.imshow(np.rot90(tr[0][0,0,:, :, 30].detach().cpu()), cmap="gray")
ax.axis('off')
plt.savefig(os.path.join(model_dir, "image_age.tif"), dpi=300)
plt.show

## Create model

In [7]:
max_epochs = 100
val_interval = 1
lr = 1e-4

# Create Regressor, L1Loss, and Adam optimizer
device = torch.device("cuda")
# device = torch.device("mps")
# device = torch.device("cpu")
model = Regressor(
    in_shape=[1, 64, 64, 64],
    out_shape=1,
    channels=(16, 32, 64, 128, 256), # (2, 4, 8)
    strides=(2, 2, 2, 2), # (2, 2, 2)
    kernel_size=3,
    num_res_units=2
).to(device)

loss_function = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

mae_metric = MAEMetric(reduction="mean")

post_pred = Compose([Activations()])
    
# Use AMP to accelerate training
if torch.cuda.is_available():
    scaler = torch.cuda.amp.GradScaler()

## Print model

In [None]:
torchinfo.summary(model, input_size=(5, 1, 64, 64, 64))

## Execute training process

In [None]:
best_metric = float("inf")
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
epoch_metric_values = []
metric_values = []

total_start = time.time()
for epoch in range(max_epochs):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        im, val = batch_data
        inputs, labels = (
            im.to(device),
            val.to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs).flatten()
        loss = loss_function(outputs, labels)
        if torch.cuda.is_available():
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}"
            f", train_loss: {loss.item():.4f}"
            f", step time: {(time.time() - step_start):.4f}"
        )
    lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    outputs = torch.tensor([post_pred(i) for i in decollate_batch(outputs)]).to(labels.device)
    mae_metric(y_pred=outputs.reshape(batch_size, -1), y=labels.reshape(batch_size, -1))
    metric = mae_metric.aggregate().item()
    epoch_metric_values.append(metric)

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                im, val = val_data
                val_inputs, val_labels = (
                    im.to(device),
                    val.to(device),
                )
                val_outputs = model(val_inputs)
                val_outputs = torch.tensor([post_pred(i) for i in decollate_batch(val_outputs)]).to(val_labels.device)
                mae_metric(y_pred=val_outputs.reshape(batch_size, -1), y=val_labels.reshape(batch_size, -1))

            metric = mae_metric.aggregate().item()
            metric_values.append(metric)
            mae_metric.reset()

            if metric < best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                torch.save(
                    model.state_dict(),
                    os.path.join(model_dir, "best_metric_model.pth"),
                )
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mae: {metric:.4f}"
                f"\nbest mae: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
            logger.info(f"epoch {epoch + 1} mae: {metric:.4f}")
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")

## Plot loss and metric

In [None]:
total_time = time.time() - total_start
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}, total time: {total_time}.")
logger.info(
    f"best_metric: {best_metric:.4f} at epoch {best_metric_epoch}, "
    f"total time to train: {total_time}"
    )

fig = plt.figure("Performance in training", (12, 6))
ax1 = fig.add_subplot(1, 2, 1)
ax1.set_title("Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
ax1.plot(x, y, color="red")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax2 = fig.add_subplot(1, 2, 2)
ax2.set_title("MAE")
x1 = [i + 1 for i in range(len(epoch_loss_values))]
x2 = [val_interval * (i + 1) for i in range(len(metric_values))]
y1 = epoch_metric_values
y2 = metric_values
ax2.plot(x1, y1, color="red")
ax2.plot(x2, y2, color="blue")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("MAE")
ax2.legend(["Train","Validation"])
plt.savefig(os.path.join(model_dir, "performance.tif"), dpi=300)
plt.show

## Check best model output

In [None]:
testidx = 3
model.load_state_dict(torch.load(os.path.join(model_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    # Select one image to evaluate and visualize the model output
    val_input = val_ds[testidx][0].unsqueeze(0).to(device)
    val_output = model(val_input)
    val_output = post_pred(val_output)

fig = plt.figure("Actual vs. Predicted", (12, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
ax.set_title(f"Actual age = {val_ds[testidx][1]} years"
             f"\nEstimated age = {val_output.item():.1f} years"
             f"\nBrain age gap = {val_output.item() - val_ds[testidx][1]:.1f} years")
ax.imshow(np.rot90(val_ds[testidx][0][0,:, :, 30].detach().cpu()), cmap="gray")
ax.axis('off')
plt.savefig(os.path.join(model_dir, "actual_predicted.tif"), dpi=300)
plt.show

## Apply best model

In [12]:
# Define dataset and dataloader
test_images = sorted(glob.glob(os.path.join(root_dir, "test", "*.nii.gz")))
test_ds = ArrayDataset(test_images, imtrans)

# Apply the best model and save predictions
model.load_state_dict(torch.load(os.path.join(model_dir, "best_metric_model.pth")))
model.eval()
test_predictions = []

with torch.no_grad():
    for idx in range(len(test_ds)):
        test_input = test_ds[idx].unsqueeze(0).to(device)
        test_output = model(test_input)
        test_output = post_pred(test_output)
        test_predictions.append(test_output.item())

np.savetxt(os.path.join(model_dir, "BrainAge.txt"), test_predictions)