Skip to content

Commit

Permalink
Merge pull request #12 from e1y4r/main
Browse files Browse the repository at this point in the history
Added kappa and dice score to metrics.
  • Loading branch information
likyoo committed Jan 13, 2022
2 parents 6bb55fd + 57d47f4 commit 0a86d51
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
53 changes: 53 additions & 0 deletions change_detection_pytorch/utils/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,56 @@ def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
score = (tp + eps) / (tp + fn + eps)

return score


def kappa(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
"""Calculate kappa score between ground truth and prediction
Args:
pr (torch.Tensor): A list of predicted elements
gt (torch.Tensor): A list of elements that are to be predicted
eps (float): epsilon to avoid zero division
threshold: threshold for outputs binarization
Returns:
float: kappa score
"""

pr = _threshold(pr, threshold=threshold)
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)

tp = torch.sum(gt * pr)
fp = torch.sum(pr) - tp
fn = torch.sum(gt) - tp
tn = torch.sum((1 - gt)*(1 - pr))

N = tp + tn + fp + fn
p0 = (tp + tn) / N
pe = ((tp + fp) * (tp + fn) + (tn + fp) * (tn + fn)) / (N * N)

score = (p0 - pe) / (1 - pe)

return score


def dice(pr, gt, eps=1e-7, threshold=None, ignore_channels=None):
"""Calculate dice score between ground truth and prediction
Args:
pr (torch.Tensor): A list of predicted elements
gt (torch.Tensor): A list of elements that are to be predicted
eps (float): epsilon to avoid zero division
threshold: threshold for outputs binarization
Returns:
float: dice score
"""
pr = _threshold(pr, threshold=threshold)
pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels)

tp = torch.sum(gt * pr)
fp = torch.sum(pr) - tp
fn = torch.sum(gt) - tp

_precision = precision(pr, gt, eps=eps, threshold=threshold, ignore_channels=ignore_channels)
_recall = recall(pr, gt, eps=eps, threshold=threshold, ignore_channels=ignore_channels)

score = 2 * _precision * _recall / (_precision + _recall)

return score
36 changes: 36 additions & 0 deletions change_detection_pytorch/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,40 @@ def forward(self, y_pr, y_gt):
)


class Dice(base.Metric):

def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):
super().__init__(**kwargs)
self.eps = eps
self.threshold = threshold
self.activation = Activation(activation)
self.ignore_channels = ignore_channels

def forward(self, y_pr, y_gt):
y_pr = self.activation(y_pr)
return F.dice(
y_pr, y_gt,
eps=self.eps,
threshold=self.threshold,
ignore_channels=self.ignore_channels,
)


class Kappa(base.Metric):

def __init__(self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs):
super().__init__(**kwargs)
self.eps = eps
self.threshold = threshold
self.activation = Activation(activation)
self.ignore_channels = ignore_channels

def forward(self, y_pr, y_gt):
y_pr = self.activation(y_pr)
return F.kappa(
y_pr, y_gt,
eps=self.eps,
threshold=self.threshold,
ignore_channels=self.ignore_channels,
)

0 comments on commit 0a86d51

Please sign in to comment.