forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix bug in gaussian_nll_loss (pytorch#56469)
Summary: Fixes pytorch#53964. cc albanD almson ## Major changes: - Overhauled the actual loss calculation so that the shapes are now correct (in functional.py) - added the missing doc in nn.functional.rst ## Minor changes (in functional.py): - I removed the previous check on whether input and target were the same shape. This is to allow for broadcasting, say when you have 10 predictions that all have the same target. - I added some comments to explain each shape check in detail. Let me know if these should be shortened/cut. Screenshots of updated docs attached. Let me know what you think, thanks! ## Edit: Description of change of behaviour (affecting BC): The backwards-compatibility is only affected for the `reduction='none'` mode. This was the source of the bug. For tensors with size (N, D), the old returned loss had size (N), as incorrect summation was happening. It will now have size (N, D) as expected. ### Example Define input tensors, all with size (2, 3). `input = torch.tensor([[0., 1., 3.], [2., 4., 0.]], requires_grad=True)` `target = torch.tensor([[1., 4., 2.], [-1., 2., 3.]])` `var = 2*torch.ones(size=(2, 3), requires_grad=True)` Initialise loss with reduction mode 'none'. We expect the returned loss to have the same size as the input tensors, (2, 3). `loss = torch.nn.GaussianNLLLoss(reduction='none')` Old behaviour: `print(loss(input, target, var)) ` `# Gives tensor([3.7897, 6.5397], grad_fn=<MulBackward0>. This has size (2).` New behaviour: `print(loss(input, target, var)) ` `# Gives tensor([[0.5966, 2.5966, 0.5966], [2.5966, 1.3466, 2.5966]], grad_fn=<MulBackward0>)` `# This has the expected size, (2, 3).` To recover the old behaviour, sum along all dimensions except for the 0th: `print(loss(input, target, var).sum(dim=1))` `# Gives tensor([3.7897, 6.5397], grad_fn=<SumBackward1>.` ![doc1](https://user-images.githubusercontent.com/26558092/115391089-f7f47b00-a1d6-11eb-8726-e4da9057aee0.png) ![doc2](https://user-images.githubusercontent.com/26558092/115391094-f925a800-a1d6-11eb-954b-afd187f42bc7.png) Pull Request resolved: pytorch#56469 Reviewed By: jbschlosser, agolynski Differential Revision: D27894170 Pulled By: albanD fbshipit-source-id: 197890189c97c22109491c47f469336b5b03a23f
- Loading branch information
M.L. Croci
authored and
Kushashwa Shrimali
committed
May 18, 2021
1 parent
e1ba175
commit d51d161
Showing
5 changed files
with
85 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters