Skip to content

Commit

Permalink
[Fix] Make accuracy take into account ignore_index (open-mmlab#1259)
Browse files Browse the repository at this point in the history
* make accuracy take into account ignore_index

* add UT for accuracy
  • Loading branch information
HJoonKwon authored Feb 14, 2022
1 parent 85c5eeb commit 4d451a0
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 6 deletions.
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/decode_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,5 +261,6 @@ def losses(self, seg_logit, seg_label):
weight=seg_weight,
ignore_index=self.ignore_index)

loss['acc_seg'] = accuracy(seg_logit, seg_label)
loss['acc_seg'] = accuracy(
seg_logit, seg_label, ignore_index=self.ignore_index)
return loss
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/point_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def losses(self, point_logits, point_label):
loss['point' + loss_module.loss_name] = loss_module(
point_logits, point_label, ignore_index=self.ignore_index)

loss['acc_point'] = accuracy(point_logits, point_label)
loss['acc_point'] = accuracy(
point_logits, point_label, ignore_index=self.ignore_index)
return loss

def get_points_train(self, seg_logits, uncertainty_func, cfg):
Expand Down
13 changes: 9 additions & 4 deletions mmseg/models/losses/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import torch.nn as nn


def accuracy(pred, target, topk=1, thresh=None):
def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
target (torch.Tensor): The target of each prediction, shape (N, , ...)
ignore_index (int | None): The label index to be ignored. Default: None
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
Expand Down Expand Up @@ -43,17 +44,19 @@ def accuracy(pred, target, topk=1, thresh=None):
if thresh is not None:
# Only prediction values larger than thresh are counted as correct
correct = correct & (pred_value > thresh).t()
correct = correct[:, target != ignore_index]
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / target.numel()))
res.append(
correct_k.mul_(100.0 / target[target != ignore_index].numel()))
return res[0] if return_single else res


class Accuracy(nn.Module):
"""Accuracy calculation module."""

def __init__(self, topk=(1, ), thresh=None):
def __init__(self, topk=(1, ), thresh=None, ignore_index=None):
"""Module to calculate the accuracy.
Args:
Expand All @@ -65,6 +68,7 @@ def __init__(self, topk=(1, ), thresh=None):
super().__init__()
self.topk = topk
self.thresh = thresh
self.ignore_index = ignore_index

def forward(self, pred, target):
"""Forward function to calculate accuracy.
Expand All @@ -76,4 +80,5 @@ def forward(self, pred, target):
Returns:
tuple[float]: The accuracies under different topk criterions.
"""
return accuracy(pred, target, self.topk, self.thresh)
return accuracy(pred, target, self.topk, self.thresh,
self.ignore_index)
24 changes: 24 additions & 0 deletions tests/test_models/test_losses/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,30 @@ def test_accuracy():
pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6],
[0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1],
[0.0, 0.0, 0.99, 0]])
# test for ignore_index
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=None)
acc = accuracy(pred, true_label)
assert acc.item() == 100

# test for ignore_index with a wrong prediction of that index
true_label = torch.Tensor([2, 3, 1, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=1)
acc = accuracy(pred, true_label)
assert acc.item() == 100

# test for ignore_index 1 with a wrong prediction of other index
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=1)
acc = accuracy(pred, true_label)
assert acc.item() == 75

# test for ignore_index 4 with a wrong prediction of other index
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=4)
acc = accuracy(pred, true_label)
assert acc.item() == 80

# test for top1
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
accuracy = Accuracy(topk=1)
Expand Down

0 comments on commit 4d451a0

Please sign in to comment.