Skip to content

Commit

Permalink
Clarify behaviour of the Transformed distribution in the docstring.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 430918549
  • Loading branch information
franrruiz authored and DistraxDev committed Feb 25, 2022
1 parent 9139341 commit 0827af5
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions distrax/_src/distributions/transformed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ class Transformed(dist_base.Distribution):
where `p(x)` is the probability density of `X` (the "base density") and
`J(f)(x)` is the Jacobian matrix of `f`, both evaluated at `x = f^{-1}(y)`.
Sampling from a Transformed distribution involves two steps: sampling from the
base distribution `x ~ p(x)` and then evaluating `y = f(x)`. The first step
is agnostic to the possible batch dimensions of the bijector `f(x)`. For
example:
```
dist = distrax.Normal(loc=0., scale=1.)
bij = distrax.ScalarAffine(shift=jnp.asarray([3., 3., 3.]))
transformed_dist = distrax.Transformed(distribution=dist, bijector=bij)
samples = transformed_dist.sample(seed=0, sample_shape=())
print(samples) # [2.7941577, 2.7941577, 2.7941577]
```
Note: the `batch_shape`, `event_shape`, and `dtype` properties of the
transformed distribution, as well as the `kl_divergence` method, are computed
on-demand via JAX tracing when requested. This assumes that the `forward`
Expand Down

0 comments on commit 0827af5

Please sign in to comment.