Skip to content

Commit

Permalink
Avoid Nan Losses (#2001)
Browse files Browse the repository at this point in the history
Fixes small error where nan losses could appear if whole input is
masked.
  • Loading branch information
sanagno committed Mar 7, 2023
1 parent e51c517 commit f2ed582
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions model/model_training/losses.py
Expand Up @@ -4,7 +4,7 @@


class CrossEntropyLoss(nn.CrossEntropyLoss):
def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean"):
def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="none"):
super().__init__(weight, size_average, ignore_index, reduce, reduction)

def forward(self, input, target, mask=None):
Expand All @@ -14,7 +14,12 @@ def forward(self, input, target, mask=None):
target = target.view(-1)
input = input[mask]
target = target[mask]
return super().forward(input, target)

size = target.numel()

loss = super().forward(input, target)

return loss.sum() / (size + 1e-8)


class PolyLoss(nn.Module):
Expand Down

0 comments on commit f2ed582

Please sign in to comment.