### metrics

In [1]:
import torch
import math
from sklearn.metrics import f1_score, average_precision_score, classification_report
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve
import os, shutil

In [2]:
class History:
    def __init__(self, save_results:str):
        self.train_loss, self.train_f1, self.train_aps, self.train_map = [], [], [], []
        self.valid_loss, self.valid_f1, self.valid_aps, self.valid_map = [], [], [], []
        self.save_results = save_results

    def __repr__(self):
        D = {}
        for key in ['train_loss', 'train_f1', 'train_map', 'valid_loss', 'valid_f1', 'valid_map']:
            D[key] = round(getattr(self,key)[-1],4) if len(getattr(self,key)) else np.nan
        return str(D)
        
    def save(self, mode='train'):
        results = {key:getattr(self,key) for key in dir(self) if 'train_' in key or 'valid_' in key}
        json.dump(results, open(os.path.join(self.save_results, f'history_{mode}.json'),'w'))


class ComputeMetrics:
    """
    label: np.array[int], shape=(N,)
    pred_probs: np.array[float], shape=(N, cls)
    # for single score -> concat 1-p and 1-p first
    # for unbounded score -> normalize first
    """
    def __init__(self, label, pred_probs, threshold_optimization=False):
        self.label = label
        self.pred_probs = pred_probs
        self.classes = pred_probs.shape[-1]
        if not threshold_optimization:
            self.pred_cls = pred_probs.argmax(axis=1)
            print(f"\ndefault_threshold={1/self.classes:.4f}")
        else:
            best_threshold = self.threshold_optimization()
            print(f"\nbest_threshold={best_threshold:.4f}")
            self.pred_cls = np.array([ row[:-1].argmax() if row.max()>=best_threshold else self.classes-1 \
                for row in self.pred_probs ])

    def threshold_optimization(self, strategy='f1'):
        best_threshold_cls = []
        for i in range(self.classes-1):
            precision, recall, thresholds = precision_recall_curve(self.label==i, self.pred_probs[:,i])
            if strategy=='f1':
                f1 = np.array([ 2*p*r/(p+r) if p+r else 0 for p,r in zip(precision,recall) ])
                best_threshold_cls.append( thresholds[f1.argmax()] )
        return sum(best_threshold_cls)/(self.classes-1)

    def get_f1(self):
        return f1_score(self.label, self.pred_cls, average='macro')
    
    def get_aps(self):
        aps = [ average_precision_score(self.label==i, self.pred_probs[:,i]) for i in range(self.classes) ]
        return aps

    def get_cls_report(self):
        return classification_report(self.label, self.pred_cls, zero_division=0)

    def get_aucs_specificities(self):
        aucs, specificities = [], []
        for i in range(self.classes):
            aucs.append( roc_auc_score(self.label==i, self.pred_probs[:,i]) )
            fpr, tpr, thresholds = roc_curve(self.label==i, self.pred_probs[:,i])
            specificities.append( 1-fpr.mean() )
        return aucs, specificities
    
    def get_confusion(self, path_list=[], losses=[]):
        confusion = [ [ [] for _ in range(self.classes) ] for _ in range(self.classes) ]
        path_list = path_list if path_list else ['']*len(self.label)
        losses = losses if losses else [-1]*len(self.label)
        for gt, pdc, path, loss in zip(self.label, self.pred_cls, path_list, losses):
            confusion[gt][pdc].append( (loss,path) )
        confusion_cnt = [ [ len(confusion[i][j]) for j in range(self.classes) ] for i in range(self.classes) ]
        return confusion, confusion_cnt
    
    def export_confusion(self, confusion, output_path, top_n=5):
        for i in range(self.classes):
            for j in range(self.classes):
                if i==j: continue
                grid_path = os.path.join(output_path, 'confusion', f"gt_{i}_pd_{j}")
                for _, path in sorted(confusion[i][j])[:top_n]:
                    os.makedirs(grid_path, exist_ok=True)
                    shutil.copy(path, grid_path)

    def export_lowest_conf(self, path_list, output_path, top_n=5):
        prob_path_list = sorted(zip(self.pred_probs.max(axis=1), path_list))
        worst_path = f"{output_path}/worst_imgs"
        os.makedirs(worst_path, exist_ok=True)
        for _, path in prob_path_list[:top_n]:
            shutil.copy(path, worst_path)