# Brain Age Estimation with SFCN

## Setup imports

In [None]:
import sys
import torch
import torch.nn as nn
import numpy as np
import glob
import os
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

from monai.config import print_config
from monai.data import ArrayDataset, DataLoader
from monai.metrics import MAEMetric
from monai.transforms import (
    Transform,
    Compose,
    LoadImage,
    EnsureChannelFirst,
    SpatialCrop,
)
from monai.utils import first

sys.path.append("C:\\UKBiobank_deep_pretrain") # https://github.com/ha-ha-ha-han/UKBiobank_deep_pretrain/
from dp_model.model_files.sfcn import SFCN

print_config()

## Setup data directory and data

In [None]:
root_dir = "C:\\BrainAgeEstimation\\Brain"
print(root_dir)
model_dir = os.path.join(root_dir, "SFCN_Pretrained")
os.makedirs(model_dir, exist_ok=True)
images = sorted(glob.glob(os.path.join(root_dir, "train", "*.nii.gz")))
test_images = sorted(glob.glob(os.path.join(root_dir, "test", "*.nii.gz")))
df = pd.read_csv(os.path.join(root_dir, "age_train.csv"))

## Setup transforms and dataset

In [None]:
batch_size = 1

class DivideByMean(Transform):
    def __call__(self, data):
        return data/data.mean()

# Define transforms for image
imtrans = Compose(
    [
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        DivideByMean(),
        SpatialCrop(roi_center=(84, 102, 84), roi_size=(160, 192, 160)),
    ]
)

# Define dataset and dataloader
train_ds = ArrayDataset(img=images, img_transform=imtrans, labels=df["Age"].values)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available())
test_ds = ArrayDataset(img=test_images, img_transform=imtrans)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available())

# Check data shape
tr = first(train_loader)
print(f"training: ({list(tr[0].shape)}, {list(tr[1].shape)}) \u00D7 {len(train_loader)}")
ts = first(test_loader)
print(f"test: ({list(ts[0].shape)}, {list(ts[1].shape)}) \u00D7 {len(test_loader)}")

## Check data shape and visualize

In [None]:
fig = plt.figure("Example image for training", (12, 6))
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
ax.set_title(f"Age = {tr[1][0]} years")
ax.imshow(np.rot90(tr[0][0,0,:, :, 80].detach().cpu()), cmap="gray")
ax.axis('off')
plt.savefig(os.path.join(model_dir, "image_age.tif"), dpi=300)
plt.show

## Apply pretrained model

In [None]:
bc = np.arange(42.5,82)
device = torch.device("cuda")
mae_metric = MAEMetric(reduction="mean")

# Apply the best model and save predictions
model = SFCN()
model = torch.nn.DataParallel(model)
fp_ = 'C:\\UKBiobank_deep_pretrain\\brain_age\\run_20190719_00_epoch_best_mae.p'
model.load_state_dict(torch.load(fp_))
model.eval()
predictions = []
test_predictions = []
labels = np.array([train_ds[i][1] for i in range(len(train_ds))])

with torch.no_grad():
    for idx in range(len(train_ds)):
        data = train_ds[idx][0]
        sp = (1,)+data.shape
        data = data.reshape(sp)
        input = data.clone().detach().type(torch.float32).to(device)
        output = model(input)
        x = output[0].cpu().reshape([1, -1])
        x = x.detach().numpy().reshape(-1)
        prob = np.exp(x)
        pred = prob@bc
        predictions.append(pred)
    
    for idx in range(len(test_ds)):
        data = test_ds[idx][0]
        sp = (1,)+data.shape
        data = data.reshape(sp)
        test_input = data.clone().detach().type(torch.float32).to(device)
        test_output = model(test_input)
        x = test_output[0].cpu().reshape([1, -1])
        x = x.detach().numpy().reshape(-1)
        prob = np.exp(x)
        pred = prob@bc
        test_predictions.append(pred)

mae = mae_metric(y_pred=torch.Tensor(predictions).reshape(1,-1), y=torch.Tensor(labels).reshape(1,-1))
print(f"MAE = {mae.item():.1f} years")
withinagerange = (labels > 42) & (labels < 82)
mae_withinagerange = mae_metric(y_pred=torch.Tensor(np.array(predictions)[withinagerange]).reshape(1,-1), y=torch.Tensor(labels[withinagerange]).reshape(1,-1)))
print(f"MAE for individuals aged between 42 and 82 years = {mae_withinagerange.item():.1f} years")