Skip to content

Commit 127f2d6

Browse files
ahalevThe Meridian Authors
authored andcommitted
Add lognormal_dist_from_mean_std helper function.
PiperOrigin-RevId: 803606606
1 parent 1e88d1b commit 127f2d6

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Add `lognormal_dist_from_mean_std` helper function.
27+
2628
## [1.2.0] - 2025-09-04
2729

2830
* Fix channel data misalignment in `Analyzer.hill_curves` when input channels

meridian/model/prior_distribution.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
'IndependentMultivariateDistribution',
3535
'PriorDistribution',
3636
'distributions_are_equal',
37+
'lognormal_dist_from_mean_std',
3738
]
3839

3940

@@ -1173,6 +1174,32 @@ def distributions_are_equal(
11731174
return True
11741175

11751176

1177+
def lognormal_dist_from_mean_std(
1178+
mean: float | Sequence[float], std: float | Sequence[float]
1179+
) -> backend.tfd.LogNormal:
1180+
"""Define a lognormal distribution from its mean and standard deviation.
1181+
1182+
This function parameterizes lognormal distributions by their mean and
1183+
standard deviation.
1184+
1185+
Args:
1186+
mean: A positive float or array-like object defining the distribution mean.
1187+
std: A non-negative float or array-like object defining the distribution
1188+
standard deviation.
1189+
1190+
Returns:
1191+
A `backend.tfd.LogNormal` object with the input mean and standard deviation.
1192+
"""
1193+
1194+
mean = np.asarray(mean)
1195+
std = np.asarray(std)
1196+
1197+
mu = np.log(mean) - 0.5 * np.log((std / mean) ** 2 + 1)
1198+
sigma = np.sqrt(np.log((std / mean) ** 2 + 1))
1199+
1200+
return backend.tfd.LogNormal(mu, sigma)
1201+
1202+
11761203
def _convert_to_deterministic_0_distribution(
11771204
distribution: backend.tfd.Distribution,
11781205
) -> backend.tfd.Distribution:

meridian/model/prior_distribution_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,5 +1898,44 @@ def test_independent_distributions_variance(
18981898
np.testing.assert_allclose(variance, expected_variance, rtol=1e-5)
18991899

19001900

1901+
class TestLognormalDistFromMeanStd(parameterized.TestCase):
1902+
1903+
@parameterized.product(
1904+
mean=(1.0, 2.0, 3.0),
1905+
std=(1.0, 2.0, 3.0),
1906+
input_type=(float, int, np.float32, backend.to_tensor)
1907+
)
1908+
def test_correct_mean_std_scalar(
1909+
self, mean, std, input_type
1910+
):
1911+
mean = input_type(mean)
1912+
std = input_type(std)
1913+
1914+
dist = prior_distribution.lognormal_dist_from_mean_std(mean, std)
1915+
1916+
np.testing.assert_allclose(dist.mean(), mean, rtol=1e-5)
1917+
np.testing.assert_allclose(dist.stddev(), std, rtol=1e-5)
1918+
1919+
@parameterized.product(
1920+
mean=((1.0,), (1.0, 2.0,)),
1921+
std=((2.0,), (2.0, 3.0,)),
1922+
input_type=(tuple, list, np.array, backend.to_tensor)
1923+
)
1924+
def test_correct_mean_std_array(
1925+
self, mean, std, input_type
1926+
):
1927+
mean = input_type(mean)
1928+
std = input_type(std)
1929+
1930+
dist = prior_distribution.lognormal_dist_from_mean_std(mean, std)
1931+
1932+
expected_len = max(len(mean), len(std))
1933+
expected_mean = np.broadcast_to(mean, expected_len)
1934+
expected_std = np.broadcast_to(std, expected_len)
1935+
1936+
np.testing.assert_allclose(dist.mean(), expected_mean, rtol=1e-5)
1937+
np.testing.assert_allclose(dist.stddev(), expected_std, rtol=1e-5)
1938+
1939+
19011940
if __name__ == '__main__':
19021941
absltest.main()

0 commit comments

Comments
 (0)