diff --git a/test/test_nn.py b/test/test_nn.py index fe06f0f4db4a..c8d54c58b8e5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -6205,7 +6205,7 @@ def forward(self, *args): input_fn=lambda: torch.randn(5, 10), target_fn=lambda: torch.rand(5, 10).mul(2).floor(), reference_fn=lambda i, t, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * get_weight(m)).sum() / - (i.numel() if get_reduction(m) == 'elementwise_mean' else 1), + (i.numel() if get_reduction(m) == 'elementwise_mean' else i.size(1) if get_reduction(m) == 'sum' else 1), desc='weights', check_sum_reduction=True, check_gradgrad=False, @@ -6712,7 +6712,8 @@ def multilabelsoftmarginloss_no_reduce_test(): constructor=wrap_functional( lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.randn(5, 10), - reference_fn=lambda i, m: -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()), + reference_fn=lambda i, m: + (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1), check_gradgrad=False, pickle=False) @@ -6726,7 +6727,8 @@ def multilabelsoftmarginloss_weights_no_reduce_test(): lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), weight=weights.type_as(i), reduction='none')), input_fn=lambda: torch.randn(5, 10), - reference_fn=lambda i, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights), + reference_fn=lambda i, m: + (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1), check_sum_reduction=True, check_gradgrad=False, pickle=False) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 21c09c412af6..b47cc39bbb17 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1827,8 +1827,22 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, """ if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) - input = torch.sigmoid(input) - return binary_cross_entropy(input, target, weight, None, None, reduction) + + loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input)) + + if weight is not None: + loss = loss * weight + + loss = loss.sum(dim=1) / input.size(1) # only return N loss values + + if reduction == 'none': + return loss + elif reduction == 'elementwise_mean': + return loss.mean() + elif reduction == 'sum': + return loss.sum() + else: + raise ValueError(reduction + " is not valid") def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 2a1714930e35..ef6c89716919 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -872,7 +872,7 @@ class MultiLabelSoftMarginLoss(_WeightedLoss): For each sample in the minibatch: .. math:: - loss(x, y) = - \sum_i y[i] * \log((1 + \exp(-x[i]))^{-1}) + loss(x, y) = - \frac{1}{C} * \sum_i y[i] * \log((1 + \exp(-x[i]))^{-1}) + (1-y[i]) * \log\left(\frac{\exp(-x[i])}{(1 + \exp(-x[i]))}\right) where `i == 0` to `x.nElement()-1`, `y[i] in {0,1}`.