In [1]:
import os
from pathlib import Path
from sklearn.metrics import ConfusionMatrixDisplay

from ba_dev.eval_helpers import *

set_custom_plot_style()

In [6]:
path_to_models = Path('/cfs/earth/scratch/kraftjul/BA/output/complete')

model_paths = [p for p in path_to_models.iterdir() if p.is_dir()]

In [7]:
model_paths[0]

PosixPath('/cfs/earth/scratch/kraftjul/BA/output/complete/resNet50_v1_no_pretrained_cross_val')

In [4]:
run = LoadRun(
        log_path=model_paths[0]
        )

In [16]:
run.info['output']['folds']

{0: {'class_weights': [0.40064895153045654,
   6.8093485832214355,
   1.0411982536315918,
   2.5204122066497803],
  'epochs_trained': 18,
  'test_fold': 0,
  'test_metrics': {'test_acc': 0.9892040491104126,
   'test_bal_acc': 0.9254578351974487,
   'test_loss': 0.07995453476905823},
  'training_time': 21294.0},
 1: {'class_weights': [0.40515443682670593,
   6.6237473487854,
   1.0000756978988647,
   2.625297784805298],
  'epochs_trained': 23,
  'test_fold': 1,
  'test_metrics': {'test_acc': 0.9912736415863037,
   'test_bal_acc': 0.9220613241195679,
   'test_loss': 0.07774166017770767},
  'training_time': 27161.0},
 2: {'class_weights': [0.4018717110157013,
   7.012234210968018,
   1.012830376625061,
   2.619832992553711],
  'epochs_trained': 20,
  'test_fold': 2,
  'test_metrics': {'test_acc': 0.9900743961334229,
   'test_bal_acc': 0.905320405960083,
   'test_loss': 0.06662610173225403},
  'training_time': 20677.0},
 3: {'class_weights': [0.3957626223564148,
   7.019892692565918,
   1.

In [None]:
bal_accuracy_testing = []

for fold in range(5):
    value = run.info['output']['folds'][fold]['test_metrics']['test_bal_acc']
    bal_accuracy_testing.append(value)

In [18]:
bal_accuracy_testing

[0.9254578351974487,
 0.9220613241195679,
 0.905320405960083,
 0.938396155834198,
 0.9035095572471619]

In [9]:
metric = run.calculate_metrics(
    metric='balanced_accuracy_score',
    set_selection='test',
    scope='img'
    )

In [19]:
predictions = run.get_predictions(fold=0, set_selection='test')


In [27]:
predictions

Unnamed: 0,idx,class_id,set,pred_id,probs,seq_id,correct,probs_max
24,24,1,test,1,"[0.0, 1.0, 0.0, 0.0]",3000003,True,1.0000
25,25,0,test,0,"[0.9993, 0.0006, 0.0, 0.0001]",3000004,True,0.9993
26,26,0,test,0,"[1.0, 0.0, 0.0, 0.0]",3000004,True,1.0000
27,27,0,test,0,"[1.0, 0.0, 0.0, 0.0]",3000004,True,1.0000
28,28,0,test,0,"[0.9998, 0.0002, 0.0, 0.0]",3000004,True,0.9998
...,...,...,...,...,...,...,...,...
430398,430398,0,test,0,"[1.0, 0.0, 0.0, 0.0]",4018683,True,1.0000
430399,430399,0,test,0,"[0.9962, 0.0, 0.0038, 0.0]",4018683,True,0.9962
430429,430429,3,test,3,"[0.0, 0.0, 0.0, 1.0]",4018688,True,1.0000
430430,430430,3,test,3,"[0.0, 0.0, 0.0, 1.0]",4018688,True,1.0000


In [21]:
y_true = predictions['class_id'].tolist()
y_pred = predictions['pred_id'].tolist()

In [23]:
from sklearn.metrics import balanced_accuracy_score
balanced_accuracy_score(y_true, y_pred)

0.9841004892822163

In [None]:
def agg_probs(ps):
    summed = [sum(col) for col in zip(*ps)]
    total = sum(summed)
    return [v/total for v in summed]

In [50]:
aggregated = (
    ds
    .groupby('seq_id')
    .agg(
        class_id = ('class_id', 'first'),
        set = ('set', 'first'),
        count = ('pred_id', 'size'),
        pred_id_majority = ('pred_id', lambda x: x.mode()),
        probs   = ('probs',   agg_probs)
    )
    .reset_index()
)

aggregated['prob_max'] = aggregated['probs'].apply(max)
aggregated['pred_id_max'] = aggregated['probs'].apply(lambda p: p.index(max(p)))


In [51]:
aggregated

Unnamed: 0,seq_id,class_id,set,count,pred_id_majority,probs,prob_max,pred_id_max
0,1000001,0,train,6,0,"[1.0, 0.0, 0.0, 0.0]",1.000000,0
1,1000002,1,train,3,1,"[3.333333333333333e-05, 0.9999666666666666, 0....",0.999967,1
2,1000003,1,test,3,1,"[0.3169438981299376, 0.6687222907430247, 0.014...",0.668722,1
3,1000004,1,test,2,"[0, 1]","[0.51025, 0.4813, 0.0078, 0.00065]",0.510250,0
4,1000005,1,train,2,1,"[0.0, 1.0, 0.0, 0.0]",1.000000,1
...,...,...,...,...,...,...,...,...
21781,7000005,2,train,2,2,"[0.0, 0.0, 1.0, 0.0]",1.000000,2
21782,7000006,2,val,16,2,"[0.0, 0.0, 1.0, 0.0]",1.000000,2
21783,7000007,2,train,5,2,"[0.0, 0.0, 1.0, 0.0]",1.000000,2
21784,7000008,2,train,4,2,"[0.0, 0.0, 1.0, 0.0]",1.000000,2
