From 50e40a1bf2cb7df3886b24f8274845da7bc0d8dd Mon Sep 17 00:00:00 2001 From: XiaoYulun Date: Thu, 28 Mar 2019 00:42:59 +0800 Subject: [PATCH] Update metrics.py using sklearn to calculate confusion matrix --- ptsemseg/metrics.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/ptsemseg/metrics.py b/ptsemseg/metrics.py index 36fcb08e..d46cc6de 100644 --- a/ptsemseg/metrics.py +++ b/ptsemseg/metrics.py @@ -2,23 +2,15 @@ # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py import numpy as np - +from sklearn.metrics import confusion_matrix class runningScore(object): def __init__(self, n_classes): self.n_classes = n_classes self.confusion_matrix = np.zeros((n_classes, n_classes)) - def _fast_hist(self, label_true, label_pred, n_class): - mask = (label_true >= 0) & (label_true < n_class) - hist = np.bincount( - n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2 - ).reshape(n_class, n_class) - return hist - def update(self, label_trues, label_preds): - for lt, lp in zip(label_trues, label_preds): - self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) + self.confusion_matrix += confusion_matrix(label_trues.flatten(), label_preds.flatten(), list(range(self.n_classes))) def get_scores(self): """Returns accuracy score evaluation result.