Skip to content

Commit

Permalink
Added log_cdf method for the Gamma distribution and modified some t…
Browse files Browse the repository at this point in the history
…ests in `gamma_test.py`.

PiperOrigin-RevId: 440864009
  • Loading branch information
franrruiz authored and DistraxDev committed Apr 11, 2022
1 parent 60ce578 commit 8882e93
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 76 deletions.
4 changes: 4 additions & 0 deletions distrax/_src/distributions/gamma.py
Expand Up @@ -103,6 +103,10 @@ def cdf(self, value: Array) -> Array:
"""See `Distribution.cdf`."""
return jax.lax.igamma(self._concentration, self._rate * value)

def log_cdf(self, value: Array) -> Array:
"""See `Distribution.log_cdf`."""
return jnp.log(self.cdf(value))

def mean(self) -> Array:
"""Calculates the mean."""
return self._concentration / self._rate
Expand Down
117 changes: 41 additions & 76 deletions distrax/_src/distributions/gamma_test.py
Expand Up @@ -34,14 +34,22 @@ def setUp(self):
self.assertion_fn = lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL)

@parameterized.named_parameters(
('1d std gamma', (1, 1)),
('2d std gamma', (np.ones(2), np.ones(2))),
('rank 2 std gamma', (np.ones((3, 2)), np.ones((3, 2)))),
('broadcasted concentration', (1, np.ones(3))),
('broadcasted rate', (np.ones(3), 1)),
('0d params', (), (), ()),
('1d params', (2,), (2,), (2,)),
('2d params, no broadcast', (3, 2), (3, 2), (3, 2)),
('2d params, broadcasted concentration', (2,), (3, 2), (3, 2)),
('2d params, broadcasted rate', (3, 2), (2,), (3, 2)),
)
def test_event_shape(self, distr_params):
super()._test_event_shape(distr_params, dict())
def test_properties(self, concentration_shape, rate_shape, batch_shape):
rng = np.random.default_rng(42)
concentration = 0.1 + rng.uniform(size=concentration_shape)
rate = 0.1 + rng.uniform(size=rate_shape)
dist = gamma.Gamma(concentration, rate)
self.assertEqual(dist.event_shape, ())
self.assertEqual(dist.batch_shape, batch_shape)
self.assertion_fn(
dist.concentration, np.broadcast_to(concentration, batch_shape))
self.assertion_fn(dist.rate, np.broadcast_to(rate, batch_shape))

@chex.all_variants
@parameterized.named_parameters(
Expand Down Expand Up @@ -104,82 +112,36 @@ def test_sample_and_log_prob(self, distr_params, sample_shape):
('1d dist, 2d value', (0.5, 0.1), np.array([1, 2])),
('2d dist, 1d value', (0.5 + np.zeros(2), 0.3 * np.ones(2)), 1),
('2d broadcasted dist, 1d value', (0.4 + np.zeros(2), 0.8), 1),
('2d dist, 2d value', ([0.1, -0.5], 0.9 * np.ones(2)), np.array([1, 2])),
('2d dist, 2d value', ([0.1, 0.5], 0.9 * np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (2.1, 1), 200),
)
def test_log_prob(self, distr_params, value):
def test_method_with_value(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_attribute(
attribute_string='log_prob',
dist_args=distr_params,
call_args=(value,),
assertion_fn=self.assertion_fn)

@chex.all_variants
@parameterized.named_parameters(
('1d dist, 1d value', (3.1, 1), 1),
('1d dist, 2d value', (0.5, 0.1), np.array([1, 2])),
('2d dist, 1d value', (0.5 + np.zeros(2), 0.3 * np.ones(2)), 1),
('2d broadcasted dist, 1d value', (0.4 + np.zeros(2), 0.8), 1),
('2d dist, 2d value', ([0.1, -0.5], 0.9 * np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (2.1, 1), 200),
)
def test_prob(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_attribute(
attribute_string='prob',
dist_args=distr_params,
call_args=(value,),
assertion_fn=self.assertion_fn)

@chex.all_variants
@parameterized.named_parameters(
('1d dist, 1d value', (3.1, 1), 1),
('1d dist, 2d value', (0.5, 0.1), np.array([1, 2])),
('2d dist, 1d value', (0.5 + np.zeros(2), 0.3 * np.ones(2)), 1),
('2d broadcasted dist, 1d value', (0.4 + np.zeros(2), 0.8), 1),
('2d dist, 2d value', ([0.1, -0.5], 0.9 * np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (2.1, 1), 200),
)
def test_cdf(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_attribute(
attribute_string='cdf',
dist_args=distr_params,
call_args=(value,),
assertion_fn=self.assertion_fn)
for method in ['log_prob', 'prob', 'cdf', 'log_cdf']:
with self.subTest(method=method):
super()._test_attribute(
attribute_string=method,
dist_args=distr_params,
call_args=(value,),
assertion_fn=self.assertion_fn)

@chex.all_variants(with_pmap=False)
@parameterized.named_parameters(
('entropy', ([0., 1.3, -0.5], [0.5, 1.3, 1.5]), 'entropy'),
('entropy broadcasted concentration', (0.5, [0.5, 1.3, 1.5]), 'entropy'),
('entropy broadcasted rate', ([0.1, 1.3, -0.5], 0.8), 'entropy'),
('mean', ([0.1, 1.3, -0.5], [0.5, 1.3, 1.5]), 'mean'),
('mean broadcasted concentration', (0.5, [0.5, 1.3, 1.5]), 'mean'),
('mean broadcasted rate', ([0.1, 1.3, -0.5], 0.8), 'mean'),
('variance', ([0.1, 1.3, -0.5], [0.5, 1.3, 1.5]), 'variance'),
('variance broadcasted concentration', (0.5, [0.5, 1., 1.]), 'variance'),
('variance broadcasted rate', ([0.1, 1.3, -0.5], 0.8), 'variance'),
('stddev', ([0.1, 1.3, -0.5], [0.5, 1.3, 1.5]), 'stddev'),
('stddev broadcasted concentration', (0.5, [0.5, 1.3, 1.5]), 'stddev'),
('stddev broadcasted rate', ([0.1, 1.3, -0.5], 0.8), 'stddev'),
('mode', ([0.1, 1.3, -0.5], [0.5, 1.3, 1.5]), 'mode'),
('mode broadcasted concentration', (0.5, [0.5, 1.3, 1.5]), 'mode'),
('mode broadcasted rate', ([0.1, 1.3, -0.5], 0.8), 'mode'),
('no broadcast', ([0.1, 1.3, 0.5], [0.5, 1.3, 1.5])),
('broadcasted concentration', (0.5, [0.5, 1.3, 1.5])),
('broadcasted rate', ([0.1, 1.3, 0.5], 0.8)),
)
def test_method(self, distr_params, function_string):
def test_method(self, distr_params):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
super()._test_attribute(
attribute_string=function_string,
dist_args=distr_params,
assertion_fn=self.assertion_fn)
for method in ['entropy', 'mean', 'variance', 'stddev', 'mode']:
with self.subTest(method=method):
super()._test_attribute(
attribute_string=method,
dist_args=distr_params,
assertion_fn=self.assertion_fn)

@chex.all_variants(with_pmap=False)
@parameterized.named_parameters(
Expand All @@ -193,16 +155,19 @@ def test_method(self, distr_params, function_string):
def test_with_two_distributions(self, function_string, mode_string):
rtol = 1e-3
atol = 1e-4
rng = np.random.default_rng(42)
super()._test_with_two_distributions(
attribute_string=function_string,
mode_string=mode_string,
dist1_kwargs={
'concentration': np.random.rand(4, 1, 2),
'rate': np.array([[0.8, 0.2], [0.1, 1.2], [1.4, 3.1]]),
'concentration': np.abs(
rng.normal(size=(4, 1, 2))).astype(np.float32),
'rate': np.array(
[[0.8, 0.2], [0.1, 1.2], [1.4, 3.1]], dtype=np.float32),
},
dist2_kwargs={
'concentration': np.random.rand(3, 2),
'rate': 0.1 + np.random.rand(4, 1, 2),
'concentration': np.abs(rng.normal(size=(3, 2))).astype(np.float32),
'rate': 0.1 + rng.uniform(size=(4, 1, 2)).astype(np.float32),
},
assertion_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol, atol))

Expand Down

0 comments on commit 8882e93

Please sign in to comment.