In [1]:
%cd ~/RATER-C/results

/home/daved/RATER-C/results


In [2]:
from glob import glob

import pandas as pd
import numpy as np

from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix
from netcal.metrics import ECE

In [3]:
files = glob('*_preds.csv')
model_names = [file.split('_preds.csv')[0] for file in files]

In [4]:
res = []

for model_name in model_names:

    test = pd.read_csv(model_name + '_preds.csv')
    test.head(3)
    
    y_true = test['target']
    probs = test['prob']
    preds = np.where(probs >= 0.5, 1, 0)
    
    tn, fp, fn, tp = confusion_matrix(y_true, preds).ravel()
    auc = np.round(roc_auc_score(y_true, probs), 3)
    f1 = np.round(f1_score(y_true, preds), 3)
    macro_f1 = np.round(f1_score(y_true, preds, average = 'macro'), 3)
    
    n_bins = 7
    ece = np.round(ECE(bins = n_bins).measure(np.array(probs), np.array(y_true)), 3)
    
    metrics = {
        'model': model_name.replace('_', '/'),
        'tn': tn,
        'fp': fp,
        'fn': fn,
        'tp': tp,
        'f1': f1,
        'auc': auc,
        'macro_f1': macro_f1,
        'ece': ece
    }

    out = pd.DataFrame.from_dict(metrics, orient = 'index').transpose()
    res.append(out)
    
pd.concat(res).reset_index(drop = True).sort_values(['auc', 'macro_f1'], ascending = False)

Unnamed: 0,model,tn,fp,fn,tp,f1,auc,macro_f1,ece
10,microsoft/deberta-v3-large,27557,1632,1540,2575,0.619,0.911,0.782,0.076
3,microsoft/deberta-v3-base,28061,1128,1964,2151,0.582,0.904,0.765,0.018
8,microsoft/deberta-large,28207,982,2188,1927,0.549,0.893,0.748,0.079
12,microsoft/deberta-base,27595,1594,1830,2285,0.572,0.887,0.757,0.033
6,xlnet/xlnet-large-cased,27453,1736,1818,2297,0.564,0.882,0.752,0.077
13,microsoft/deberta-v3-small,27834,1355,1976,2139,0.562,0.877,0.753,0.031
5,google-bert/bert-large-uncased,28022,1167,2129,1986,0.547,0.877,0.745,0.027
0,FacebookAI/roberta-base,27036,2153,1645,2470,0.565,0.873,0.75,0.083
2,microsoft/deberta-v3-xsmall,27596,1593,1960,2155,0.548,0.872,0.744,0.036
1,albert/albert-base-v2,28354,835,2476,1639,0.497,0.871,0.721,0.01
