In [0]:
import torch

y_true = torch.tensor([0,1,2,2,3,3,3,4,4,4,4,5,5,5,5,5])
y_pred = torch.tensor([1,1,3,3,1,2,3,4,4,1,1,1,2,3,4,5])

In [0]:
def print_cm_and_class_stats(y_true, y_pred, verbose=False):
  # get list of class labels by calling the unique vlaues in y_pred see https://pytorch.org/docs/stable/torch.html#torch.unique 
  print(y_true)
  print(y_pred)
  classes = y_true.unique()

  # number of classes
  num_classes = len(classes)

  # initialiye confusion matrix see https://en.wikipedia.org/wiki/Confusion_matrix
  conf_matrix = torch.zeros(num_classes, num_classes)

  # map just puts the two tensors together so they can be iterated over at once think of stacking them together
  # add the value in the corresponding confusion matrix cell by one for each truth / predicitn pair
  for t, p in zip(y_true, y_pred):
    conf_matrix[t, p] += 1

  # let's print the confusion matrix
  print('\nConfusion Matrix')
  dim_0_name = 'Ground Truth'.center(2 * num_classes, ' ')
  dim_1_name = 'Prediction'.center(4 * num_classes, ' ')
  print(dim_1_name)
  for i in range(num_classes):
    print(dim_0_name[2*i], end = '')
    for j in range(num_classes):
      print (f'{conf_matrix[i][j].item():4.0f}', end = '')
    print('\n'+dim_0_name[2*i+1])

  # the diagonal represents the correctly predicted samples for each class
  # these are the true positive samples
  true_positives = conf_matrix.diag()
  if verbose:
    print('\ntrue positives per class:',true_positives)

  # all predicted positives are calculated by summing up all predictions for each class
  # they include true positive and false positive samples
  predicted_positives = torch.sum(conf_matrix, dim=0)
  if verbose:
    print('all positives predictions per class:',predicted_positives)

  # false positives = samples where class was predicted but actually is another class
  # we get this by subtracting the correctly positive classified samples from all positive predictions
  false_positives = predicted_positives - true_positives
  if verbose:
    print('false positives per class:',false_positives)

  # false negatives = samples which where not classified correctly (as positive)
  # i.e. the actual positive samples minus the ones detected (true positives)
  # therefore we need the actual positive samples per class first
  actual_positives = torch.sum(conf_matrix, dim=1)
  if verbose:
    print('actual positives per class:',actual_positives)

  # false_negatives = all actual positive samples per class minus the ones that were predicted correctly
  false_negatives = actual_positives - true_positives
  if verbose:
    print('false negatives per class:',false_negatives)

  # true negatives = predicted_negatives - false negatives
  # therefore we first have to get all negatives by subtracting all positive samples per class from the total number of samples
  all_samples = conf_matrix.sum()
  if verbose:
    print('total number of samples:',all_samples)

  # note: all samples is just a scalar number, but will be deflated to the size of all_positives (i.e. the number of classes)
  predicted_negatives = all_samples - predicted_positives
  if verbose:
    print('all negative predictions per class:',predicted_negatives)

  # now we can calculate the true negatives
  true_negatives = predicted_negatives - false_negatives
  if verbose:
    print('true negatives per class:',true_negatives)\

  # for faster typing, let's rename those
  fn = false_negatives
  fp = false_positives
  tn = true_negatives
  tp = true_positives
  actual_negatives = tn + fp

  assert(fn + fp +  tn + tp == all_samples).all, 'sanity check failed'

  #--------------------
  # OVERALL STATISTICS
  #--------------------
  print('\nGlobal statistics:')
  
  # calculate accuracy by dividing correct predictions by total number of predictions
  accuracy = (y_pred == y_true).sum().float()/len(y_pred)
  print(f'- Accuracy: {accuracy:.2%}')


  #--------------------
  # CLASS STATISTICS
  #--------------------
  # finally we can calculate class wise metrics:
  class_stat={}

  # Sensitivity = TP / (TP+FN) = TP / P 
  # https://en.wikipedia.org/wiki/Sensitivity_and_specificity
  # also called the true positive rate, the recall, or probability of detection
  # measures the proportion of actual positives that are correctly identified as such 
  # e.g., the percentage of sick people who are correctly identified as having the condition
  class_stat['Sensitivity'] = tp / actual_positives
  recall = class_stat['Sensitivity']

  # Specificity = TN / (TN+FP) = TN / N 
  # https://en.wikipedia.org/wiki/Sensitivity_and_specificity
  # also called the true negative rate
  # measures the proportion of actual negatives that are correctly identified as such
  # e.g., the percentage of healthy people who are correctly identified as not having the condition
  class_stat['Specificity'] = tn / actual_negatives

  # Accuracy = (TP + TN) / ALL = (TP + TN) / P + N = (TP + TN) / (TP+FN+TN+FP)
  # https://en.wikipedia.org/wiki/Accuracy_and_precision
  # measure of a test's accuracy
  class_stat['Accuracy'] = (tp+tn) / (actual_positives + actual_negatives)
  #class_ind = 0
  #print(f"TP: {tp[class_ind]}; TN: {tn[class_ind]}; positives: {actual_positives[class_ind]}, negatives: {actual_negatives[class_ind]}")

  # Precision = TP / (TP+FP)
  # https://en.wikipedia.org/wiki/Precision_and_recall
  # also called positive predictive value
  # fraction of relevant instances among the retrieved instances
  # (for recall = sensitivity see above)
  class_stat['Precision'] = tp / predicted_positives
  precision = class_stat['Precision']

  # F1-Score = 2 * (recall*precision)/(recall+precision)
  # https://en.wikipedia.org/wiki/F1_score
  # measure of a test's accuracy
  class_stat['F1-Score'] = 2 * recall * precision / (recall+precision)

  print('\nStatistics per class:')
  print_row('Classes', classes, head=True)
  for key, stat in class_stat.items():
    print_row(key, stat)

def print_row(name, items, col_width = 12, head=False):
  num_cells = len(items)
  if head:
    print((('_'*col_width)+'_')*(num_cells+1))
  print(((' '*col_width)+'|')*(num_cells+1))
  print(name.ljust(col_width)+'|', end='')
  for cell in items:
    print(f'{cell:^{col_width}{".0f" if head else ".2%"}}|', end='')
  print('\n'+(('_'*col_width)+'|')*(num_cells+1))

In [5]:
# you can set verbose=True for more details
print_cm_and_class_stats(y_true=y_true, y_pred=y_pred)

tensor([0, 1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5])
tensor([1, 1, 3, 3, 1, 2, 3, 4, 4, 1, 1, 1, 2, 3, 4, 5])

Confusion Matrix
       Prediction       
G   0   1   0   0   0   0
r
o   0   1   0   0   0   0
u
n   0   0   0   2   0   0
d
    0   1   1   1   0   0
T
r   0   2   0   0   2   0
u
t   0   1   1   1   1   1
h

Global statistics:
- Accuracy: 31.25%

Statistics per class:
___________________________________________________________________________________________
            |            |            |            |            |            |            |
Classes     |     0      |     1      |     2      |     3      |     4      |     5      |
____________|____________|____________|____________|____________|____________|____________|
            |            |            |            |            |            |            |
Sensitivity |   0.00%    |  100.00%   |   0.00%    |   33.33%   |   50.00%   |   20.00%   |
____________|____________|____________|____________|_________