From 57d47f44bf62e5d989513cab1475b02b265f1b29 Mon Sep 17 00:00:00 2001 From: e1y4r Date: Thu, 13 Jan 2022 14:32:08 +0800 Subject: [PATCH] Added kappa and dice score to metrics. --- change_detection_pytorch/utils/functional.py | 53 ++++++++++++++++++++ change_detection_pytorch/utils/metrics.py | 36 +++++++++++++ 2 files changed, 89 insertions(+) diff --git a/change_detection_pytorch/utils/functional.py b/change_detection_pytorch/utils/functional.py index cef707d..282aa4b 100644 --- a/change_detection_pytorch/utils/functional.py +++ b/change_detection_pytorch/utils/functional.py @@ -126,3 +126,56 @@ def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): score = (tp + eps) / (tp + fn + eps) return score + + +def kappa(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): + """Calculate kappa score between ground truth and prediction + Args: + pr (torch.Tensor): A list of predicted elements + gt (torch.Tensor): A list of elements that are to be predicted + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: kappa score + """ + + pr = _threshold(pr, threshold=threshold) + pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) + + tp = torch.sum(gt * pr) + fp = torch.sum(pr) - tp + fn = torch.sum(gt) - tp + tn = torch.sum((1 - gt)*(1 - pr)) + + N = tp + tn + fp + fn + p0 = (tp + tn) / N + pe = ((tp + fp) * (tp + fn) + (tn + fp) * (tn + fn)) / (N * N) + + score = (p0 - pe) / (1 - pe) + + return score + + +def dice(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): + """Calculate dice score between ground truth and prediction + Args: + pr (torch.Tensor): A list of predicted elements + gt (torch.Tensor): A list of elements that are to be predicted + eps (float): epsilon to avoid zero division + threshold: threshold for outputs binarization + Returns: + float: dice score + """ + pr = _threshold(pr, threshold=threshold) + pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) + + tp = torch.sum(gt * pr) + fp = torch.sum(pr) - tp + fn = torch.sum(gt) - tp + + _precision = precision(pr, gt, eps=eps, threshold=threshold, ignore_channels=ignore_channels) + _recall = recall(pr, gt, eps=eps, threshold=threshold, ignore_channels=ignore_channels) + + score = 2 * _precision * _recall / (_precision + _recall) + + return score \ No newline at end of file diff --git a/change_detection_pytorch/utils/metrics.py b/change_detection_pytorch/utils/metrics.py index e947e95..f87828d 100644 --- a/change_detection_pytorch/utils/metrics.py +++ b/change_detection_pytorch/utils/metrics.py @@ -99,4 +99,40 @@ def forward(self, y_pr, y_gt): ) +class Dice(base.Metric): + + def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): + super().__init__(**kwargs) + self.eps = eps + self.threshold = threshold + self.activation = Activation(activation) + self.ignore_channels = ignore_channels + + def forward(self, y_pr, y_gt): + y_pr = self.activation(y_pr) + return F.dice( + y_pr, y_gt, + eps=self.eps, + threshold=self.threshold, + ignore_channels=self.ignore_channels, + ) + + +class Kappa(base.Metric): + + def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs): + super().__init__(**kwargs) + self.eps = eps + self.threshold = threshold + self.activation = Activation(activation) + self.ignore_channels = ignore_channels + + def forward(self, y_pr, y_gt): + y_pr = self.activation(y_pr) + return F.kappa( + y_pr, y_gt, + eps=self.eps, + threshold=self.threshold, + ignore_channels=self.ignore_channels, + )