In [1]:
import sys
import os
import torch
import glob
import pandas as pd
import numpy as np
import pingouin as pg
from age_prediction.models.\
    efficientnet_pytorch_3d import EfficientNet3D as EfNetB0
from age_prediction.metrics import MSE, MAE
from age_prediction.trainer import ModuleTrainer
from age_prediction.dataloader import MyDataLoader

  **kwargs


In [2]:
def check_state_dict(ck, md):
    if ck.keys() == md.keys():
        return ''
    else:
        split_key = list(ck.keys())[0].split(
            list(md.keys())[0])[0]
        return split_key

def correct_state_dict(ck, split_key):
    ck2 = {}
    for k in ck.keys():
        ck2[k.split(split_key)[1]] = ck[k]
    return ck2

def MAE_pred(y_pred, y_true):
    mae = torch.nn.L1Loss(reduction='mean')(y_pred, y_true).cpu().detach().numpy().item()
    return mae

def MSE_pred(y_pred, y_true):
    mse = torch.nn.MSELoss()(y_pred, y_true).cpu().detach().numpy().item()
    return mse

def delta_pred(y_pred, y_true):
    delta = y_pred - y_true
    return np.mean(delta)

def pearson(y_pred, y_true):
    return pg.corr(y_pred, y_true)['r'][0]


In [3]:
def get_snapshot_results(side, snapshot, dropout_rate, testfile):
    # Load effNet3D B0
    model = EfNetB0.from_name("efficientnet-b0",
                              override_params={
                                'num_classes': 1,
                                'dropout_rate': dropout_rate
                              },
                              in_channels=1,
                              )
    device = torch.device('cpu')
    model = model.to(device)

    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=.256, alpha=0.9,
                                    eps=1e-08, momentum=0.9,
                                    weight_decay=0)

    # print('Loading model from {}'.format(snapshot))
    checkpoint = torch.load(snapshot, map_location=device)
    split_key = check_state_dict(checkpoint['state_dict'], model.state_dict())
    if split_key is not None:
        checkpoint['state_dict'] = correct_state_dict(checkpoint['state_dict'], split_key)

    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch = checkpoint['epoch']
    _loss = checkpoint['loss']
    _val_loss = checkpoint['val_loss']
    # print("Snapshot trained for {} epochs. \
    #     Loss: {} and Val loss {}".format(epoch, _loss, _val_loss))

    loss = torch.nn.L1Loss(reduction='mean')
    metrics = [MSE()]
    
    # Predict
    trainer = ModuleTrainer(model.to(device))

    trainer.compile(loss=loss,
                    optimizer=optimizer,
                    metrics=metrics)
    
    dataloader = MyDataLoader(database='../datasets',
                              csv_data='database_split',
                              side=side,
                              batch=512,
                              data_aug=False,
                              test_file=testfile
                              )
    dataloader.prepare_data('test_label')
    dataloader.setup('test_label')
    
    # print("Test size", len(dataloader.test.inputs[0]))


    pred = trainer.predict_loader(dataloader.testlabel_dataloader(),
                                  cuda_device=False)
    
    imgs = dataloader.testpath
    imgs = [img.split("/")[-1].split(side)[0] for img in imgs]
    
    preds = pd.concat([pd.DataFrame(imgs),
                       pd.DataFrame(dataloader.testlabel),
                       pd.DataFrame(pred.detach().cpu().numpy())], axis=1)
    preds.columns = ['Name', 'True', 'Pred']
    
    return preds, epoch

def metrics_pipeline(preds):
    mae = MAE_pred(torch.from_numpy(preds['Pred'].values),
                   torch.from_numpy(preds['True'].values))
    mse = MSE_pred(torch.from_numpy(preds['Pred'].values),
                   torch.from_numpy(preds['True'].values))
    delta = delta_pred(preds['Pred'],
                       preds['True'])
    r = pearson(preds['Pred'],
                preds['True'])
    # print("MAE", mae)
    # print("MSE", mse)
    # print("Delta", delta)
    # print("Pearson", r)
    return mae, mse, delta, r


In [4]:
def train_results(logger, epoch):
    df = pd.read_csv(logger)
    col = df.columns[~df.columns.str.contains('val|epoch')]
    df = df.loc[epoch - 1, col]
    # order mae, mse
    if any(df.index.str.contains('mae')):
        mae = df['mae_metric']
        mse = df['loss']
    else:
        mae = df['loss']
        mse = df['mse_metric']
    return mae, mse

In [5]:
def organize_results(res, type, dp, ep=None, age=None, side=None):
    if age is not None:
        df = {'side': side, 'age': age}
        if ep is not None:
            df['epoch'] = ep
        df['mae_' + type] = res[0]
        df['mse_' + type] = res[1]
    else:
        df = {'mae_' + type: res[0], 'mse_' + type: res[1]}
    if type != 'train':
        _, _, delta, r = res
        df['delta_' + type] = delta
        df['r_' + type] = r   

    return pd.DataFrame(df, index=[dp])


In [11]:
def get_results(side, age, date):
    results = pd.DataFrame()
    for ckpt in glob.glob('outputs/ckpt_'+date+'*'):
        if age in ckpt and side+"_" in ckpt:
            # if 'best' in ckpt:
            # print(ckpt)
            dropout_rate = float(ckpt.split("_dp")[-1].split("_")[0])
            # print(dropout_rate)
            delim = ckpt.split("_model")[-1]
            delim = delim.split("_clr")[0]
            delim = '_model' + delim
            # else:
            #     dropout_rate = float(ckpt.split("_dp")[-1].split(".pt")[0])
            #     delim = ckpt.split("_model")[-1]
            #     delim = delim.split("_clr")[0]
            #     delim = '_model' + delim
                # delim = ".pth.tar"
            # val
            pred, epoch = get_snapshot_results(side, ckpt, dropout_rate, 'val_'+age+'.csv')
            # print(pred)
            val = organize_results(metrics_pipeline(pred), 'val', dropout_rate)
            # print("TEST")
            pred, epoch = get_snapshot_results(side, ckpt, dropout_rate, 'test_'+age+'.csv')
            test = organize_results(metrics_pipeline(pred), 'test', dropout_rate)
            train = organize_results(train_results(ckpt.replace("ckpt", "logger").replace(delim, "").replace('pth.tar', "csv"),
                                                    epoch), 'train', dropout_rate, epoch,
                                                        age, side.split("_")[-1])
            df = pd.concat([train, val, test], axis=1)
            results = pd.concat([results, df])
    return results.sort_index()


In [14]:
res_R = get_results('_R', '70-100', '17-06-2021')
res_L = get_results('_L', '70-100', '17-06-2021')

results = pd.concat([res_L, res_R])
results

Unnamed: 0,side,age,epoch,mae_train,mse_train,mae_val,mse_val,delta_val,r_val,mae_test,mse_test,delta_test,r_test
0.2,L,70-100,80,1.617285,4.468334,4.362982,32.031043,0.871396,0.211786,3.625911,22.420263,-0.023753,0.386211
0.2,L,70-100,100,1.293031,3.455635,5.785848,49.316305,3.817043,0.283984,5.556869,46.254741,3.022629,0.237073
0.2,R,70-100,100,1.309929,2.962685,5.084044,42.371272,3.815696,0.321425,4.370785,30.700501,2.828327,0.361822
0.2,R,70-100,61,0.853568,1.442674,4.305229,28.46907,2.028653,0.367246,3.70564,22.235257,0.936852,0.362152


In [9]:
results.to_csv('predict_results/metrics_0-70.csv')

In [8]:
res_R = get_results('_R', '70-100', '27-05-2021')
res_L = get_results('_L', '70-100', '27-05-2021')
results = pd.concat([res_L, res_R])
results

Test size 151
Test size 151
Test size 151
Test size 151
Test size 151
Test size 151
Test size 151
Test size 151


Unnamed: 0,side,age,epoch,mae_train,mse_train,mae_val,mse_val,delta_val,r_val,mae_test,mse_test,delta_test,r_test
0.2,L,70-100,22,5.412542,45.607456,3.339176,17.546778,-0.240355,0.434532,3.476039,19.819592,-1.019994,0.32341
0.3,L,70-100,12,6.9911,70.692521,3.600096,18.229512,-0.513932,0.194589,3.520462,17.580773,-0.578576,0.262749
0.2,R,70-100,33,5.314347,45.219057,3.889771,22.377662,2.363966,0.336537,3.871664,21.090473,1.73473,0.283575
0.3,R,70-100,6,48.832,3930.888534,4.056024,22.660267,1.85986,-0.068574,3.885455,20.190156,1.700866,0.096561
