Skip to content

Commit

Permalink
Fix bug in gaussian_nll_loss (pytorch#56469)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#53964. cc albanD almson

## Major changes:
- Overhauled the actual loss calculation so that the shapes are now correct (in functional.py)
- added the missing doc in nn.functional.rst

## Minor changes (in functional.py):
- I removed the previous check on whether input and target were the same shape. This is to allow for broadcasting, say when you have 10 predictions that all have the same target.
- I added some comments to explain each shape check in detail. Let me know if these should be shortened/cut.

Screenshots of updated docs attached.
Let me know what you think, thanks!

## Edit: Description of change of behaviour (affecting BC):
The backwards-compatibility is only affected for the `reduction='none'` mode. This was the source of the bug. For tensors with size (N, D), the old returned loss had size (N), as incorrect summation was happening. It will now have size (N, D) as expected.

### Example
Define input tensors, all with size (2, 3).
`input = torch.tensor([[0., 1., 3.], [2., 4., 0.]], requires_grad=True)`
`target = torch.tensor([[1., 4., 2.], [-1., 2., 3.]])`
`var = 2*torch.ones(size=(2, 3), requires_grad=True)`

Initialise loss with reduction mode 'none'. We expect the returned loss to have the same size as the input tensors, (2, 3).
`loss = torch.nn.GaussianNLLLoss(reduction='none')`

Old behaviour:
`print(loss(input, target, var)) `
`# Gives tensor([3.7897, 6.5397], grad_fn=<MulBackward0>. This has size (2).`

New behaviour:
`print(loss(input, target, var)) `
`# Gives tensor([[0.5966, 2.5966, 0.5966], [2.5966, 1.3466, 2.5966]], grad_fn=<MulBackward0>)`
`# This has the expected size, (2, 3).`

To recover the old behaviour, sum along all dimensions except for the 0th:
`print(loss(input, target, var).sum(dim=1))`
`# Gives tensor([3.7897, 6.5397], grad_fn=<SumBackward1>.`

![doc1](https://user-images.githubusercontent.com/26558092/115391089-f7f47b00-a1d6-11eb-8726-e4da9057aee0.png)
![doc2](https://user-images.githubusercontent.com/26558092/115391094-f925a800-a1d6-11eb-954b-afd187f42bc7.png)

Pull Request resolved: pytorch#56469

Reviewed By: jbschlosser, agolynski

Differential Revision: D27894170

Pulled By: albanD

fbshipit-source-id: 197890189c97c22109491c47f469336b5b03a23f
  • Loading branch information
M.L. Croci authored and Kushashwa Shrimali committed May 18, 2021
1 parent e1ba175 commit d51d161
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 50 deletions.
1 change: 1 addition & 0 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ Loss functions
cosine_embedding_loss
cross_entropy
ctc_loss
gaussian_nll_loss
hinge_embedding_loss
kl_div
l1_loss
Expand Down
25 changes: 20 additions & 5 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5426,7 +5426,7 @@ def test_gaussian_nll_loss_reduction_modes(self):
input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]])
target = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
var = torch.tensor([[0.5, 1., 1.5], [1., 1.5, 2.]])
component_wise_loss = 0.5 * (torch.sum(torch.log(var) + (input - target)**2 / var, dim=1))
component_wise_loss = 0.5 * (torch.log(var) + (input - target)**2 / var)
self.assertEqual(component_wise_loss,
F.gaussian_nll_loss(input, target, var, reduction='none'))
self.assertEqual(torch.sum(component_wise_loss),
Expand All @@ -5436,12 +5436,27 @@ def test_gaussian_nll_loss_reduction_modes(self):
with self.assertRaisesRegex(ValueError, 'is not valid'):
F.gaussian_nll_loss(input, target, var, reduction='total')

def test_gaussian_nll_loss_broadcasting(self):
input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]])
target_full = torch.tensor([[1., 2., 3.], [1., 2., 3.]])
target_part = torch.tensor([[1., 2., 3.]])
var_full = torch.tensor([[0.5, 0.5, 0.5], [1.5, 1.5, 1.5]])
var_part1 = torch.tensor([[0.5], [1.5]])
var_part2 = torch.tensor([0.5, 1.5])
component_wise_loss = 0.5 * (torch.log(var_full) + (input - target_full)**2 / var_full)
self.assertEqual(component_wise_loss,
F.gaussian_nll_loss(input, target_part, var_full, reduction='none'))
self.assertEqual(component_wise_loss,
F.gaussian_nll_loss(input, target_full, var_part1, reduction='none'))
self.assertEqual(component_wise_loss,
F.gaussian_nll_loss(input, target_full, var_part2, reduction='none'))
self.assertEqual(component_wise_loss,
F.gaussian_nll_loss(input, target_part, var_part1, reduction='none'))
self.assertEqual(component_wise_loss,
F.gaussian_nll_loss(input, target_part, var_part2, reduction='none'))

def test_gaussian_nll_loss_args(self):
input = torch.randn(3, 5)
with self.assertRaisesRegex(ValueError, 'input and target must have same size'):
target = torch.randn(3, 6)
var = torch.ones(3, 5)
torch.nn.functional.gaussian_nll_loss(input, target, var)
with self.assertRaisesRegex(ValueError, 'var is of incorrect size'):
target = torch.randn(3, 5)
var = torch.ones(3, 3)
Expand Down
77 changes: 48 additions & 29 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2563,7 +2563,14 @@ def poisson_nll_loss(
return ret


def gaussian_nll_loss(input, target, var, *, full=False, eps=1e-6, reduction='mean'):
def gaussian_nll_loss(
input: Tensor,
target: Tensor,
var: Tensor,
full: bool = False,
eps: float = 1e-6,
reduction: str = "mean",
) -> Tensor:
r"""Gaussian negative log likelihood loss.
See :class:`~torch.nn.GaussianNLLLoss` for details.
Expand All @@ -2573,31 +2580,47 @@ def gaussian_nll_loss(input, target, var, *, full=False, eps=1e-6, reduction='me
target: sample from the Gaussian distribution.
var: tensor of positive variance(s), one for each of the expectations
in the input (heteroscedastic), or a single one (homoscedastic).
full: ``True``/``False`` (bool), include the constant term in the loss
calculation. Default: ``False``.
eps: value added to var, for stability. Default: 1e-6.
reduction: specifies the reduction to apply to the output:
`'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
full (bool, optional): include the constant term in the loss calculation. Default: ``False``.
eps (float, optional): value added to var, for stability. Default: 1e-6.
reduction (string, optional): specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the output is the average of all batch member losses,
``'sum'``: the output is the sum of all batch member losses.
Default: ``'mean'``.
"""
if not torch.jit.is_scripting():
tens_ops = (input, target, var)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
return handle_torch_function(
gaussian_nll_loss, tens_ops, input, target, var, full=full, eps=eps, reduction=reduction)

# Inputs and targets much have same shape
input = input.view(input.size(0), -1)
target = target.view(target.size(0), -1)
if input.size() != target.size():
raise ValueError("input and target must have same size")

# Second dim of var must match that of input or be equal to 1
var = var.view(input.size(0), -1)
if var.size(1) != input.size(1) and var.size(1) != 1:
raise ValueError("var is of incorrect size")
if has_torch_function_variadic(input, target, var):
return handle_torch_function(
gaussian_nll_loss,
(input, target, var),
input,
target,
var,
full=full,
eps=eps,
reduction=reduction,
)

# Check var size
# If var.size == input.size, the case is heteroscedastic and no further checks are needed.
# Otherwise:
if var.size() != input.size():

# If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case.
# e.g. input.size = (10, 2, 3), var.size = (10, 2)
# -> unsqueeze var so that var.shape = (10, 2, 1)
# this is done so that broadcasting can happen in the loss calculation
if input.size()[:-1] == var.size():
var = torch.unsqueeze(var, -1)

# This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1.
# This is also a homoscedastic case.
# e.g. input.size = (10, 2, 3), var.size = (10, 2, 1)
elif input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1: # Heteroscedastic case
pass

# If none of the above pass, then the size of var is incorrect.
else:
raise ValueError("var is of incorrect size")

# Check validity of reduction mode
if reduction != 'none' and reduction != 'mean' and reduction != 'sum':
Expand All @@ -2612,15 +2635,11 @@ def gaussian_nll_loss(input, target, var, *, full=False, eps=1e-6, reduction='me
with torch.no_grad():
var.clamp_(min=eps)

# Calculate loss (without constant)
loss = 0.5 * (torch.log(var) + (input - target)**2 / var).view(input.size(0), -1).sum(dim=1)

# Add constant to loss term if required
# Calculate the loss
loss = 0.5 * (torch.log(var) + (input - target)**2 / var)
if full:
D = input.size(1)
loss = loss + 0.5 * D * math.log(2 * math.pi)
loss += 0.5 * math.log(2 * math.pi)

# Apply reduction
if reduction == 'mean':
return loss.mean()
elif reduction == 'sum':
Expand Down
30 changes: 15 additions & 15 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,20 +304,18 @@ class GaussianNLLLoss(_Loss):
The targets are treated as samples from Gaussian distributions with
expectations and variances predicted by the neural network. For a
D-dimensional ``target`` tensor modelled as having heteroscedastic Gaussian
distributions with a D-dimensional tensor of expectations ``input`` and a
D-dimensional tensor of positive variances ``var`` the loss is:
``target`` tensor modelled as having Gaussian distribution with a tensor
of expectations ``input`` and a tensor of positive variances ``var`` the loss is:
.. math::
\text{loss} = \frac{1}{2}\sum_{i=1}^D \left(\log\left(\text{max}\left(\text{var}[i],
\ \text{eps}\right)\right) + \frac{\left(\text{input}[i] - \text{target}[i]\right)^2}
{\text{max}\left(\text{var}[i], \ \text{eps}\right)}\right) + \text{const.}
\text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var},
\ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2}
{\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}
where :attr:`eps` is used for stability. By default, the constant term of
the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is
a scalar (implying ``target`` tensor has homoscedastic Gaussian
distributions) it is broadcasted to be the same size as the input.
the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same
size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension
of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting.
Args:
full (bool, optional): include the constant term in the loss
Expand All @@ -333,21 +331,23 @@ class GaussianNLLLoss(_Loss):
Shape:
- Input: :math:`(N, *)` where :math:`*` means any number of additional
dimensions
- Target: :math:`(N, *)`, same shape as the input
- Var: :math:`(N, 1)` or :math:`(N, *)`, same shape as the input
- Target: :math:`(N, *)`, same shape as the input, or same shape as the input
but with one dimension equal to 1 (to allow for broadcasting)
- Var: :math:`(N, *)`, same shape as the input, or same shape as the input but
with one dimension equal to 1, or same shape as the input but with one fewer
dimension (to allow for broadcasting)
- Output: scalar if :attr:`reduction` is ``'mean'`` (default) or
``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N)`
``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same
shape as the input
Examples::
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
>>> var = torch.ones(5, 2, requires_grad=True) #heteroscedastic
>>> output = loss(input, target, var)
>>> output.backward()
>>> loss = nn.GaussianNLLLoss()
>>> input = torch.randn(5, 2, requires_grad=True)
>>> target = torch.randn(5, 2)
Expand Down
2 changes: 1 addition & 1 deletion torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.fractional_max_pool3d_with_indices: (
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None: -1),
torch.nn.functional.gaussian_nll_loss: (lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1),
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
torch.nn.functional.gelu: lambda input: -1,
torch.nn.functional.glu: lambda input, dim=-1: -1,
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
Expand Down

0 comments on commit d51d161

Please sign in to comment.