# Notebook to test [SFCN model](https://github.com/ha-ha-ha-han/UKBiobank_deep_pretrain) for brain age prediction

## Currently using a sample ukbb subject 

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
from models.sfcn import *
from models import dp_loss as dpl
from models import dp_utils as dpu

import nibabel as nib

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

## Paths

In [None]:
project_dir = "../"
models_dir = f"{project_dir}/models/"

data_dir = "/home/nikhil/projects/brain_changes/data/ukbb/"
    
# This is to be modified with the path of saved weights
p_ = f"{models_dir}/run_20190719_00_epoch_best_mae.p"

## Load Model

In [None]:
model = SFCN()
model = torch.nn.DataParallel(model)

if torch.cuda.is_available():
    map_location=lambda storage, loc: storage.cuda()
else:
    map_location='cpu'
    
model.load_state_dict(torch.load(p_, map_location=torch.device('cpu')))

## Load sample data
### Either random or UKBB sample subject

### eid	sex	birth_year	T1-ses2	T1-ses3	ethnicity	age_at_ses2	age_at_ses3	age_at_recruitment
								
- 1004084	1.0	1947.0	20252_2_0	20252_3_0	1001.0	70.0	72.0	60.0
- 1010063	0.0	1964.0	20252_2_0	20252_3_0	1001.0	53.0	55.0	45.



In [None]:
use_random_scan = False

subject_age_dict = {"sub-1010063": 53, 
                    "sub-1004084": 70}


if use_random_scan: 
    print("Generating a random scan...")
    data = np.ones([182, 218, 182]).astype(np.float)
    #data = np.random.rand(182, 218, 182)
    label = np.array([71.3,]) # Assuming the random subject is 71.3-year-old.

else:
    # Sample subject needs to be in the MNI space
    subject_id = "sub-1004084" #"sub-1010063" #"sub-1004084"
    scan_session = "ses-2"

    print(f"Using a sample scan from ukbb: {subject_id}")
    subject_dir = f"{data_dir}imaging/ukbb_test_subject/{subject_id}/{scan_session}/non-bids/T1/"
    T1_mni = f"{subject_dir}T1_brain_to_MNI.nii.gz"

    ukbb_follow_up_csv = f"{data_dir}tabular/tab_follow_up.csv"
    ukbb_metadata = pd.read_csv(ukbb_follow_up_csv)

    data = nib.load(T1_mni).get_fdata()
    print(f"image shape: {data.shape}")
    print(f"image mean: {np.mean(data.ravel())}")

    # Age at scanning
    age = subject_age_dict[subject_id]
    label = np.array([age,])


# Transforming the age to soft label (probability distribution)
# Changing this range will shift the predicted age because prediction is treated as classification problem with n_classes = n_bins
bin_range = [42,82]

bin_step = 1
sigma = 1
y, bc = dpu.num2vect(label, bin_range, bin_step, sigma)
y = torch.tensor(y, dtype=torch.float32)
print(f'Label shape: {y.shape}')

In [None]:
def get_brain_age(input_data, model, bc):
    """ Function to get brain age from T1w MRI (linear reg to MNI space) and SFCN model checkpoint
    """
    model.eval() 
    with torch.no_grad():
        output = model.module(input_data)

    # Output, loss, visualisation
    x = output[0].reshape([1, -1])

    x = x.numpy().reshape(-1)
    prob = np.exp(x)
    pred = prob@bc

    return prob, pred

def preproc_images(img, crop_shape=(160, 192, 160)):
    """ Function to preprocess T1w scan as expected by SFCN
    """
    img = img/img.mean()
    img = dpu.crop_center(img, crop_shape)

    # Move the img from numpy to torch tensor
    sp = (1,1)+img.shape
    img = img.reshape(sp)
    input_data = torch.tensor(img, dtype=torch.float32)

    return input_data
        

In [None]:
input_data = preproc_images(data)
prob, pred = get_brain_age(input_data, model, bc)

print(f"pred: {pred}, label = {label}")

In [None]:
input_data = preproc_images(data)
prob, pred = get_brain_age(input_data, model, bc)

print(f"pred: {pred}, label = {label}")

## Plots

In [None]:
plt.bar(bc, y)
plt.title('Soft label')
plt.show()

plt.bar(bc, prob)
plt.title(f'Prediction: age={pred:.2f}\nloss={loss}')
plt.show()