In [1]:
import os
import torch
import pandas as pd
import torch.nn as nn
import numpy as np
from age_prediction.models.\
    efficientnet_pytorch_3d import EfficientNet3D as EfNetB0
from age_prediction.dataloader import MyDataLoader
from age_prediction.trainer import ModuleTrainer
from age_prediction.metrics import MSE, MAE

import scipy.stats
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def MAE_pred(y_pred, y_true):
    mae = nn.L1Loss(reduction='mean')(y_pred, y_true).cpu().detach().numpy().item()
    return mae
def MSE_pred(y_pred, y_true):
    mse = nn.MSELoss()(y_pred, y_true).cpu().detach().numpy().item()
    return mse
def delta_pred(y_pred, y_true):
    delta = y_pred - y_true
    return np.mean(delta)

In [3]:
# params 
gpu, dataParallel = 'True', 'True'
gpu, dataParallel = 'False', 'False'
side = '_L'
batch_size = 512
data_aug = 'False'
age_range = [0, 70]

In [4]:
# Load effNet3D B0
model = EfNetB0.from_name("efficientnet-b0",
                          override_params={'num_classes': 1},
                          in_channels=1)
if eval(gpu):
    print("Using GPU")
    device = torch.device('cuda')
    cuda = True
    if eval(dataParallel):
        if torch.cuda.device_count() > 1:
            print(torch.cuda.device_count(), "GPUs!")
            model = nn.DataParallel(model)
else:
    print("Using cpu")
    cuda = False
    device = torch.device('cpu')

model = model.to(device)
print("Device", device, 'cuda', cuda)

Using cpu
Device cpu cuda False


In [5]:
def check_state_dict(ck, md):
    if ck.keys() == md.keys():
        return ''
    else:
        split_key = list(ck.keys())[0].split(
            list(md.keys())[0])[0]
        return split_key

def correct_state_dict(ck, split_key):
    ck2 = {}
    for k in ck.keys():
        ck2[k.split(split_key)[1]] = ck[k]
    return ck2

In [8]:
def predict_pipeline(testfile, age=70, side='_L'):
    dataloader = MyDataLoader(database='../datasets',
                              csv_data='database_split',
                              side=side,
                              batch=batch_size,
                              data_aug=eval(data_aug),
                              test_file=testfile
                             )
    dataloader.prepare_data('test_label')
    dataloader.setup('test_label')
    
    # print("Test size", len(dataloader.test.inputs[0]))
    
    # Check train/val results    
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=.256, alpha=0.9,
                                    eps=1e-08, momentum=0.9,
                                    )
    if age == 0:
        if side == '_L':
            snapshot = 'outputs/ckpt_10-06-2021_age_[0-70]_RMS_wd_0_L_dp0.3_model_best_clr_[-5.2,-3.4].pth.tar'
        else:
            snapshot = 'outputs/ckpt_10-06-2021_age_[0-70]_RMS_wd_0_R_dp0.2_model_best_clr_[-5.2,-3.6].pth.tar'
    else:
        if side == '_L':
            snapshot = 'outputs/ckpt_27-05-2021_age_[70-100]_RMS_wd_0_L_dp0.2.pth.tar'
        else:
            snapshot = 'outputs/ckpt_27-05-2021_age_[70-100]_RMS_wd_0_R_dp0.2.pth.tar'
    print(snapshot, age)

    checkpoint = torch.load(snapshot, map_location=device)
    split_key = check_state_dict(checkpoint['state_dict'], model.state_dict())
    if split_key is not None:
        checkpoint['state_dict'] = correct_state_dict(checkpoint['state_dict'], split_key)
    
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    optimizer.load_state_dict(checkpoint['optimizer'])
    
    loss = nn.L1Loss(reduction='mean')
    metrics = [MSE()]
    
    # Predict
    trainer = ModuleTrainer(model.to(device))

    trainer.compile(loss=loss,
                    optimizer=optimizer,
                    metrics=metrics)

    pred = trainer.predict_loader(dataloader.testlabel_dataloader(),
                                  cuda_device=cuda)
    
    imgs = dataloader.testpath
    imgs = [img.split("/")[-1].split(side)[0] for img in imgs]
    
    preds = pd.concat([pd.DataFrame(imgs),
                       pd.DataFrame(dataloader.testlabel),
                       pd.DataFrame(pred.detach().cpu().numpy())], axis=1)
    preds.columns = ['Name', 'True', 'Pred']
    
    return preds

def metrics_pipeline(preds):
    print("MAE", MAE_pred(torch.from_numpy(preds['Pred'].values), torch.from_numpy(preds['True'].values)))
    print("MSE", MSE_pred(torch.from_numpy(preds['Pred'].values), torch.from_numpy(preds['True'].values)))
    print("Delta", delta_pred(preds['Pred'], preds['True']))

In [9]:
print("#"*5, 'VAL',"#"*5)
print("# Left - 0-70")
pred_left = predict_pipeline('val_0-70.csv', age=0, side='_L')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/val_0-70_L.csv', index=False)
print("# Right - 0-70")
pred_right = predict_pipeline('val_0-70.csv', age=0, side='_R')
metrics_pipeline(pred_right)
pred_right.to_csv('predict_results/val_0-70_R.csv', index=False)

print("#"*5, 'TEST',"#"*5)
print("# Left - 0-70")
pred_left = predict_pipeline('test_0-70.csv', age=0, side='_L')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/test_0-70_L.csv', index=False)
print("# Right - 0-70")
pred_right = predict_pipeline('test_0-70.csv', age=0, side='_R')
metrics_pipeline(pred_right)
pred_right.to_csv('predict_results/test_0-70_R.csv', index=False)


##### VAL #####
# Left - 0-70
outputs/ckpt_10-06-2021_age_[0-70]_RMS_wd_0_L_dp0.3_model_best_clr_[-5.2,-3.4].pth.tar 0
MAE 4.801818963831121
MSE 39.5635064774985
Delta 0.48927336085926315
# Right - 0-70
outputs/ckpt_10-06-2021_age_[0-70]_RMS_wd_0_R_dp0.2_model_best_clr_[-5.2,-3.6].pth.tar 0
MAE 4.2397582799738105
MSE 33.75790230962084
Delta -0.0384121079878381
##### TEST #####
# Left - 0-70
outputs/ckpt_10-06-2021_age_[0-70]_RMS_wd_0_L_dp0.3_model_best_clr_[-5.2,-3.4].pth.tar 0
MAE 4.799713423972905
MSE 54.37101925792592
Delta 0.8318147366545926
# Right - 0-70
outputs/ckpt_10-06-2021_age_[0-70]_RMS_wd_0_R_dp0.2_model_best_clr_[-5.2,-3.6].pth.tar 0
MAE 4.757565201160519
MSE 37.50495844376314
Delta -1.3441947298271706


In [10]:
print("#"*5, 'VAL',"#"*5)
print("# Left - 70-100")
pred_left = predict_pipeline('val_70-100.csv', age=70, side='_L')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/val_70-100_L.csv', index=False)
print("# Right - 70-100")
pred_right = predict_pipeline('val_70-100.csv', age=70, side='_R')
metrics_pipeline(pred_right)
pred_right.to_csv('predict_results/val_70-100_R.csv', index=False)

print("#"*5, 'TEST',"#"*5)
print("# Left - 70-100")
pred_left = predict_pipeline('test_70-100.csv', age=70, side='_L')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/test_70-100_L.csv', index=False)
print("# Right - 70-100")
pred_right = predict_pipeline('test_70-100.csv', age=70, side='_R')
metrics_pipeline(pred_right)
pred_right.to_csv('predict_results/test_70-100_R.csv', index=False)


##### VAL #####
# Left - 70-100
outputs/ckpt_27-05-2021_age_[70-100]_RMS_wd_0_L_dp0.2.pth.tar 70
MAE 4.280558362543977
MSE 28.408468286749503
Delta 1.5427204056291388
# Right - 70-100
outputs/ckpt_27-05-2021_age_[70-100]_RMS_wd_0_R_dp0.2.pth.tar 70
MAE 4.780182672967974
MSE 36.942544893747694
Delta 3.2049762043731884
##### TEST #####
# Left - 70-100
outputs/ckpt_27-05-2021_age_[70-100]_RMS_wd_0_L_dp0.2.pth.tar 70
MAE 3.821388901464197
MSE 22.100976693503686
Delta 0.5262228858391974
# Right - 70-100
outputs/ckpt_27-05-2021_age_[70-100]_RMS_wd_0_R_dp0.2.pth.tar 70
MAE 4.3066763366295016
MSE 29.71222956930787
Delta 2.2448196916390737


In [14]:
# AD
print("# Left - 70-100")
pred_left = predict_pipeline('ad_70-100.csv', age=70, side='_L')
metrics_pipeline(pred_left)
print("# Right - 70-11")
pred_right = predict_pipeline('ad_70-100.csv', age=70, side='_R')
metrics_pipeline(pred_right)
pred_left.to_csv('predict_results/ad_70-100_L.csv', index=False)
pred_right.to_csv('predict_results/ad_70-100_R.csv', index=False)

# MCI
print("# Left - 70-100")
pred_left = predict_pipeline('mci_70-100.csv', age=70, side='_L')
metrics_pipeline(pred_left)
print("# Right - 70-11")
pred_right = predict_pipeline('mci_70-100.csv', age=70, side='_R')
metrics_pipeline(pred_right)
pred_left.to_csv('predict_results/mci_70-100_L.csv', index=False)
pred_right.to_csv('predict_results/mci_70-100_R.csv', index=False)

# Left - 70-100
outputs/ckpt_27-05-2021_age_[70-100]_RMS_wd_0_L_dp0.2.pth.tar 70
MAE 6.910618394185481
MSE 68.35870272455229
Delta 6.142161085502953
# Right - 70-11
outputs/ckpt_27-05-2021_age_[70-100]_RMS_wd_0_R_dp0.2.pth.tar 70
MAE 7.617103145909651
MSE 80.87346241657352
Delta 6.949431336325321
# Left - 70-100
outputs/ckpt_27-05-2021_age_[70-100]_RMS_wd_0_L_dp0.2.pth.tar 70
MAE 6.434782841291086
MSE 61.38389611448999
Delta 4.6830616134096426
# Right - 70-11
outputs/ckpt_27-05-2021_age_[70-100]_RMS_wd_0_R_dp0.2.pth.tar 70
MAE 7.248179061099353
MSE 75.56198157082085
Delta 6.271172426991253
