Skip to content

Commit

Permalink
Implement survival and log-survival function for the Logistic distr…
Browse files Browse the repository at this point in the history
…ibution.

Also:
* Simplify the tests for the `Logistic`
PiperOrigin-RevId: 443160365
  • Loading branch information
DistraxDev authored and DistraxDev committed Apr 22, 2022
1 parent 906ecb5 commit 97ef0de
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 62 deletions.
8 changes: 8 additions & 0 deletions distrax/_src/distributions/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def log_cdf(self, value: Array) -> Array:
"""See `Distribution.log_cdf`."""
return -jax.nn.softplus(-self._standardize(value))

def survival_function(self, value: Array) -> Array:
"""See `Distribution.survival_function`."""
return jax.nn.sigmoid(-self._standardize(value))

def log_survival_function(self, value: Array) -> Array:
"""See `Distribution.log_survival_function`."""
return -jax.nn.softplus(self._standardize(value))

def mean(self) -> Array:
"""Calculates the mean."""
return self.loc
Expand Down
72 changes: 10 additions & 62 deletions distrax/_src/distributions/logistic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,71 +109,19 @@ def test_sample_and_log_prob(self, distr_params, sample_shape):
('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (0, 1), 200),
)
def test_log_prob(self, distr_params, value):
def test_method_with_input(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_log_prob(distr_params, dict(), value)

@chex.all_variants
@parameterized.named_parameters(
('1d dist, 1d value', (0, 1), 1),
('1d dist, 2d value', (0., 1.), np.array([1, 2])),
('1d dist, 2d value as list', (0., 1.), [1, 2]),
('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1),
('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1),
('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (0, 1), 200),
)
def test_prob(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_attribute(
attribute_string='prob',
dist_args=distr_params,
call_args=(value,),
assertion_fn=self.assertion_fn)

@chex.all_variants
@parameterized.named_parameters(
('1d dist, 1d value', (0, 1), 1),
('1d dist, 2d value', (0., 1.), np.array([1, 2])),
('1d dist, 2d value as list', (0., 1.), [1, 2]),
('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1),
('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1),
('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (0, 1), 200),
)
def test_cdf(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_attribute(
attribute_string='cdf',
dist_args=distr_params,
call_args=(value,),
assertion_fn=self.assertion_fn)

@chex.all_variants
@parameterized.named_parameters(
('1d dist, 1d value', (0, 1), 1),
('1d dist, 2d value', (0., 1.), np.array([1, 2])),
('1d dist, 2d value as list', (0., 1.), [1, 2]),
('2d dist, 1d value', (np.zeros(2), np.ones(2)), 1),
('2d broadcasted dist, 1d value', (np.zeros(2), 1), 1),
('2d dist, 2d value', (np.zeros(2), np.ones(2)), np.array([1, 2])),
('1d dist, 1d value, edge case', (0, 1), 200),
)
def test_log_cdf(self, distr_params, value):
distr_params = (np.asarray(distr_params[0], dtype=np.float32),
np.asarray(distr_params[1], dtype=np.float32))
value = np.asarray(value, dtype=np.float32)
super()._test_attribute(
attribute_string='log_cdf',
dist_args=distr_params,
call_args=(value,),
assertion_fn=self.assertion_fn)
for method in ['log_prob', 'prob', 'cdf', 'log_cdf', 'survival_function',
'log_survival_function']:
with self.subTest(method):
super()._test_attribute(
attribute_string=method,
dist_args=distr_params,
dist_kwargs={},
call_args=(value,),
assertion_fn=self.assertion_fn)

@chex.all_variants(with_pmap=False)
@parameterized.named_parameters(
Expand Down

0 comments on commit 97ef0de

Please sign in to comment.