In [2]:
import numpy as np
import glob
import os
import cv2
from PIL import Image
from pprint import pprint
from collections import defaultdict
import pandas as pd
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
class Evaluator(object):
    """
    https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py
    """
    def __init__(self, num_class):
        self.num_class = num_class
        self.confusion_matrix = np.zeros((self.num_class,)*2)

    def Pixel_Accuracy(self):
        Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return Acc

    def Pixel_Accuracy_Class(self):
        Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        Acc = np.nanmean(Acc)
        return Acc

    def Mean_Intersection_over_Union(self):
        MIoU = np.diag(self.confusion_matrix) / (
                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
                    np.diag(self.confusion_matrix))
        MIoU = np.nanmean(MIoU)
        return MIoU

    def Frequency_Weighted_Intersection_over_Union(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
                    np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
                    np.diag(self.confusion_matrix))

        FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
        return FWIoU, iu

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def add_batch(self, gt_image, pre_image):
        assert gt_image.shape == pre_image.shape
        self.confusion_matrix += self._generate_matrix(gt_image, pre_image)

    def reset(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

def get_msk(msk_pth):
    msk = np.asarray(Image.open(msk_pth)).copy()
    msk[msk>0] = 1
    return msk

csv_pths = {
    'val0': r'../data/val_0.csv',
    'val1': r'../data/val_1.csv',
    'val2': r'../data/val_2.csv'
}


In [11]:
gt_dir = r'/boot/data1/kang_data/zkxt21_seaice/trainData/gt'

model_dict = {
    'timm-tf_efficientnet_lite3val1': '../EffUNet/20230313_142125/output_path',
}


In [12]:
results = defaultdict()
evaluator = Evaluator(2)
for method, pred_dir in model_dict.items():
    evaluator.reset()
    print(method)
    results[method] = defaultdict()
    pred_pths = glob.glob(os.path.join(pred_dir,'*.png'))
    pred_nms = map(os.path.basename, pred_pths)
    gt_pths = [os.path.join(gt_dir, pred_nm) for pred_nm in pred_nms]
    for _ in map(lambda x: evaluator.add_batch(get_msk(x[0]), get_msk(x[1])), zip(gt_pths, pred_pths)):
        pass

    results[method]['FwIoU'], results[method]['IoUs'] = evaluator.Frequency_Weighted_Intersection_over_Union()
    results[method]['OA'] = evaluator.Pixel_Accuracy()
    
pprint(results)

timm-tf_efficientnet_lite3val1
timm-tf_efficientnet_lite3val2
defaultdict(None,
            {'timm-tf_efficientnet_lite3val1': defaultdict(None,
                                                           {'FwIoU': 0.9864927753876305,
                                                            'IoUs': array([0.99229092, 0.94071814]),
                                                            'OA': 0.9931311959703567}),
             'timm-tf_efficientnet_lite3val2': defaultdict(None,
                                                           {'FwIoU': 0.9857390663676159,
                                                            'IoUs': array([0.99186587, 0.93736981]),
                                                            'OA': 0.9927486551219019})})
