From 97ef0de620b8c40a02f387e2df3cdf3bd0ebbe3d Mon Sep 17 00:00:00 2001 From: DistraxDev Date: Wed, 20 Apr 2022 12:17:49 -0700 Subject: [PATCH] Implement survival and log-survival function for the `Logistic` distribution. Also: * Simplify the tests for the `Logistic` PiperOrigin-RevId: 443160365 --- distrax/_src/distributions/logistic.py | 8 +++ distrax/_src/distributions/logistic_test.py | 72 +++------------------ 2 files changed, 18 insertions(+), 62 deletions(-) diff --git a/distrax/_src/distributions/logistic.py b/distrax/_src/distributions/logistic.py index c3793d3b..80c04589 100644 --- a/distrax/_src/distributions/logistic.py +++ b/distrax/_src/distributions/logistic.py @@ -101,6 +101,14 @@ def log_cdf(self, value: Array) -> Array: """See `Distribution.log_cdf`.""" return -jax.nn.softplus(-self._standardize(value)) + def survival_function(self, value: Array) -> Array: + """See `Distribution.survival_function`.""" + return jax.nn.sigmoid(-self._standardize(value)) + + def log_survival_function(self, value: Array) -> Array: + """See `Distribution.log_survival_function`.""" + return -jax.nn.softplus(self._standardize(value)) + def mean(self) -> Array: """Calculates the mean.""" return self.loc diff --git a/distrax/_src/distributions/logistic_test.py b/distrax/_src/distributions/logistic_test.py index e1dd2c0c..3b977dde 100644 --- a/distrax/_src/distributions/logistic_test.py +++ b/distrax/_src/distributions/logistic_test.py @@ -109,71 +109,19 @@ def test_sample_and_log_prob(self, distr_params, sample_shape): ('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])), ('1d dist, 1d value, edge case', (0, 1), 200), ) - def test_log_prob(self, distr_params, value): + def test_method_with_input(self, distr_params, value): distr_params = (np.asarray(distr_params[0], dtype=np.float32), np.asarray(distr_params[1], dtype=np.float32)) value = np.asarray(value, dtype=np.float32) - super()._test_log_prob(distr_params, dict(), value) - - @chex.all_variants - @parameterized.named_parameters( - ('1d dist, 1d value', (0, 1), 1), - ('1d dist, 2d value', (0., 1.), np.array([1, 2])), - ('1d dist, 2d value as list', (0., 1.), [1, 2]), - ('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1), - ('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1), - ('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])), - ('1d dist, 1d value, edge case', (0, 1), 200), - ) - def test_prob(self, distr_params, value): - distr_params = (np.asarray(distr_params[0], dtype=np.float32), - np.asarray(distr_params[1], dtype=np.float32)) - value = np.asarray(value, dtype=np.float32) - super()._test_attribute( - attribute_string='prob', - dist_args=distr_params, - call_args=(value,), - assertion_fn=self.assertion_fn) - - @chex.all_variants - @parameterized.named_parameters( - ('1d dist, 1d value', (0, 1), 1), - ('1d dist, 2d value', (0., 1.), np.array([1, 2])), - ('1d dist, 2d value as list', (0., 1.), [1, 2]), - ('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1), - ('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1), - ('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])), - ('1d dist, 1d value, edge case', (0, 1), 200), - ) - def test_cdf(self, distr_params, value): - distr_params = (np.asarray(distr_params[0], dtype=np.float32), - np.asarray(distr_params[1], dtype=np.float32)) - value = np.asarray(value, dtype=np.float32) - super()._test_attribute( - attribute_string='cdf', - dist_args=distr_params, - call_args=(value,), - assertion_fn=self.assertion_fn) - - @chex.all_variants - @parameterized.named_parameters( - ('1d dist, 1d value', (0, 1), 1), - ('1d dist, 2d value', (0., 1.), np.array([1, 2])), - ('1d dist, 2d value as list', (0., 1.), [1, 2]), - ('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1), - ('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1), - ('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])), - ('1d dist, 1d value, edge case', (0, 1), 200), - ) - def test_log_cdf(self, distr_params, value): - distr_params = (np.asarray(distr_params[0], dtype=np.float32), - np.asarray(distr_params[1], dtype=np.float32)) - value = np.asarray(value, dtype=np.float32) - super()._test_attribute( - attribute_string='log_cdf', - dist_args=distr_params, - call_args=(value,), - assertion_fn=self.assertion_fn) + for method in ['log_prob', 'prob', 'cdf', 'log_cdf', 'survival_function', + 'log_survival_function']: + with self.subTest(method): + super()._test_attribute( + attribute_string=method, + dist_args=distr_params, + dist_kwargs={}, + call_args=(value,), + assertion_fn=self.assertion_fn) @chex.all_variants(with_pmap=False) @parameterized.named_parameters(