Skip to content

Commit

Permalink
use logsigmoid at multilabel_soft_margin_loss, and change output from…
Browse files Browse the repository at this point in the history
… shape=(N, C)to (N,) (pytorch#9965)

Summary:
- fixes pytorch#9141, pytorch#9301
- use logsigmoid at multilabel_soft_margin_loss to make it more stable (NOT fixing legacy MultiLabelSoftMarginCriterion)
- return (N) instead of (N, C) to match the same behavior as MultiMarginLoss
- Note that with this PR, the following behavior is expected:
```
loss = F.multilabel_soft_margin_loss(outputs, labels, reduction='none')
loss_mean = F.multilabel_soft_margin_loss(outputs, labels, reduction='elementwise_mean')
loss_sum = F.multilabel_soft_margin_loss(outputs, labels, reduction='sum')

loss.sum() == loss_sum  # True
loss.mean() == loss_mean  # True
```
Pull Request resolved: pytorch#9965

Differential Revision: D9038402

Pulled By: weiyangfb

fbshipit-source-id: 0fa94c7b3cd370ea62bd6333f1a0e9bd0b8ccbb9
  • Loading branch information
weiyangfb authored and Rob Kunkle committed Aug 15, 2018
1 parent f34a992 commit f6b1259
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
8 changes: 5 additions & 3 deletions test/test_nn.py
Expand Up @@ -6205,7 +6205,7 @@ def forward(self, *args):
input_fn=lambda: torch.randn(5, 10),
target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
reference_fn=lambda i, t, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * get_weight(m)).sum() /
(i.numel() if get_reduction(m) == 'elementwise_mean' else 1),
(i.numel() if get_reduction(m) == 'elementwise_mean' else i.size(1) if get_reduction(m) == 'sum' else 1),
desc='weights',
check_sum_reduction=True,
check_gradgrad=False,
Expand Down Expand Up @@ -6712,7 +6712,8 @@ def multilabelsoftmarginloss_no_reduce_test():
constructor=wrap_functional(
lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')),
input_fn=lambda: torch.randn(5, 10),
reference_fn=lambda i, m: -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()),
reference_fn=lambda i, m:
(-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1),
check_gradgrad=False,
pickle=False)

Expand All @@ -6726,7 +6727,8 @@ def multilabelsoftmarginloss_weights_no_reduce_test():
lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
weight=weights.type_as(i), reduction='none')),
input_fn=lambda: torch.randn(5, 10),
reference_fn=lambda i, m: -((t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights),
reference_fn=lambda i, m:
(-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1),
check_sum_reduction=True,
check_gradgrad=False,
pickle=False)
Expand Down
18 changes: 16 additions & 2 deletions torch/nn/functional.py
Expand Up @@ -1827,8 +1827,22 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
"""
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
input = torch.sigmoid(input)
return binary_cross_entropy(input, target, weight, None, None, reduction)

loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input))

if weight is not None:
loss = loss * weight

loss = loss.sum(dim=1) / input.size(1) # only return N loss values

if reduction == 'none':
return loss
elif reduction == 'elementwise_mean':
return loss.mean()
elif reduction == 'sum':
return loss.sum()
else:
raise ValueError(reduction + " is not valid")


def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/loss.py
Expand Up @@ -872,7 +872,7 @@ class MultiLabelSoftMarginLoss(_WeightedLoss):
For each sample in the minibatch:
.. math::
loss(x, y) = - \sum_i y[i] * \log((1 + \exp(-x[i]))^{-1})
loss(x, y) = - \frac{1}{C} * \sum_i y[i] * \log((1 + \exp(-x[i]))^{-1})
+ (1-y[i]) * \log\left(\frac{\exp(-x[i])}{(1 + \exp(-x[i]))}\right)
where `i == 0` to `x.nElement()-1`, `y[i] in {0,1}`.
Expand Down

0 comments on commit f6b1259

Please sign in to comment.