In [1]:
from dp_model.model_files.sfcn import SFCN
from dp_model import dp_loss as dpl
from dp_model import dp_utils as dpu
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import nibabel as nib
import pandas as pd
from tqdm.autonotebook import tqdm
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import roc_auc_score
import shutil
import zipfile
import os

In [2]:
print(torch.__version__)

1.11.0+cu102


In [3]:
print("First we will evaluate the Peng et al., Neuroimage brain age and sex models")
print("This is available here: https://github.com/ha-ha-ha-han/UKBiobank_deep_pretrain")

First we will evaluate the Peng et al., Neuroimage brain age and sex models
This is available here: https://github.com/ha-ha-ha-han/UKBiobank_deep_pretrain


In [4]:
model = SFCN()
model = torch.nn.DataParallel(model)
fp_ = './brain_age/run_20190719_00_epoch_best_mae.p'
model.load_state_dict(torch.load(fp_))
model.cuda()

DataParallel(
  (module): SFCN(
    (feature_extractor): Sequential(
      (conv_0): Sequential(
        (0): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): ReLU()
      )
      (conv_1): Sequential(
        (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): ReLU()
      )
      (conv_2): Sequential(
        (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, 

In [5]:
path = '/media/Data1/*/BIOBANK/'
metadata_test = pd.read_csv(path+'/TEST/metadata_test.csv',index_col=0)
metadata_test=metadata_test[::2].reset_index(drop=True) 
metadata_test['SCFN_age_prediction']=0
metadata_test['SCFN_age_prediction_KL']=0

In [6]:
for i, row in tqdm(metadata_test.iterrows(),total=metadata_test.shape[0]):
    zip_t1 = sorted(glob.glob('/home/*/Desktop/Biobank_T1/'+str(int(row['biobank_id']))+'*.zip'))[0]
    
    with zipfile.ZipFile(zip_t1) as z:
        with z.open('T1/T1_brain_to_MNI.nii.gz') as zf, open('/home/*/Desktop/T1_temp.nii.gz', 'wb') as f:
            shutil.copyfileobj(zf, f)
            
    data = np.asanyarray(nib.load('/home/*/Desktop/T1_temp.nii.gz').dataobj)
    
    label = row['age']

    # Transforming the age to soft label (probability distribution)
    bin_range = [42,82]
    bin_step = 1
    sigma = 1
    y, bc = dpu.num2vect(label, bin_range, bin_step, sigma)
    y = torch.tensor(y, dtype=torch.float32)

    # Preprocessing
    data = data/data.mean()
    data = dpu.crop_center(data, (160, 192, 160))

    # Move the data from numpy to torch tensor on GPU
    sp = (1,1)+data.shape
    data = data.reshape(sp)
    input_data = torch.tensor(data, dtype=torch.float32).cuda()
    
    # Evaluation
    model.eval() # Don't forget this. BatchNorm will be affected if not in eval mode.
    with torch.no_grad():
        output = model(input_data)

    # Output, loss, visualisation
    x = output[0].cpu().reshape([1, -1])
    loss = dpl.my_KLDivLoss(x, y).numpy()

    # Prediction, Visualisation and Summary
    x = x.numpy().reshape(-1)
    y = y.numpy().reshape(-1)

    prob = np.exp(x)
    pred = prob@bc

    metadata_test.loc[i,'SCFN_age_prediction']=pred
    metadata_test.loc[i,'SCFN_age_prediction_KL']=loss
    os.remove('/home/*/Desktop/T1_temp.nii.gz')

  0%|          | 0/2381 [00:00<?, ?it/s]

In [7]:
print("HaHaHaHan Brain Age Prediction")
print(mean_absolute_error(metadata_test['age_at_scan'], metadata_test['SCFN_age_prediction']))
print(r2_score(y_true=metadata_test['age_at_scan'].values, y_pred = metadata_test['SCFN_age_prediction'].values))

HaHaHaHan Brain Age Prediction
5.282125401294747
0.3131510890874354


In [8]:
##sex prediction
# Example
model = SFCN(output_dim=2, channel_number=[28, 58, 128, 256, 256, 64])
model = torch.nn.DataParallel(model)
fp_ = './sex_prediction/run_20191008_00_epoch_last.p'
model.load_state_dict(torch.load(fp_))
model.cuda()

DataParallel(
  (module): SFCN(
    (feature_extractor): Sequential(
      (conv_0): Sequential(
        (0): Conv3d(1, 28, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): ReLU()
      )
      (conv_1): Sequential(
        (0): Conv3d(28, 58, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): ReLU()
      )
      (conv_2): Sequential(
        (0): Conv3d(58, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, 

In [9]:
for i, row in tqdm(metadata_test.iterrows(),total=metadata_test.shape[0]):
    zip_t1 = sorted(glob.glob('/home/*/Desktop/Biobank_T1/'+str(int(row['biobank_id']))+'*.zip'))[0]
    
    with zipfile.ZipFile(zip_t1) as z:
        with z.open('T1/T1_brain_to_MNI.nii.gz') as zf, open('/home/*/Desktop/T1_temp.nii.gz', 'wb') as f:
            shutil.copyfileobj(zf, f)
    
    data = np.asanyarray(nib.load('/home/*/Desktop/T1_temp.nii.gz').dataobj)
    
    y = torch.tensor([int(row['sex'])]) # Assuming Sex is Male (0=Female, 1=Male)

    # Preprocessing
    data = data/data.mean()
    data = dpu.crop_center(data, (160, 192, 160))

    # Move the data from numpy to torch tensor on GPU
    sp = (1,1)+data.shape
    data = data.reshape(sp)
    input_data = torch.tensor(data, dtype=torch.float32).cuda()

    # Evaluation
    model.eval() 
    with torch.no_grad():
        output = model(input_data)

    # Output, loss, visualisation
    x = output[0].cpu().reshape([1, -1])
    loss = F.nll_loss(x, y)

    # Prediction, Visualisation and Summary
    x = np.exp(x.numpy().reshape(-1))

    metadata_test.loc[i,'SCFN_sex_prediction']=x.argmax()
    os.remove('/home/*/Desktop/T1_temp.nii.gz')

  0%|          | 0/2381 [00:00<?, ?it/s]

In [10]:
print("HaHaHaHan Brain Sex Prediction")
print(roc_auc_score(y_true=metadata_test['sex'].values, y_score = metadata_test['SCFN_sex_prediction'].values))
print(balanced_accuracy_score(y_true=metadata_test['sex'], y_pred=metadata_test['SCFN_sex_prediction']))

HaHaHaHan Brain Sex Prediction
0.9658960321936367
0.9658960321936367


In [11]:
print("Now we will evaluate brainageR")
print("This is available here: https://github.com/james-cole/brainageR")

Now we will evaluate brainageR
This is available here: https://github.com/james-cole/brainageR


In [12]:
# os.system('brainage_run_parallel.sh')
print("Run the above if not done elsewhere")

Run the above if not done elsewhere


In [13]:
metadata_test['brainageR']=0
for i, row in metadata_test.iterrows():
    file = glob.glob('/media/Data1/brainage_prediction/prediction/*'+str(int(row['biobank_id']))+'*')
    if len(file)>0:
        brainage = pd.read_csv(file[0])
        metadata_test.loc[i,'brainageR']= brainage['brain.predicted_age'].values[0]
    else:
        continue

In [14]:
print("brainageR Prediction")
print(mean_absolute_error(metadata_test.dropna()['age_at_scan'], metadata_test.dropna()['brainageR']))
print(r2_score(y_true=metadata_test.dropna()['age_at_scan'].values, y_pred = metadata_test.dropna()['brainageR'].values))

brainageR Prediction
5.155460842974922
0.10002060791804346
