diff --git a/distrax/_src/distributions/normal.py b/distrax/_src/distributions/normal.py index 608d7a9..21e316f 100644 --- a/distrax/_src/distributions/normal.py +++ b/distrax/_src/distributions/normal.py @@ -104,6 +104,14 @@ def log_cdf(self, value: Array) -> Array: """See `Distribution.log_cdf`.""" return jax.scipy.special.log_ndtr(self._standardize(value)) + def survival_function(self, value: Array) -> Array: + """See `Distribution.survival_function`.""" + return jax.scipy.special.ndtr(-self._standardize(value)) + + def log_survival_function(self, value: Array) -> Array: + """See `Distribution.log_survival_function`.""" + return jax.scipy.special.log_ndtr(-self._standardize(value)) + def _standardize(self, value: Array) -> Array: return (value - self._loc) / self._scale diff --git a/distrax/_src/distributions/normal_test.py b/distrax/_src/distributions/normal_test.py index 5788416..7953ba9 100644 --- a/distrax/_src/distributions/normal_test.py +++ b/distrax/_src/distributions/normal_test.py @@ -103,7 +103,8 @@ def test_method_with_input(self, distr_params, value): distr_params = (np.asarray(distr_params[0], dtype=np.float32), np.asarray(distr_params[1], dtype=np.float32)) value = np.asarray(value, dtype=np.float32) - for method in ['log_prob', 'prob', 'cdf', 'log_cdf']: + for method in ['log_prob', 'prob', 'cdf', 'log_cdf', 'survival_function', + 'log_survival_function']: with self.subTest(method): super()._test_attribute( attribute_string=method,