In [2]:
import sys
import os
import torch
import glob
import pandas as pd
import numpy as np
import pingouin as pg
from age_prediction.models.\
    efficientnet_pytorch_3d import EfficientNet3D as EfNetB0
from age_prediction.metrics import MSE, MAE
from age_prediction.trainer import ModuleTrainer
from age_prediction.dataloader import MyDataLoader

In [4]:
def check_state_dict(ck, md):
    if ck.keys() == md.keys():
        return ''
    else:
        split_key = list(ck.keys())[0].split(
            list(md.keys())[0])[0]
        return split_key

def correct_state_dict(ck, split_key):
    ck2 = {}
    for k in ck.keys():
        ck2[k.split(split_key)[1]] = ck[k]
    return ck2

def MAE_pred(y_pred, y_true):
    mae = torch.nn.L1Loss(reduction='mean')(y_pred, y_true).cpu().detach().numpy().item()
    return mae

def MSE_pred(y_pred, y_true):
    mse = torch.nn.MSELoss()(y_pred, y_true).cpu().detach().numpy().item()
    return mse

def delta_pred(y_pred, y_true):
    delta = y_pred - y_true
    return np.mean(delta)

def pearson(y_pred, y_true):
    return pg.corr(y_pred, y_true)['r'][0]


In [1]:
def get_snapshot_results(side, snapshot, dropout_rate, testfile):
    # Load effNet3D B0
    model = EfNetB0.from_name("efficientnet-b0",
                            override_params={
                                'num_classes': 1,
                                'dropout_rate': dropout_rate
                            },
                            in_channels=1,
                            )
    device = torch.device('cpu')
    model = model.to(device)

    optimizer = torch.optim.RMSprop(model.parameters(),
                                lr=.256, alpha=0.9,
                                eps=1e-08, momentum=0.9,
                                weight_decay=0)

    # print('Loading model from {}'.format(snapshot))
    checkpoint = torch.load(snapshot, map_location=device)
    split_key = check_state_dict(checkpoint['state_dict'], model.state_dict())
    if split_key is not None:
        checkpoint['state_dict'] = correct_state_dict(checkpoint['state_dict'], split_key)

    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch = checkpoint['epoch']
    _loss = checkpoint['loss']
    _val_loss = checkpoint['val_loss']
    # print("Snapshot trained for {} epochs. \
    #     Loss: {} and Val loss {}".format(epoch, _loss, _val_loss))

    loss = torch.nn.L1Loss(reduction='mean')
    metrics = [MSE()]
    
    # Predict
    trainer = ModuleTrainer(model.to(device))

    trainer.compile(loss=loss,
                    optimizer=optimizer,
                    metrics=metrics)
    
    dataloader = MyDataLoader(database='../datasets',
                          csv_data='database_split',
                          side=side,
                          batch=512,
                          data_aug=False,
                          test_file=testfile
                          )
    dataloader.prepare_data('test_label')
    dataloader.setup('test_label')
    
    # print("Test size", len(dataloader.test.inputs[0]))


    pred = trainer.predict_loader(dataloader.testlabel_dataloader(),
                                  cuda_device=False)
    
    imgs = dataloader.testpath
    imgs = [img.split("/")[-1].split(side)[0] for img in imgs]
    
    preds = pd.concat([pd.DataFrame(imgs),
                       pd.DataFrame(dataloader.testlabel),
                       pd.DataFrame(pred.detach().cpu().numpy())], axis=1)
    preds.columns = ['Name', 'True', 'Pred']
    
    return preds, epoch

def metrics_pipeline(preds):
    mae = MAE_pred(torch.from_numpy(preds['Pred'].values), torch.from_numpy(preds['True'].values))
    mse = MSE_pred(torch.from_numpy(preds['Pred'].values), torch.from_numpy(preds['True'].values))
    delta = delta_pred(preds['Pred'], preds['True'])
    r = pearson(preds['Pred'], preds['True'])
    # print("MAE", mae)
    # print("MSE", mse)
    # print("Delta", delta)
    # print("Pearson", r)
    return mae, mse, delta, r


In [3]:
def train_results(logger, epoch):
    df = pd.read_csv(logger)
    col = df.columns[~df.columns.str.contains('val|epoch')]
    df = df.loc[epoch - 1, col]
    # order mae, mse
    if any(df.index.str.contains('mae')):
        mae = df['mae_metric']
        mse = df['loss']
    else:
        mae = df['loss']
        mse = df['mse_metric']
    return mae, mse

In [5]:
def organize_results(res, type, dp, age=None, side=None):
    if age is not None:
        df = {'side': side, 'age': age}
        df['mae_' + type] = res[0]
        df['mse_' + type] = res[1]
    else:
        df = {'mae_' + type: res[0], 'mse_' + type: res[1]}
    if type != 'train':
        _, _, delta, r = res
        df['delta_' + type] = delta
        df['r_' + type] = r   

    return pd.DataFrame(df, index=[dp])


In [8]:
def get_results(side, age, date):
    results = pd.DataFrame()
    for ckpt in glob.glob('outputs/ckpt_'+date+'*_best*'):
        if "e-5" not in ckpt:
            dropout_rate = float(ckpt.split("_dp")[-1].split("_")[0])
            # val
            pred, epoch = get_snapshot_results(side, ckpt, dropout_rate, 'val_'+age+'.csv')
            val = organize_results(metrics_pipeline(pred), 'val',
                                dropout_rate)
            # print("TEST")
            pred, epoch = get_snapshot_results(side, ckpt, dropout_rate, 'test_'+age+'.csv')
            test = organize_results(metrics_pipeline(pred), 'test', dropout_rate)
            # print("TRAIN RES")
            train = organize_results(train_results(ckpt.replace("ckpt", "logger")
                                                    .replace("_model_best.pth.tar", ".csv"),
                                                epoch), 'train', dropout_rate, age, side.split("_")[-1])
            df = pd.concat([train, val, test], axis=1)
            results = pd.concat([results, df])
    return results.sort_index()


In [10]:
res_R = get_results('_R', '0-70', '30-04-2021')
res_L = get_results('_L', '0-70', '26-04-2021')
results = pd.concat([res_L, res_R])
results.to_csv('predict_results/metrics_0-70.csv')
results

Unnamed: 0,side,age,mae_train,mse_train,mae_val,mse_val,delta_val,r_val,mae_test,mse_test,delta_test,r_test
0.2,L,0-70,3.532041,22.587549,3.563695,26.085099,0.217828,0.94712,5.088891,56.416066,-0.537932,0.881968
0.3,L,0-70,3.048431,16.19577,3.749202,25.368216,0.352153,0.948762,5.489556,50.760203,-0.757924,0.892862
0.4,L,0-70,1.627977,4.194952,3.206993,31.542526,0.35003,0.935815,5.41279,53.233595,-1.148087,0.889608
0.5,L,0-70,3.674227,23.185775,3.755647,30.183296,0.420059,0.938986,5.714946,67.928661,-0.887353,0.867021
0.6,L,0-70,4.166114,30.867123,4.337871,31.042238,-0.283843,0.9365,5.338261,48.727451,-1.049087,0.900675
0.2,R,0-70,2.001938,6.555099,2.8192,17.932193,0.601597,0.965186,5.333038,48.321063,-0.379081,0.900897
0.3,R,0-70,1.347234,2.850212,3.151737,24.064317,-0.088164,0.952101,6.09233,64.90936,-0.982236,0.86511
0.4,R,0-70,3.550543,22.436674,4.398792,34.288252,0.790195,0.930736,6.0112,62.53787,-1.209486,0.869262
0.5,R,0-70,4.000049,28.301527,4.845988,41.022643,1.124578,0.918074,6.194015,66.464243,-0.868537,0.859197
