In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import monai
from monai.transforms import LoadImage, AsDiscrete, MapLabelValue
from monai.metrics import HausdorffDistanceMetric, DiceMetric, compute_average_surface_distance

In [3]:
class SegmentationMetrics(object):
    def __init__(self, eps=1e-5, average=True, ignore_background=True, activation='0-1'):
        self.eps = eps
        self.average = average
        self.ignore = ignore_background
        self.activation = activation

    @staticmethod
    def _one_hot(gt, pred, class_num):
        # transform sparse mask into one-hot mask
        # shape: (H, W, D) -> (C, H, W, D)
        discrete_trfm = AsDiscrete(to_onehot=class_num)
        return discrete_trfm(gt.squeeze().unsqueeze(0)).unsqueeze(0)

    @staticmethod
    def _get_class_data(gt_onehot, pred, class_num):
        # perform calculation on a batch
        # for precise result in a single image, plz set batch size to 1
        matrix = np.zeros((4, class_num))

        # calculate tp, fp, fn per class
        for i in range(class_num):
            # pred shape: (N, H, W, D)
            class_pred = pred[:,i, :, :, :]
            # gt shape: (N, H, W), binary array where 0 denotes negative and 1 denotes positive
            class_gt = gt_onehot[:,i, :, :, :]

            pred_flat = class_pred.contiguous().view(-1, )  # shape: (N * H * W * D, )
            gt_flat = class_gt.contiguous().view(-1, )  # shape: (N * H * W * D, )

            tp = torch.sum(gt_flat * pred_flat)
            fp = torch.sum(pred_flat) - tp
            fn = torch.sum(gt_flat) - tp
            tn = torch.sum((gt_flat - 1).abs() * (pred_flat - 1).abs())
            
            matrix[:, i] = tp.item(), fp.item(), fn.item(), tn.item()

        return matrix

    def _calculate_multi_metrics(self, gt, pred, class_num):
        # calculate metrics in multi-class segmentation
        matrix = self._get_class_data(gt, pred, class_num)
        if self.ignore:
            matrix = matrix[:, 1:]

        # tp = np.sum(matrix[0, :])
        # fp = np.sum(matrix[1, :])
        # fn = np.sum(matrix[2, :])
        
        pixel_acc = (matrix[0] + self.eps) / (matrix[0] + matrix[1]+self.eps)
        dice = (2 * matrix[0] + self.eps) / (2 * matrix[0] + matrix[1] + matrix[2] + self.eps)
        precision = (matrix[0] + self.eps) / (matrix[0] + matrix[1] + self.eps)
        specificity = (matrix[3] + self.eps) / (matrix[3] + matrix[1] + self.eps)
        precision = specificity
        recall = (matrix[0] + self.eps) / (matrix[0] + matrix[2] + self.eps)
        
        if self.average:
            dice = np.average(dice)
            precision = np.average(precision)
            recall = np.average(recall)

        return pixel_acc, dice, precision, recall, matrix

    def __call__(self, y_true, y_pred):
        class_num = y_pred.size(1)

        if self.activation in [None, 'none']:
            activation_fn = lambda x: x
            activated_pred = activation_fn(y_pred)
        elif self.activation == "sigmoid":
            activation_fn = nn.Sigmoid()
            activated_pred = activation_fn(y_pred)
        elif self.activation == "softmax":
            activation_fn = nn.Softmax(dim=1)
            activated_pred = activation_fn(y_pred)
        elif self.activation == "0-1":
            pred_argmax = torch.argmax(y_pred, dim=1)
            activated_pred = self._one_hot(pred_argmax, y_pred, class_num)
        else:
            raise NotImplementedError("Not a supported activation!")

        gt_onehot = self._one_hot(y_true, y_pred, class_num)
        pixel_acc, dice, precision, recall, matrix = self._calculate_multi_metrics(gt_onehot, activated_pred, class_num)
        return pixel_acc, dice, precision, recall, matrix

In [10]:
common = SegmentationMetrics(activation='none',average=False)
haunsdorf = HausdorffDistanceMetric()
dice_mtr = DiceMetric()

dataloader = monai.transforms.Compose([LoadImage(), monai.transforms.EnsureType()])
discrete_trfm = AsDiscrete(to_onehot=6)
label_trfm = MapLabelValue(orig_labels=[0,1,2,3,4,5,6], target_labels=[0,1,2,3,4,4,5])


In [28]:
ROOT = 'data'
ROOT_LABEL = 'eval'
for CASE in sorted(os.listdir(ROOT)):
    data = dataloader(os.path.join(ROOT, CASE, 'merged.nii.gz'))[0].unsqueeze(0)
    data_label = dataloader(os.path.join(ROOT_LABEL, CASE, 'pred_12_trans.nii.gz'))[0].unsqueeze(0)
    data_label = data_label.unsqueeze(0)
    data = label_trfm(data)

    data = discrete_trfm(data).unsqueeze(0)
    pixel_acc, dice, spec, recall, matrix = common(data_label, data)
    hauns = haunsdorf(data, discrete_trfm(data_label.squeeze(0)).unsqueeze(0))
    dice = dice_mtr(data, discrete_trfm(data_label.squeeze(0)).unsqueeze(0))
    surf = compute_average_surface_distance(data, discrete_trfm(data_label.squeeze(0)).unsqueeze(0))
    
    header = ['Anatomy', 'Study', 'Accuracy', 'Specificity', 'Recall', 'Dice', 'Hausdorff', 'Avg.Surface']

    print('| {:^88} |'.format(f'{CASE}: Per class metric statistics'))
    print('='*92)
    print('| {:^11} | {:^6} | {:} | {:} | {:^7} | {:} | {:} | {:} |'.format(*header))
    print('='*92)
    for anatomy_idx,anatomy in enumerate(['Artery', 'Vein', 'Urethra', 'Neoplasm', 'Kidney']):
      
        print(
            '| {:^11} | {:^6} | {:>7.2f}% | {:>10.2f}% | {:>6.2f}% | {:^4.2f} | {:^9.2f} | {:^11.2f} |'.format(*[
            anatomy, 
            CASE,
            pixel_acc[anatomy_idx]*100, spec[anatomy_idx]*100, recall[anatomy_idx]*100,
            dice[0,anatomy_idx].item(), hauns[0,anatomy_idx].item(), surf[0,anatomy_idx].item()])
        )
    print('-'*92)

|                           case_1: Per class metric statistics                            |
|   Anatomy   | Study  | Accuracy | Specificity | Recall  | Dice | Hausdorff | Avg.Surface |
|   Артерия   | case_1 |    2.05% |      99.16% |   2.63% | 0.96 |   93.14   |    11.41    |
|    Вена     | case_1 |    1.72% |      98.67% |   3.45% | 0.02 |   95.89   |    13.74    |
| Мочеточник  | case_1 |    0.10% |      99.78% |   0.19% | 0.02 |   43.37   |    19.00    |
| Образование | case_1 |   12.64% |      99.34% |  23.82% | 0.00 |   92.67   |    18.90    |
|  Паренхима  | case_1 |    6.29% |      96.68% |  13.08% | 0.17 |   74.44   |    15.94    |
--------------------------------------------------------------------------------------------
|                           case_2: Per class metric statistics                            |
|   Anatomy   | Study  | Accuracy | Specificity | Recall  | Dice | Hausdorff | Avg.Surface |
|   Артерия   | case_2 |   31.12% |      99.26% |  37.40% | 0.96 |   6