From b327819cfca2088323392d04d714208574727856 Mon Sep 17 00:00:00 2001 From: DistraxDev Date: Wed, 20 Apr 2022 12:44:13 -0700 Subject: [PATCH] Implement survival and log-survival function for the `Normal` distribution. PiperOrigin-RevId: 443166567 --- distrax/_src/distributions/distribution.py | 33 +++++++++++++++++++ .../_src/distributions/distribution_test.py | 6 ++++ distrax/_src/distributions/normal.py | 8 +++++ distrax/_src/distributions/normal_test.py | 3 +- distrax/_src/distributions/uniform_test.py | 2 ++ 5 files changed, 51 insertions(+), 1 deletion(-) diff --git a/distrax/_src/distributions/distribution.py b/distrax/_src/distributions/distribution.py index d81d80b3..2ec47085 100644 --- a/distrax/_src/distributions/distribution.py +++ b/distrax/_src/distributions/distribution.py @@ -188,6 +188,39 @@ def cdf(self, value: Array) -> Array: """ return jnp.exp(self.log_cdf(value)) + def survival_function(self, value: Array) -> Array: + """Evaluates the survival function at `value`. + + Args: + value: An event. + + Returns: + The survival function evaluated at `value`, i.e. P[X > value] + """ + if not self.event_shape: + # Defined for univariate distributions only. + return 1. - self.cdf(value) + else: + raise NotImplementedError('`survival_function` is not defined for ' + f'distribution {self.name}') + + def log_survival_function(self, value: Array) -> Array: + """Evaluates the log of the survival function at `value`. + + Args: + value: An event. + + Returns: + The log of the survival function evaluated at `value`, i.e. + log P[X > value] + """ + if not self.event_shape: + # Defined for univariate distributions only. + return jnp.log1p(-self.cdf(value)) + else: + raise NotImplementedError('`log_survival_function` is not defined for ' + f'distribution {self.name}') + def mean(self) -> Array: """Calculates the mean.""" raise NotImplementedError( diff --git a/distrax/_src/distributions/distribution_test.py b/distrax/_src/distributions/distribution_test.py index e850eb4b..b2d2c503 100644 --- a/distrax/_src/distributions/distribution_test.py +++ b/distrax/_src/distributions/distribution_test.py @@ -205,6 +205,12 @@ def test_to_batch_shape_index_raises(self, index): distribution.to_batch_shape_index( batch_shape=(2, 3, 4), index=index) + def test_multivariate_survival_function_raises(self): + mult_dist = DummyMultivariateDist(42) + with self.assertRaises(NotImplementedError): + mult_dist.survival_function(jnp.zeros(42)) + with self.assertRaises(NotImplementedError): + mult_dist.log_survival_function(jnp.zeros(42)) if __name__ == '__main__': absltest.main() diff --git a/distrax/_src/distributions/normal.py b/distrax/_src/distributions/normal.py index 608d7a97..21e316f9 100644 --- a/distrax/_src/distributions/normal.py +++ b/distrax/_src/distributions/normal.py @@ -104,6 +104,14 @@ def log_cdf(self, value: Array) -> Array: """See `Distribution.log_cdf`.""" return jax.scipy.special.log_ndtr(self._standardize(value)) + def survival_function(self, value: Array) -> Array: + """See `Distribution.survival_function`.""" + return jax.scipy.special.ndtr(-self._standardize(value)) + + def log_survival_function(self, value: Array) -> Array: + """See `Distribution.log_survival_function`.""" + return jax.scipy.special.log_ndtr(-self._standardize(value)) + def _standardize(self, value: Array) -> Array: return (value - self._loc) / self._scale diff --git a/distrax/_src/distributions/normal_test.py b/distrax/_src/distributions/normal_test.py index 57884160..7953ba9d 100644 --- a/distrax/_src/distributions/normal_test.py +++ b/distrax/_src/distributions/normal_test.py @@ -103,7 +103,8 @@ def test_method_with_input(self, distr_params, value): distr_params = (np.asarray(distr_params[0], dtype=np.float32), np.asarray(distr_params[1], dtype=np.float32)) value = np.asarray(value, dtype=np.float32) - for method in ['log_prob', 'prob', 'cdf', 'log_cdf']: + for method in ['log_prob', 'prob', 'cdf', 'log_cdf', 'survival_function', + 'log_survival_function']: with self.subTest(method): super()._test_attribute( attribute_string=method, diff --git a/distrax/_src/distributions/uniform_test.py b/distrax/_src/distributions/uniform_test.py index 0fdac471..91ff1ffc 100644 --- a/distrax/_src/distributions/uniform_test.py +++ b/distrax/_src/distributions/uniform_test.py @@ -89,6 +89,8 @@ def test_sample_and_log_prob(self, distr_params, sample_shape): ('log_prob', 'log_prob'), ('prob', 'prob'), ('cdf', 'cdf'), + ('survival_function', 'survival_function'), + ('log_survival_function', 'log_survival_function') ) def test_method_with_inputs(self, function_string): inputs = 10. * np.random.normal(size=(100,))