Skip to content

Commit

Permalink
Fix gumbel cdf (#91698)
Browse files Browse the repository at this point in the history
Fix `Gumbel.cdf` function.

**Description**
When transformed parameters is outside of the support of underlying Uniform distribution. This makes behavior of `Gumbel.cdf` consistent with other `TransformedDistribution` that pass value of validate_args to the base distribution.

**Issue**
running `Gumbel(0.0,1.0,validate_args=False).cdf(20.0)` would cause `ValueError` exception from `_validate_sample`

**Testing**
Test was added to the `test_distributions.py` to check if `Gumbel(0.0,1.0,validate_args=False).cdf(20.0)` successfully returns `1.0`

This is a second attempt to push changes , after pytorch/pytorch#82488

Pull Request resolved: pytorch/pytorch#91698
Approved by: https://github.com/fritzo, https://github.com/zou3519
  • Loading branch information
vfonov authored and cyyever committed Mar 12, 2023
1 parent 115bac7 commit 0e59eba
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
12 changes: 12 additions & 0 deletions test/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2589,6 +2589,18 @@ def test_gumbel(self):
self.assertEqual(Gumbel(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Gumbel(1.0, 1.0).sample().size(), ())
self.assertEqual(Gumbel(1.0, 1.0).sample((1,)).size(), (1,))
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float32),
torch.tensor(1.0, dtype=torch.float32),
validate_args=False).cdf(20.0), 1.0, atol=1e-4, rtol=0)
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float64),
torch.tensor(1.0, dtype=torch.float64),
validate_args=False).cdf(50.0), 1.0, atol=1e-4, rtol=0)
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float32),
torch.tensor(1.0, dtype=torch.float32),
validate_args=False).cdf(-5.0), 0.0, atol=1e-4, rtol=0)
self.assertEqual(Gumbel(torch.tensor(0.0, dtype=torch.float64),
torch.tensor(1.0, dtype=torch.float64),
validate_args=False).cdf(-10.0), 0.0, atol=1e-8, rtol=0)

def ref_log_prob(idx, x, log_prob):
l = loc.view(-1)[idx].detach()
Expand Down
5 changes: 3 additions & 2 deletions torch/distributions/gumbel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ def __init__(self, loc, scale, validate_args=None):
self.loc, self.scale = broadcast_all(loc, scale)
finfo = torch.finfo(self.loc.dtype)
if isinstance(loc, Number) and isinstance(scale, Number):
base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args)
else:
base_dist = Uniform(torch.full_like(self.loc, finfo.tiny),
torch.full_like(self.loc, 1 - finfo.eps))
torch.full_like(self.loc, 1 - finfo.eps),
validate_args=validate_args)
transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
super().__init__(base_dist, transforms, validate_args=validate_args)
Expand Down

0 comments on commit 0e59eba

Please sign in to comment.