diff --git a/losses.py b/losses.py index dd95860..d406aac 100644 --- a/losses.py +++ b/losses.py @@ -18,7 +18,7 @@ def __init__(self, weight=None, gamma=0.): self.weight = weight def forward(self, input, target): - return focal_loss(F.cross_entropy(input, target, weight=self.weight), self.gamma) + return focal_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), self.gamma) class LDAMLoss(nn.Module):