In [1]:
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, cohen_kappa_score, roc_auc_score, log_loss

def get_scores(y_true, y_pred, y_pred_hard):
    acc = accuracy_score(y_true, y_pred_hard)
    auc = roc_auc_score(y_true, y_pred, average="macro", multi_class="ovr")
    f1 = f1_score(y_true, y_pred_hard, average="macro")
    kappa = cohen_kappa_score(y_true, y_pred_hard)
    quad_kappa = cohen_kappa_score(y_true, y_pred_hard, weights="quadratic")
    logloss = log_loss(y_true, y_pred)
    return dict(Accuracy=acc, AUC=auc, F1=f1, LogLoss=logloss, Kappa=kappa, QuadKappa=quad_kappa)

quickqual_results = pd.read_csv('quickqual_results.csv',
                                index_col=0)
mcfnet_results = pd.read_csv('https://raw.githubusercontent.com/'
                             'HzFu/EyeQ/master/MCF_Net/result/DenseNet121_v3_v1.csv',
                             index_col=0)
ground_truth = pd.read_csv('https://raw.githubusercontent.com/'
                           'HzFu/EyeQ/master/data/Label_EyeQ_test.csv',
                           index_col=0)

y_true = ground_truth['quality'].values

quickqual_preds = quickqual_results[['Good', 'Usable', 'Reject']].values
# take argmax to get hard predictions
# (hard = single class label instead of probabilities for each class)
quickqual_preds_hard = quickqual_preds.argmax(axis=1)

mcfnet_preds = mcfnet_results[['Good', 'Usable', 'Reject']].values
mcfnet_hardpreds = mcfnet_preds.argmax(axis=1)
# mcfnets predictions sometimes sum to more than 1, so we normalize them as otherwise sklearn metrics will complain
# NB: for the histgrams of predictions, we use the original predictions, not the normalized ones
mcfnet_preds = mcfnet_preds / mcfnet_preds.sum(1).reshape(-1, 1)

results_table = pd.DataFrame(columns=['Accuracy', 'AUC', 'F1', 'LogLoss', 'Kappa', 'QuadKappa'])

results_table = pd.concat([results_table, pd.DataFrame([get_scores(y_true, mcfnet_preds, mcfnet_hardpreds)], index=['MCFNet'])])
results_table = pd.concat([results_table, pd.DataFrame([get_scores(y_true, quickqual_preds, quickqual_preds_hard)], index=['QuickQual'])])
results_table.round(4)

Unnamed: 0,Accuracy,AUC,F1,LogLoss,Kappa,QuadKappa
MCFNet,0.88,0.9588,0.8606,0.3632,0.8017,0.8955
QuickQual,0.8863,0.9687,0.8675,0.3049,0.8107,0.9019
