diff --git a/distrax/_src/distributions/distribution.py b/distrax/_src/distributions/distribution.py index d81d80b3..2ec47085 100644 --- a/distrax/_src/distributions/distribution.py +++ b/distrax/_src/distributions/distribution.py @@ -188,6 +188,39 @@ def cdf(self, value: Array) -> Array: """ return jnp.exp(self.log_cdf(value)) + def survival_function(self, value: Array) -> Array: + """Evaluates the survival function at `value`. + + Args: + value: An event. + + Returns: + The survival function evaluated at `value`, i.e. P[X > value] + """ + if not self.event_shape: + # Defined for univariate distributions only. + return 1. - self.cdf(value) + else: + raise NotImplementedError('`survival_function` is not defined for ' + f'distribution {self.name}') + + def log_survival_function(self, value: Array) -> Array: + """Evaluates the log of the survival function at `value`. + + Args: + value: An event. + + Returns: + The log of the survival function evaluated at `value`, i.e. + log P[X > value] + """ + if not self.event_shape: + # Defined for univariate distributions only. + return jnp.log1p(-self.cdf(value)) + else: + raise NotImplementedError('`log_survival_function` is not defined for ' + f'distribution {self.name}') + def mean(self) -> Array: """Calculates the mean.""" raise NotImplementedError( diff --git a/distrax/_src/distributions/distribution_test.py b/distrax/_src/distributions/distribution_test.py index e850eb4b..b2d2c503 100644 --- a/distrax/_src/distributions/distribution_test.py +++ b/distrax/_src/distributions/distribution_test.py @@ -205,6 +205,12 @@ def test_to_batch_shape_index_raises(self, index): distribution.to_batch_shape_index( batch_shape=(2, 3, 4), index=index) + def test_multivariate_survival_function_raises(self): + mult_dist = DummyMultivariateDist(42) + with self.assertRaises(NotImplementedError): + mult_dist.survival_function(jnp.zeros(42)) + with self.assertRaises(NotImplementedError): + mult_dist.log_survival_function(jnp.zeros(42)) if __name__ == '__main__': absltest.main() diff --git a/distrax/_src/distributions/logistic.py b/distrax/_src/distributions/logistic.py index c3793d3b..80c04589 100644 --- a/distrax/_src/distributions/logistic.py +++ b/distrax/_src/distributions/logistic.py @@ -101,6 +101,14 @@ def log_cdf(self, value: Array) -> Array: """See `Distribution.log_cdf`.""" return -jax.nn.softplus(-self._standardize(value)) + def survival_function(self, value: Array) -> Array: + """See `Distribution.survival_function`.""" + return jax.nn.sigmoid(-self._standardize(value)) + + def log_survival_function(self, value: Array) -> Array: + """See `Distribution.log_survival_function`.""" + return -jax.nn.softplus(self._standardize(value)) + def mean(self) -> Array: """Calculates the mean.""" return self.loc diff --git a/distrax/_src/distributions/logistic_test.py b/distrax/_src/distributions/logistic_test.py index e1dd2c0c..3b977dde 100644 --- a/distrax/_src/distributions/logistic_test.py +++ b/distrax/_src/distributions/logistic_test.py @@ -109,71 +109,19 @@ def test_sample_and_log_prob(self, distr_params, sample_shape): ('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])), ('1d dist, 1d value, edge case', (0, 1), 200), ) - def test_log_prob(self, distr_params, value): + def test_method_with_input(self, distr_params, value): distr_params = (np.asarray(distr_params[0], dtype=np.float32), np.asarray(distr_params[1], dtype=np.float32)) value = np.asarray(value, dtype=np.float32) - super()._test_log_prob(distr_params, dict(), value) - - @chex.all_variants - @parameterized.named_parameters( - ('1d dist, 1d value', (0, 1), 1), - ('1d dist, 2d value', (0., 1.), np.array([1, 2])), - ('1d dist, 2d value as list', (0., 1.), [1, 2]), - ('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1), - ('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1), - ('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])), - ('1d dist, 1d value, edge case', (0, 1), 200), - ) - def test_prob(self, distr_params, value): - distr_params = (np.asarray(distr_params[0], dtype=np.float32), - np.asarray(distr_params[1], dtype=np.float32)) - value = np.asarray(value, dtype=np.float32) - super()._test_attribute( - attribute_string='prob', - dist_args=distr_params, - call_args=(value,), - assertion_fn=self.assertion_fn) - - @chex.all_variants - @parameterized.named_parameters( - ('1d dist, 1d value', (0, 1), 1), - ('1d dist, 2d value', (0., 1.), np.array([1, 2])), - ('1d dist, 2d value as list', (0., 1.), [1, 2]), - ('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1), - ('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1), - ('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])), - ('1d dist, 1d value, edge case', (0, 1), 200), - ) - def test_cdf(self, distr_params, value): - distr_params = (np.asarray(distr_params[0], dtype=np.float32), - np.asarray(distr_params[1], dtype=np.float32)) - value = np.asarray(value, dtype=np.float32) - super()._test_attribute( - attribute_string='cdf', - dist_args=distr_params, - call_args=(value,), - assertion_fn=self.assertion_fn) - - @chex.all_variants - @parameterized.named_parameters( - ('1d dist, 1d value', (0, 1), 1), - ('1d dist, 2d value', (0., 1.), np.array([1, 2])), - ('1d dist, 2d value as list', (0., 1.), [1, 2]), - ('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1), - ('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1), - ('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])), - ('1d dist, 1d value, edge case', (0, 1), 200), - ) - def test_log_cdf(self, distr_params, value): - distr_params = (np.asarray(distr_params[0], dtype=np.float32), - np.asarray(distr_params[1], dtype=np.float32)) - value = np.asarray(value, dtype=np.float32) - super()._test_attribute( - attribute_string='log_cdf', - dist_args=distr_params, - call_args=(value,), - assertion_fn=self.assertion_fn) + for method in ['log_prob', 'prob', 'cdf', 'log_cdf', 'survival_function', + 'log_survival_function']: + with self.subTest(method): + super()._test_attribute( + attribute_string=method, + dist_args=distr_params, + dist_kwargs={}, + call_args=(value,), + assertion_fn=self.assertion_fn) @chex.all_variants(with_pmap=False) @parameterized.named_parameters( diff --git a/distrax/_src/distributions/uniform_test.py b/distrax/_src/distributions/uniform_test.py index 0fdac471..91ff1ffc 100644 --- a/distrax/_src/distributions/uniform_test.py +++ b/distrax/_src/distributions/uniform_test.py @@ -89,6 +89,8 @@ def test_sample_and_log_prob(self, distr_params, sample_shape): ('log_prob', 'log_prob'), ('prob', 'prob'), ('cdf', 'cdf'), + ('survival_function', 'survival_function'), + ('log_survival_function', 'log_survival_function') ) def test_method_with_inputs(self, function_string): inputs = 10. * np.random.normal(size=(100,))