In [3]:
import os
import torch
import pandas as pd
import torch.nn as nn
import numpy as np
from age_prediction.models.\
    efficientnet_pytorch_3d import EfficientNet3D as EfNetB0
from age_prediction.dataloader import MyDataLoader
from age_prediction.trainer import ModuleTrainer
from age_prediction.metrics import MSE, MAE

import scipy.stats
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
def MAE_pred(y_pred, y_true):
    mae = nn.L1Loss(reduction='mean')(y_pred, y_true).cpu().detach().numpy().item()
    return mae
def MSE_pred(y_pred, y_true):
    mse = nn.MSELoss()(y_pred, y_true).cpu().detach().numpy().item()
    return mse
def delta_pred(y_pred, y_true):
    delta = y_pred - y_true
    return np.mean(delta)

In [5]:
# params 
gpu, dataParallel = 'True', 'True'
gpu, dataParallel = 'False', 'False'
side = '_L'
batch_size = 512
data_aug = 'False'
age_range = [0, 70]

In [6]:
# Load effNet3D B0
model = EfNetB0.from_name("efficientnet-b0",
                          override_params={'num_classes': 1},
                          in_channels=1)
if eval(gpu):
    print("Using GPU")
    device = torch.device('cuda')
    cuda = True
    if eval(dataParallel):
        if torch.cuda.device_count() > 1:
            print(torch.cuda.device_count(), "GPUs!")
            model = nn.DataParallel(model)
else:
    print("Using cpu")
    cuda = False
    device = torch.device('cpu')

model = model.to(device)
print("Device", device, 'cuda', cuda)

Using cpu
Device cpu cuda False


In [7]:
def check_state_dict(ck, md):
    if ck.keys() == md.keys():
        return ''
    else:
        split_key = list(ck.keys())[0].split(
            list(md.keys())[0])[0]
        return split_key

def correct_state_dict(ck, split_key):
    ck2 = {}
    for k in ck.keys():
        ck2[k.split(split_key)[1]] = ck[k]
    return ck2

In [8]:
def predict_pipeline(testfile, age=70, side='_L', hold='1'):
    dataloader = MyDataLoader(database='../datasets',
                              csv_data='database_split',
                              side=side,
                              batch=batch_size,
                              data_aug=eval(data_aug),
                              test_file=testfile
                             )
    dataloader.prepare_data('test_label')
    dataloader.setup('test_label')

    # Check train/val results    
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=.256, alpha=0.9,
                                    eps=1e-08, momentum=0.9,
                                    )
    if age == 0:
        if side == '_L':
            snapshot = 'outputs/results_0-70/ckpt_10-06-2021_age_[0-70]_RMS_wd_0_L_dp0.3_model_best_clr_[-5.2,-3.4].pth.tar'
        else:
            snapshot = 'outputs/results_0-70/ckpt_10-06-2021_age_[0-70]_RMS_wd_0_R_dp0.2_model_best_clr_[-5.2,-3.6].pth.tar'
    else:
        if side == '_L':
            if hold == '1':
                snapshot = 'outputs/results_70-100/ckpt_17-06-2021_age_[70-100]_RMS_wd_0_L_dp0.2_model_best_clr_[-4.7,-3.3].pth.tar'
            else:
                snapshot = 'outputs/results_70-100/ckpt_17-06-2021_age_[70-100]_RMS_wd_0_L_hold2_dp0.2_model_best_clr_[-4.7,-3.3].pth.tar'
        else:
            if hold == '1':
                snapshot = 'outputs/results_70-100/ckpt_15-06-2021_age_[70-100]_RMS_wd_0_R_dp0.2_model_best_clr_[-4.7,-3.3].pth.tar'
            else:
                snapshot = 'outputs/results_70-100/ckpt_17-06-2021_age_[70-100]_RMS_wd_0_R_hold2_dp0.2_model_best_clr_[-4.7,-3.3].pth.tar'

    checkpoint = torch.load(snapshot, map_location=device)
    split_key = check_state_dict(checkpoint['state_dict'], model.state_dict())
    if split_key is not None:
        checkpoint['state_dict'] = correct_state_dict(checkpoint['state_dict'], split_key)

    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    optimizer.load_state_dict(checkpoint['optimizer'])

    epoch = checkpoint['epoch']
    _loss = checkpoint['loss']
    _val_loss = checkpoint['val_loss']
    print("Snapshot trained for {} epochs. \
               Loss: {} and Val loss {}".format(epoch, _loss, _val_loss))

    loss = nn.L1Loss(reduction='mean')
    metrics = [MSE()]

    # Predict
    trainer = ModuleTrainer(model.to(device))

    trainer.compile(loss=loss,
                    optimizer=optimizer,
                    metrics=metrics)
    import time
    start_time = time.time()
    pred = trainer.predict_loader(dataloader.testlabel_dataloader(),
                                  cuda_device=cuda)
    print(time.time() - start_time)
    imgs = dataloader.testpath
    imgs = [img.split("/")[-1].split(side)[0] for img in imgs]

    preds = pd.concat([pd.DataFrame(imgs),
                       pd.DataFrame(dataloader.testlabel),
                       pd.DataFrame(pred.detach().cpu().numpy())], axis=1)
    preds.columns = ['Name', 'True', 'Pred']

    return preds

def metrics_pipeline(preds):
    print("MAE", MAE_pred(torch.from_numpy(preds['Pred'].values), torch.from_numpy(preds['True'].values)))
    print("MSE", MSE_pred(torch.from_numpy(preds['Pred'].values), torch.from_numpy(preds['True'].values)))
    print("Delta", delta_pred(preds['Pred'], preds['True']))


In [9]:
print("#"*5, 'VAL',"#"*5)
print("# Left - 0-70")
pred_left = predict_pipeline('val_0-70.csv', age=0, side='_L')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/val_0-70_L.csv', index=False)
print("# Right - 0-70")
pred_right = predict_pipeline('val_0-70.csv', age=0, side='_R')
metrics_pipeline(pred_right)
pred_right.to_csv('predict_results/val_0-70_R.csv', index=False)

print("#"*5, 'TEST',"#"*5)
print("# Left - 0-70")
pred_left = predict_pipeline('test_0-70.csv', age=0, side='_L')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/test_0-70_L.csv', index=False)
print("# Right - 0-70")
pred_right = predict_pipeline('test_0-70.csv', age=0, side='_R')
metrics_pipeline(pred_right)
pred_right.to_csv('predict_results/test_0-70_R.csv', index=False)


##### VAL #####
# Left - 0-70
Snapshot trained for 117 epochs.                Loss: 2.572100747562287 and Val loss 4.801818370819092
2.9620437622070312
MAE 4.801818963831121
MSE 39.5635064774985
Delta 0.48927336085926315
# Right - 0-70
Snapshot trained for 142 epochs.                Loss: 1.0101312005833207 and Val loss 4.239758014678955
2.9154791831970215
MAE 4.2397582799738105
MSE 33.75790230962084
Delta -0.0384121079878381
##### TEST #####
# Left - 0-70
Snapshot trained for 117 epochs.                Loss: 2.572100747562287 and Val loss 4.801818370819092
2.3737547397613525
MAE 4.235512146722703
MSE 36.33146664751724
Delta 0.17313959030878012
# Right - 0-70
Snapshot trained for 142 epochs.                Loss: 1.0101312005833207 and Val loss 4.239758014678955
1.6373393535614014
MAE 4.710462586539132
MSE 37.31764790949443
Delta -1.5365772519792826


In [10]:
print("Hold1")
print("#"*5, 'VAL',"#"*5)
print("# Left - 70-100")
pred_left = predict_pipeline('val_70-100.csv', age=70, side='_L')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/val_70-100_L.csv', index=False)
print("# Right - 70-100")
pred_right = predict_pipeline('val_70-100.csv', age=70, side='_R')
metrics_pipeline(pred_right)
pred_right.to_csv('predict_results/val_70-100_R.csv', index=False)

print("#"*5, 'TEST',"#"*5)
print("# Left - 70-100")
pred_left = predict_pipeline('test_70-100.csv', age=70, side='_L')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/test_70-100_L.csv', index=False)
print("# Right - 70-100")
pred_right = predict_pipeline('test_70-100.csv', age=70, side='_R')
metrics_pipeline(pred_right)
pred_right.to_csv('predict_results/test_70-100_R.csv', index=False)


Hold1
##### VAL #####
# Left - 70-100
Snapshot trained for 83 epochs.                Loss: 1.650392166317153 and Val loss 3.8597484455992843
9.441497087478638
MAE 3.8597486887546566
MSE 21.792410368585944
Delta 0.745377941005277
# Right - 70-100
Snapshot trained for 95 epochs.                Loss: 0.9459606564607452 and Val loss 3.9991477088422966
9.39162826538086
MAE 3.9991480442072382
MSE 26.271734447147598
Delta 1.63959557741683
##### TEST #####
# Left - 70-100
Snapshot trained for 83 epochs.                Loss: 1.650392166317153 and Val loss 3.8597484455992843
8.305693626403809
MAE 3.7301775496527054
MSE 20.468335884194406
Delta 0.769520993264306
# Right - 70-100
Snapshot trained for 95 epochs.                Loss: 0.9459606564607452 and Val loss 3.9991477088422966
9.213366746902466
MAE 3.6130778836888195
MSE 21.064230867210814
Delta 1.2800750227163962


In [11]:
print("Hold2")
print("#"*5, 'VAL',"#"*5)
print("# Left - 70-100")
pred_left = predict_pipeline('test_70-100.csv', age=70, side='_L', hold='2')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/val_70-100_L_h2.csv', index=False)
print("# Right - 70-100")
pred_right = predict_pipeline('test_70-100.csv', age=70, side='_R', hold='2')
metrics_pipeline(pred_right)
pred_right.to_csv('predict_results/val_70-100_R_h2.csv', index=False)

print("#"*5, 'TEST',"#"*5)
print("# Left - 70-100")
pred_left = predict_pipeline('val_70-100.csv', age=70, side='_L', hold='2')
metrics_pipeline(pred_left)
pred_left.to_csv('predict_results/test_70-100_L_h2.csv', index=False)
print("# Right - 70-100")
pred_right = predict_pipeline('val_70-100.csv', age=70, side='_R', hold='2')
metrics_pipeline(pred_right)
pred_right.to_csv('predict_results/test_70-100_R_h2.csv', index=False)

Hold2
##### VAL #####
# Left - 70-100
Snapshot trained for 80 epochs.                Loss: 1.617284610571002 and Val loss 3.625910752656444
4.613950490951538
MAE 3.6259114549649474
MSE 22.420263440176114
Delta -0.023752932516943954
# Right - 70-100
Snapshot trained for 61 epochs.                Loss: 0.8535681883674782 and Val loss 3.7056415949436214
4.559263229370117
MAE 3.705640153695417
MSE 22.23525691020929
Delta 0.9368515519906357
##### TEST #####
# Left - 70-100
Snapshot trained for 80 epochs.                Loss: 1.617284610571002 and Val loss 3.625910752656444
4.716772556304932
MAE 4.3629817230022505
MSE 32.03104338816646
Delta 0.871395682025429
# Right - 70-100
Snapshot trained for 61 epochs.                Loss: 0.8535681883674782 and Val loss 3.7056415949436214
4.4589011669158936
MAE 4.305228974803395
MSE 28.469069531672165
Delta 2.028652610526179


In [14]:
# AD
print("# Left - 70-100")
pred_left = predict_pipeline('ad_70-100.csv', age=70, side='_L')
metrics_pipeline(pred_left)
print("# Right - 70-100")
pred_right = predict_pipeline('ad_70-100.csv', age=70, side='_R')
metrics_pipeline(pred_right)
pred_left.to_csv('predict_results/ad_70-100_L.csv', index=False)
pred_right.to_csv('predict_results/ad_70-100_R.csv', index=False)

# MCI
print("# Left - 70-100")
pred_left = predict_pipeline('mci_70-100.csv', age=70, side='_L')
metrics_pipeline(pred_left)
print("# Right - 70-100")
pred_right = predict_pipeline('mci_70-100.csv', age=70, side='_R')
metrics_pipeline(pred_right)
pred_left.to_csv('predict_results/mci_70-100_L.csv', index=False)
pred_right.to_csv('predict_results/mci_70-100_R.csv', index=False)

# Left - 70-100
Snapshot trained for 83 epochs.                Loss: 1.650392166317153 and Val loss 3.8597484455992843
MAE 5.720929962596254
MSE 45.80675199025366
Delta 3.692867939780203
# Right - 70-100
Snapshot trained for 95 epochs.                Loss: 0.9459606564607452 and Val loss 3.9991477088422966
MAE 6.815621627004523
MSE 68.31286158034939
Delta 5.956902302281137
# Left - 70-100
Snapshot trained for 83 epochs.                Loss: 1.650392166317153 and Val loss 3.8597484455992843
MAE 4.9495386374424175
MSE 38.10408053489799
Delta 2.187438271815083
# Right - 70-100
Snapshot trained for 95 epochs.                Loss: 0.9459606564607452 and Val loss 3.9991477088422966
MAE 6.6928059308177446
MSE 62.10426971233497
Delta 4.382946917165323


In [15]:
# AD
print("# Left - 70-100")
pred_left = predict_pipeline('ad_70-100.csv', age=70, side='_L', hold='2')
metrics_pipeline(pred_left)
print("# Right - 70-100")
pred_right = predict_pipeline('ad_70-100.csv', age=70, side='_R', hold='2')
metrics_pipeline(pred_right)
pred_left.to_csv('predict_results/ad_70-100_L_h2.csv', index=False)
pred_right.to_csv('predict_results/ad_70-100_R_h2.csv', index=False)

# MCI
print("# Left - 70-100")
pred_left = predict_pipeline('mci_70-100.csv', age=70, side='_L', hold='2')
metrics_pipeline(pred_left)
print("# Right - 70-100")
pred_right = predict_pipeline('mci_70-100.csv', age=70, side='_R', hold='2')
metrics_pipeline(pred_right)
pred_left.to_csv('predict_results/mci_70-100_L_h2.csv', index=False)
pred_right.to_csv('predict_results/mci_70-100_R_h2.csv', index=False)

# Left - 70-100
Snapshot trained for 80 epochs.                Loss: 1.617284610571002 and Val loss 3.625910752656444
MAE 5.771290011611281
MSE 51.282814319498385
Delta 4.072410167347301
# Right - 70-100
Snapshot trained for 61 epochs.                Loss: 0.8535681883674782 and Val loss 3.7056415949436214
MAE 7.084966040798352
MSE 71.89335141858317
Delta 6.30101322009803
# Left - 70-100
Snapshot trained for 80 epochs.                Loss: 1.617284610571002 and Val loss 3.625910752656444
MAE 5.316323050084818
MSE 43.93353797816829
Delta 2.2249704550936875
# Right - 70-100
Snapshot trained for 61 epochs.                Loss: 0.8535681883674782 and Val loss 3.7056415949436214
MAE 6.5372819026627855
MSE 59.915671703704646
Delta 4.819177306885738
