Skip to content

Commit

Permalink
Implement survival and log-survival function for the Logistic distr…
Browse files Browse the repository at this point in the history
…ibution.

Also:
* Simplify the tests for the `Logistic`
PiperOrigin-RevId: 443160365
  • Loading branch information
DistraxDev authored and DistraxDev committed Apr 22, 2022
1 parent cc452c5 commit a2fbf4f
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 62 deletions.
33 changes: 33 additions & 0 deletions distrax/_src/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,39 @@ def cdf(self, value: Array) -> Array:
"""
return jnp.exp(self.log_cdf(value))

def survival_function(self, value: Array) -> Array:
"""Evaluates the survival function at `value`.
Args:
value: An event.
Returns:
The survival function evaluated at `value`, i.e. P[X > value]
"""
if not self.event_shape:
# Defined for univariate distributions only.
return 1. - self.cdf(value)
else:
raise NotImplementedError('`survival_function` is not defined for '
f'distribution {self.name}')

def log_survival_function(self, value: Array) -> Array:
"""Evaluates the log of the survival function at `value`.
Args:
value: An event.
Returns:
The log of the survival function evaluated at `value`, i.e.
log P[X > value]
"""
if not self.event_shape:
# Defined for univariate distributions only.
return jnp.log1p(-self.cdf(value))
else:
raise NotImplementedError('`log_survival_function` is not defined for '
f'distribution {self.name}')

def mean(self) -> Array:
"""Calculates the mean."""
raise NotImplementedError(
Expand Down
6 changes: 6 additions & 0 deletions distrax/_src/distributions/distribution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ def test_to_batch_shape_index_raises(self, index):
distribution.to_batch_shape_index(
batch_shape=(2, 3, 4), index=index)

def test_multivariate_survival_function_raises(self):
mult_dist = DummyMultivariateDist(42)
with self.assertRaises(NotImplementedError):
mult_dist.survival_function(jnp.zeros(42))
with self.assertRaises(NotImplementedError):
mult_dist.log_survival_function(jnp.zeros(42))

if __name__ == '__main__':
absltest.main()
8 changes: 8 additions & 0 deletions distrax/_src/distributions/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def log_cdf(self, value: Array) -> Array:
"""See `Distribution.log_cdf`."""
return -jax.nn.softplus(-self._standardize(value))

def survival_function(self, value: Array) -> Array:
"""See `Distribution.survival_function`."""
return jax.nn.sigmoid(-self._standardize(value))

def log_survival_function(self, value: Array) -> Array:
"""See `Distribution.log_survival_function`."""
return -jax.nn.softplus(self._standardize(value))

def mean(self) -> Array:
"""Calculates the mean."""
return self.loc
Expand Down
72 changes: 10 additions & 62 deletions distrax/_src/distributions/logistic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,71 +109,19 @@ def test_sample_and_log_prob(self, distr_params, sample_shape):
('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (0, 1), 200),
)
def test_log_prob(self, distr_params, value):
def test_method_with_input(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_log_prob(distr_params, dict(), value)

@chex.all_variants
@parameterized.named_parameters(
('1d dist, 1d value', (0, 1), 1),
('1d dist, 2d value', (0., 1.), np.array([1, 2])),
('1d dist, 2d value as list', (0., 1.), [1, 2]),
('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1),
('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1),
('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (0, 1), 200),
)
def test_prob(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_attribute(
attribute_string='prob',
dist_args=distr_params,
call_args=(value,),
assertion_fn=self.assertion_fn)

@chex.all_variants
@parameterized.named_parameters(
('1d dist, 1d value', (0, 1), 1),
('1d dist, 2d value', (0., 1.), np.array([1, 2])),
('1d dist, 2d value as list', (0., 1.), [1, 2]),
('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1),
('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1),
('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (0, 1), 200),
)
def test_cdf(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_attribute(
attribute_string='cdf',
dist_args=distr_params,
call_args=(value,),
assertion_fn=self.assertion_fn)

@chex.all_variants
@parameterized.named_parameters(
('1d dist, 1d value', (0, 1), 1),
('1d dist, 2d value', (0., 1.), np.array([1, 2])),
('1d dist, 2d value as list', (0., 1.), [1, 2]),
('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1),
('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1),
('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (0, 1), 200),
)
def test_log_cdf(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_attribute(
attribute_string='log_cdf',
dist_args=distr_params,
call_args=(value,),
assertion_fn=self.assertion_fn)
for method in ['log_prob', 'prob', 'cdf', 'log_cdf', 'survival_function',
'log_survival_function']:
with self.subTest(method):
super()._test_attribute(
attribute_string=method,
dist_args=distr_params,
dist_kwargs={},
call_args=(value,),
assertion_fn=self.assertion_fn)

@chex.all_variants(with_pmap=False)
@parameterized.named_parameters(
Expand Down
2 changes: 2 additions & 0 deletions distrax/_src/distributions/uniform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def test_sample_and_log_prob(self, distr_params, sample_shape):
('log_prob', 'log_prob'),
('prob', 'prob'),
('cdf', 'cdf'),
('survival_function', 'survival_function'),
('log_survival_function', 'log_survival_function')
)
def test_method_with_inputs(self, function_string):
inputs = 10. * np.random.normal(size=(100,))
Expand Down

0 comments on commit a2fbf4f

Please sign in to comment.