In [2]:
from typing import Any
import numpy as np

In [5]:
class ConfMatrix:
    def __init__(self, label: list[Any]):
        assert(len(label) == len(set(label))), "label contains duplicate data!"
        self.label = label
        self.label2index = {k: i for i, k in enumerate(self.label)}
        self.cm = np.zeros((len(self.label), len(self.label)))
        
    def add(self, gts: list[Any], ests: list[Any]):
        assert(len(gts) == len(ests))
        for gt, est in zip(gts, ests):
            assert gt in self.label, "invalid gt label: {}".format(gt)
            assert est in self.label, "invalid est label: {}".format(est)
            gt_index = self.label2index[gt]
            est_index = self.label2index[est]
            self.cm[gt_index][est_index] += 1
            
    def get_cm(self):
        return self.cm
    
    def calc_accuracy(self):
        """calculate accuracy of the estimations """
        # to prevent zero division
        cm_total = np.sum(self.cm)
        if cm_total == 0:
            cm_total = 1
        
        acc = np.sum(np.diag(self.cm))/cm_total
        return acc
    
#         ests 
#  g    [3. 0. 0.]
#  t    [0. 1. 2.]
#  s    [2. 1. 3.]
#est sum[6. 2. 5.]
    def calc_precision(self):
        """calculate the precision of the all estimations
           precision = true_positive / all_estimations
        """
        cm = self.cm
        # total estimation number for all the classes
        totals = cm.sum(axis=0)
        # fill 1 to avoid divide-by-zero
        totals[totals == 0] = 1
#         print("cm shape: {}, totals shape: {}, new shape: {}".format(cm.shape, totals.shape, totals[np.newaxis, :].shape))
#         cm = cm / totals[np.newaxis, :]
        cm = cm / totals.reshape(1, -1)
        return np.diag(cm)
    
#         ests      gts sum
#  g    [3. 0. 0.]   3
#  t    [0. 1. 2.]   3
#  s    [2. 1. 3.]   6
    def calc_recall(self):
        """calculate the recall of all the estimations """
        cm = self.cm
        # total ground truth number of all the classes
        totals = cm.sum(axis=1)
        # fill 1 to avoid divide-by-zero
        totals[totals == 0] = 1
#         cm = cm / totals[:, np.newaxis]
        cm = cm / totals.reshape(-1, 1)
        return np.diag(cm)
    
    def calc_f1(self):
        precisions = self.calc_precision()
        recalls = self.calc_recall()
        assert(precisions.shape == recalls.shape), "precisions and recalls have different shapes"
        f1s_denomminator = precisions + recalls
        # fill 1 to avoid divide-by-zero
        f1s_denomminator[f1s_denomminator==0] = 1
        f1s = 2*np.multiply(precisions, recalls)/f1s_denomminator
        return f1s

In [6]:
gts = [2, 0, 2, 2, 0, 1, 1, 2, 2, 0, 1, 2]
ests = [0, 0, 2, 1, 0, 2, 1, 0, 2, 0, 2, 2]

np.set_printoptions(precision=2)
label = set(gts)
conf_matrix = ConfMatrix(label=label)
conf_matrix.add(gts, ests)
cm = conf_matrix.get_cm()
print("cm: \n{}".format(cm))
acc = conf_matrix.calc_accuracy()
print("acc: {:.2f}".format(acc))
prec = conf_matrix.calc_precision()
print("prec: {}".format(prec))
recall = conf_matrix.calc_recall()
print("recall: {}".format(recall))
f1s = conf_matrix.calc_f1()
print("f1s: {}".format(f1s))

cm: 
[[3. 0. 0.]
 [0. 1. 2.]
 [2. 1. 3.]]
acc: 0.58
prec: [0.6 0.5 0.6]
recall: [1.   0.33 0.5 ]
f1s: [0.75 0.4  0.55]
