In [1]:
import os
import torch
import pandas as pd
import torch.nn as nn
import numpy as np
from age_prediction.models.\
    efficientnet_pytorch_3d import EfficientNet3D as EfNetB0
from age_prediction.dataloader import MyDataLoader
from age_prediction.trainer import ModuleTrainer
from age_prediction.metrics import MSE, MAE

import scipy.stats
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def MAE_pred(y_pred, y_true):
    mae = nn.L1Loss(reduction='mean')(y_pred, y_true).cpu().detach().numpy().item()
    return mae
def MSE_pred(y_pred, y_true):
    mse = nn.MSELoss()(y_pred, y_true).cpu().detach().numpy().item()
    return mse
def delta_pred(y_pred, y_true):
    delta = y_pred - y_true
    return np.mean(delta)

In [3]:
# params 
gpu, dataParallel = 'True', 'True'
gpu, dataParallel = 'False', 'False'
side = '_L'
batch_size = 512
data_aug = 'False'
age_range = [0, 70]

In [4]:
# Load effNet3D B0
model = EfNetB0.from_name("efficientnet-b0",
                          override_params={'num_classes': 1},
                          in_channels=1)
if eval(gpu):
    print("Using GPU")
    device = torch.device('cuda')
    cuda = True
    if eval(dataParallel):
        if torch.cuda.device_count() > 1:
            print(torch.cuda.device_count(), "GPUs!")
            model = nn.DataParallel(model)
else:
    print("Using cpu")
    cuda = False
    device = torch.device('cpu')

model = model.to(device)
print("Device", device, 'cuda', cuda)

Dropout 0.2
Using cpu
Device cpu cuda False


In [5]:
def check_state_dict(ck, md):
    if ck.keys() == md.keys():
        return ''
    else:
        split_key = list(ck.keys())[0].split(
            list(md.keys())[0])[0]
        return split_key

def correct_state_dict(ck, split_key):
    ck2 = {}
    for k in ck.keys():
        ck2[k.split(split_key)[1]] = ck[k]
    return ck2

In [11]:
def predict_pipeline(testfile, age=70, side='_L'):
    dataloader = MyDataLoader(database='../datasets',
                          csv_data='database_split',
                          side=side,
                          batch=batch_size,
                          data_aug=eval(data_aug),
                          test_file=testfile
                          )
    dataloader.prepare_data('test_label')
    dataloader.setup('test_label')
    
    print("Test size", len(dataloader.test.inputs[0]))
    
    # Check train/val results    
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=.256, alpha=0.9,
                                    eps=1e-08, momentum=0.9,
                                    )
    if age == 0:
        if side == '_L':
            snapshot = 'outputs/ckpt_27-04-2021_age_[0-70]_RMS_wd_0_L_dp0.6_model_best.pth.tar'
        else:
            snapshot = 'outputs/ckpt_16-04-2021_age_[0-70]_RMS_wd_0_R_dp0.5.pth.tar'
    else:
        if side == '_L':
            snapshot = 'outputs/ckpt_18-04-2021_age_[70-100]_RMS_wd_0_L_dp0.3.pth.tar'
        else:
            snapshot = 'outputs/ckpt_19-04-2021_age_[70-100]_RMS_wd_0_R_dp0.4.pth.tar'
    print(snapshot)

    checkpoint = torch.load(snapshot, map_location=device)
    split_key = check_state_dict(checkpoint['state_dict'], model.state_dict())
    print(list(checkpoint['state_dict'])[0])
    if split_key is not None:
        checkpoint['state_dict'] = correct_state_dict(checkpoint['state_dict'], split_key)
    print(list(checkpoint['state_dict'])[0])
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    optimizer.load_state_dict(checkpoint['optimizer'])
    
    print(side, age, 'epoch', checkpoint['epoch'],
          'loss', checkpoint['loss'],
          'val_loss', checkpoint['val_loss'])
    
    loss = nn.L1Loss(reduction='mean')
    metrics = [MSE()]
    
    # Predict
    trainer = ModuleTrainer(model.to(device))

    trainer.compile(loss=loss,
                    optimizer=optimizer,
                    metrics=metrics)

    print('cuda', cuda)

    pred = trainer.predict_loader(dataloader.testlabel_dataloader(),
                                  cuda_device=cuda)
    
    imgs = dataloader.testpath
    imgs = [img.split("/")[-1].split(side)[0] for img in imgs]
    
    preds = pd.concat([pd.DataFrame(imgs),
                       pd.DataFrame(dataloader.testlabel),
                       pd.DataFrame(pred.detach().cpu().numpy())], axis=1)
    preds.columns = ['Name', 'True', 'Pred']
    
    return preds

def metrics_pipeline(preds):
    print("MAE", MAE_pred(torch.from_numpy(preds['Pred'].values), torch.from_numpy(preds['True'].values)))
    print("MSE", MSE_pred(torch.from_numpy(preds['Pred'].values), torch.from_numpy(preds['True'].values)))
    print("Delta", delta_pred(preds['Pred'], preds['True']))

In [12]:
# Just checking if loss is right!
print("# Left - 0-70")
pred_left = predict_pipeline('val_0-70.csv', age=0, side='_L')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/val_0-70_L.csv', index=False)
# print("# Right - 0-70")
# pred_right = predict_pipeline('val_0-70.csv', age=0, side='_R')
# metrics_pipeline(pred_right)
# pred_right.to_csv('predict_results/val_0-70_R.csv', index=False)

# print("# Left - 70-100")
# pred_left = predict_pipeline('val_70-100.csv', age=70, side='_L')
# metrics_pipeline(pred_left)
# print("# Right - 70-100")
# pred_right = predict_pipeline('val_70-100.csv', age=70, side='_R')
# metrics_pipeline(pred_right)

# pred_left.to_csv('predict_results/val_70-100_L.csv', index=False)
# pred_right.to_csv('predict_results/val_70-100_R.csv', index=False)

# Left - 0-70
Preparing data
Setup data
Test size 67
outputs/ckpt_27-04-2021_age_[0-70]_RMS_wd_0_L_dp0.6_model_best.pth.tar
module._conv_stem.weight
_conv_stem.weight
_L 0 epoch 203 loss 30.867122573401094 val_loss 31.042200088500977
cuda False
MAE 4.337871380706332
MSE 31.042237827168353
Delta -0.2838427301663072


In [13]:
# Test results
print("# Left - 0-70")
pred_left = predict_pipeline('test_0-70.csv', age=0, side='_L')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/test_0-70_L.csv', index=False)
# print("# Right - 0-70")
# pred_right = predict_pipeline('val_0-70.csv', age=0, side='_R')
# metrics_pipeline(pred_right)
# pred_right.to_csv('predict_results/test_0-70_R.csv', index=False)

# print("# Left - 70-100")
# pred_left = predict_pipeline('test_70-100.csv', age=70, side='_L')
# metrics_pipeline(pred_left)
# print("# Right - 70-100")
# pred_right = predict_pipeline('test_70-100.csv', age=70, side='_R')
# metrics_pipeline(pred_right)

# pred_left.to_csv('predict_results/test_70-100_L.csv', index=False)
# pred_right.to_csv('predict_results/test_70-100_R.csv', index=False)

# Left - 0-70
Preparing data
Setup data
Test size 67
outputs/ckpt_27-04-2021_age_[0-70]_RMS_wd_0_L_dp0.6_model_best.pth.tar
module._conv_stem.weight
_conv_stem.weight
_L 0 epoch 203 loss 30.867122573401094 val_loss 31.042200088500977
cuda False
MAE 5.33826094214596
MSE 48.72745083675871
Delta -1.0490870324889225


In [17]:
# Test
print("# Left - 70-100")
pred_left = predict_pipeline('test_exp.csv', age=70, side='_L')
metrics_pipeline(pred_left)
print("# Right - 70-11")
pred_right = predict_pipeline('test_exp.csv', age=70, side='_R')
metrics_pipeline(pred_right)

pred_left.to_csv('predict_results/test_exp_L.csv', index=False)
pred_right.to_csv('predict_results/test_exp_R.csv', index=False)

# Left - 70-100
Preparing data
Setup data
Test size 151
outputs/ckpt_18-04-2021_age_[70-100]_RMS_wd_0_L_dp0.3.pth.tar
module._conv_stem.weight
_conv_stem.weight
_L 70 epoch 150 loss 1.9140096906470672 val_loss 3.928427219390869
cuda False
MAE 4.477334736198778
MSE 30.201666103676082
Delta 1.0984418856387113
# Right - 70-11
Preparing data
Setup data
Test size 151
outputs/ckpt_19-04-2021_age_[70-100]_RMS_wd_0_R_dp0.4.pth.tar
module._conv_stem.weight
_conv_stem.weight
_R 70 epoch 150 loss 2.959100668674303 val_loss 3.5481879711151123
cuda False
MAE 3.6924684941373913
MSE 22.789665610614335
Delta -1.5773571999657228


In [18]:
# AD
print("# Left - 70-100")
pred_left = predict_pipeline('ad_images.csv', age=70, side='_L')
metrics_pipeline(pred_left)
print("# Right - 70-100")
pred_right = predict_pipeline('ad_images.csv', age=70, side='_R')
metrics_pipeline(pred_right)
pred_left.to_csv('predict_results/ad_images_L.csv', index=False)
pred_right.to_csv('predict_results/ad_images_R.csv', index=False)

# Left - 70-100
Preparing data
Setup data
Test size 209
outputs/ckpt_18-04-2021_age_[70-100]_RMS_wd_0_L_dp0.3.pth.tar
module._conv_stem.weight
_conv_stem.weight
_L 70 epoch 150 loss 1.9140096906470672 val_loss 3.928427219390869
cuda False
MAE 5.699224357057416
MSE 48.2593426091945
Delta 4.396512322676809
# Right - 70-100
Preparing data
Setup data
Test size 209
outputs/ckpt_19-04-2021_age_[70-100]_RMS_wd_0_R_dp0.4.pth.tar
module._conv_stem.weight
_conv_stem.weight
_R 70 epoch 150 loss 2.959100668674303 val_loss 3.5481879711151123
cuda False
MAE 4.559458645907315
MSE 31.092469865173967
Delta 2.319664862965853


In [19]:
# MCI
print("# Left - 70-100")
pred_left = predict_pipeline('mci_images.csv', age=70, side='_L')
metrics_pipeline(pred_left)
print("# Right - 70-11")
pred_right = predict_pipeline('mci_images.csv', age=70, side='_R')
metrics_pipeline(pred_right)
pred_left.to_csv('predict_results/mci_images_L.csv', index=False)
pred_right.to_csv('predict_results/mci_images_R.csv', index=False)

# Left - 70-100
Preparing data
Setup data
Test size 251
outputs/ckpt_18-04-2021_age_[70-100]_RMS_wd_0_L_dp0.3.pth.tar
module._conv_stem.weight
_conv_stem.weight
_L 70 epoch 150 loss 1.9140096906470672 val_loss 3.928427219390869
cuda False
MAE 5.243145447993183
MSE 43.625055629645374
Delta 3.58503745029647
# Right - 70-11
Preparing data
Setup data
Test size 251
outputs/ckpt_19-04-2021_age_[70-100]_RMS_wd_0_R_dp0.4.pth.tar
module._conv_stem.weight
_conv_stem.weight
_R 70 epoch 150 loss 2.959100668674303 val_loss 3.5481879711151123
cuda False
MAE 4.732869997062529
MSE 35.33636928538651
Delta 1.4572912938091382


In [13]:
preds_left, preds_right, ens = predict('val_0-70.csv', age=0)
metrics(preds_left, preds_right, ens)
# sns.regplot(x="PredLR", y="TrueR", data=ens)
x, y = preds_right['TrueR'], preds_right['PredR']
slope, intercept, r, p, stderr = scipy.stats.linregress(x, y)
# intercept = 0
# slope = 1
line = f'Regression line: y={intercept:.2f}+{slope:.2f}x, r={r:.2f}'
line = f'Regression line: y={intercept:.2f}+{slope:.2f}x, r={r:.2f}'
fig, ax = plt.subplots()
ax.plot(x, y, linewidth=0, marker='s', label='Data points')
ax.plot(x, intercept + slope * x, label=line)
ax.set_xlabel('Chronological age')
ax.set_ylabel('Estimated age')
ax.legend(facecolor='white')
plt.show()

NameError: name 'predict' is not defined

In [None]:
maeval = abs(preds_right['PredR'].values - preds_right['TrueR'].values)
plott = pd.concat([pd.DataFrame(maeval), preds_right['TrueR']], axis=1)
plott.columns = ['MAE', 'TrueR']
sns.scatterplot(x="TrueR", y="MAE", data=plott)

In [None]:
preds_left, preds_right, ens = predict('val_exp.csv')
metrics(preds_left, preds_right, ens)
sns.regplot(x="PredLR", y="TrueR", data=ens)

In [None]:
x, y = ens['TrueR'], ens['PredLR']
slope, intercept, r, p, stderr = scipy.stats.linregress(x, y)
# intercept = 0
# slope = 1
line = f'Regression line: y={intercept:.2f}+{slope:.2f}x, r={r:.2f}'
line = f'Regression line: y={intercept:.2f}+{slope:.2f}x, r={r:.2f}'
fig, ax = plt.subplots()
ax.plot(x, y, linewidth=0, marker='s', label='Data points')
ax.plot(x, intercept + slope * x, label=line)
ax.set_xlabel('Chronological age')
ax.set_ylabel('Estimated age')
ax.legend(facecolor='white')
plt.show()

In [None]:
def bias_corr(pred):
    a = 0.41
    b = 48.4
    return (pred - b)/a

In [None]:
enss = pd.concat([ens, pd.DataFrame(bias_corr(ens['PredLR'].values), columns=['BC'])], axis=1)

maeval = enss['BC'].values - enss['TrueR'].values
plott = pd.concat([pd.DataFrame(maeval), enss['TrueR']], axis=1)
plott.columns = ['MAE', 'TrueR']
sns.scatterplot(x="TrueR", y="MAE", data=plott)


print("MAE", MAE_pred(torch.from_numpy(enss['BC'].values), torch.from_numpy(enss['TrueR'].values)))
print("MSE", MSE_pred(torch.from_numpy(enss['BC'].values), torch.from_numpy(enss['TrueR'].values)))
print("Delta", delta_pred(enss['BC'], enss['TrueR']))

In [None]:
preds_left, preds_right, ens = predict('test_exp.csv')
metrics(preds_left, preds_right, ens)
sns.regplot(x="PredLR", y="TrueR", data=ens)

In [None]:
enss = pd.concat([ens, pd.DataFrame(bias_corr(ens['PredLR'].values), columns=['BC'])], axis=1)

maeval = enss['BC'].values - enss['TrueR'].values
plott = pd.concat([pd.DataFrame(maeval), enss['TrueR']], axis=1)
plott.columns = ['MAE', 'TrueR']
sns.scatterplot(x="TrueR", y="MAE", data=plott)


print("MAE", MAE_pred(torch.from_numpy(enss['BC'].values), torch.from_numpy(enss['TrueR'].values)))
print("MSE", MSE_pred(torch.from_numpy(enss['BC'].values), torch.from_numpy(enss['TrueR'].values)))
print("Delta", delta_pred(enss['BC'], enss['TrueR']))

In [None]:
preds_left, preds_right, ens = predict('ad_images.csv')
metrics(preds_left, preds_right, ens)
sns.regplot(x="PredLR", y="TrueR", data=ens)

In [None]:
enss = pd.concat([ens, pd.DataFrame(bias_corr(ens['PredLR'].values), columns=['BC'])], axis=1)

maeval = enss['BC'].values - enss['TrueR'].values
plott = pd.concat([pd.DataFrame(maeval), enss['TrueR']], axis=1)
plott.columns = ['MAE', 'TrueR']
sns.scatterplot(x="TrueR", y="MAE", data=plott)

print("MAE", MAE_pred(torch.from_numpy(enss['BC'].values), torch.from_numpy(enss['TrueR'].values)))
print("MSE", MSE_pred(torch.from_numpy(enss['BC'].values), torch.from_numpy(enss['TrueR'].values)))
print("Delta", delta_pred(enss['BC'], enss['TrueR']))

In [None]:
preds_left, preds_right, ens = predict('mci_images.csv')
metrics(preds_left, preds_right, ens)
sns.regplot(x="PredLR", y="TrueR", data=ens)

In [None]:
enss = pd.concat([ens, pd.DataFrame(bias_corr(ens['PredLR'].values), columns=['BC'])], axis=1)

maeval = enss['BC'].values - enss['TrueR'].values
plott = pd.concat([pd.DataFrame(maeval), enss['TrueR']], axis=1)
plott.columns = ['MAE', 'TrueR']
sns.scatterplot(x="TrueR", y="MAE", data=plott)

print("MAE", MAE_pred(torch.from_numpy(enss['BC'].values), torch.from_numpy(enss['TrueR'].values)))
print("MSE", MSE_pred(torch.from_numpy(enss['BC'].values), torch.from_numpy(enss['TrueR'].values)))
print("Delta", delta_pred(enss['BC'], enss['TrueR']))

In [None]:
import scipy.stats
x, y = ens['TrueR'], ens['PredLR']
slope, intercept, r, p, stderr = scipy.stats.linregress(x, y)
intercept = 0
slope = 1
line = f'Regression line: y={intercept:.2f}+{slope:.2f}x, r={r:.2f}'
line = f'Regression line: y={intercept:.2f}+{slope:.2f}x, r={r:.2f}'
fig, ax = plt.subplots()
ax.plot(x, y, linewidth=0, marker='s', label='Data points')
ax.plot(x, intercept + slope * x, label=line)
ax.set_xlabel('True')
ax.set_ylabel('Pred')
ax.legend(facecolor='white')
plt.show()

In [19]:

ad_inp = pd.read_csv('../csv_data/ad_images.csv')

# ad_eval

In [22]:
for side in ['_L', '_R']:
    ad_eval = pd.read_csv('ad'+side+'.csv')
    ad_eval['Image Filename'] = ad_eval['0'].apply(lambda x: x.split('/')[-1].split(side)[0])
    for img in ad_inp['Image Filename']:
        if img not in ad_eval['Image Filename'].values:
            print(img)

016_S_5032_ADNI2_Month_6-New_Pt
016_S_5032_ADNI2_Month_6-New_Pt


In [23]:
for side in ['_L', '_R']:
    for x in os.listdir('../datasets/ADNI/'):
        if side in x:
            if x.split(side)[0] in ad_inp['Image Filename'].values:
                if '016_S_5032_ADNI2_Month_6-New_Pt' in x.split(side)[0]:
                    print(x)

In [11]:
ad_eval[ad_eval['Image Filename']=='024_S_4223_ADNI2_Month_6-New_Pt']

Unnamed: 0,0,Image Filename
