# TFP bijectors

__Objective:__ explore bijectors in Tensorflow Probability.

Source: [here](https://github.com/tensorchiefs/dl_book/blob/master/chapter_06/nb_ch06_03.ipynb)

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

tfd = tfp.distributions

sns.set_theme()

## Using bijectors to trasform samples

Generate uniformly distributed samples in the $[0, 10]$ interval.

In [None]:
samples = tfd.Uniform(low=0., high=10.).sample(10000)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=samples.numpy(),
    stat='density'
)

plt.title('Uniform samples', fontsize=14)
plt.xlabel('x')

Instantiate a `Square` bijector and apply it to the samples. The original samples were obtained from a uniform distribution,
$$
p_x(x) = \mathcal{U}\left( x | 1, 10\right) =
\left\lbrace\begin{array}{l}
\frac{1}{10}\quad\text{if}\quad x\in[0, 10] \\
0\quad\text{otherwise}
\end{array}\right.
$$
and are mapped to a new space $z$ such that
$$
z = x^2.
$$

The probability density on $z$ is given by the transformation rule
$$
p_z(z) = p_x(x)\, \left| \frac{\mathrm{d}x}{\mathrm{d}z} \right| = p_x(x)\, \frac{1}{2\sqrt{z}} = \frac{1}{20 \sqrt{z}},
$$
where we assumed that the inverse transofrmation is $x = \sqrt{z}$ (otherwise it wouldn't have been invertible).

In [None]:
square_bij = tfp.bijectors.Square()

Transformations are applied with the `forward` and the `inverse` methods, which correspond to the direct and the inverse transformations respectively.

In [None]:
transformed_samples = square_bij.forward(samples)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=transformed_samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Transformed samples'
)

x_plot = tf.linspace(0.5, 100., 1000)
y_plot = 1. / (20. * tf.sqrt(x_plot))

sns.lineplot(
    x=x_plot,
    y=y_plot,
    color=sns.color_palette()[1],
    label='Analytical result'
)

plt.title('Transformed samples', fontsize=14)
plt.xlabel('z')
plt.legend()

The transformed samples can be brought back to the original space via the inverse transformation.

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=square_bij.inverse(transformed_samples).numpy(),
    stat='density',
    color=sns.color_palette()[0]
)

plt.title('Transformed samples mapped back to the original space', fontsize=14)
plt.xlabel('x')

## Using bijectors to transform distributions

Distributions can be transformed via bijectors as well, using the `TransformedDistribution` object, which accepts a "source" distribution and a bijector as its inputs, outputting another distribution corresponding to the mapping of the source one through the bijector.

In [None]:
# Sampling this distribution is equivalent to sampling
# the source one and then applying the transformation
# to the samples.
square_distr = tfd.TransformedDistribution(
    distribution=tfd.Uniform(low=0., high=10.),
    bijector=square_bij
)

In [None]:
transformed_distr_samples = square_distr.sample(10000)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=transformed_distr_samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Samples from the transformed distribution',
    alpha=0.5
)

sns.histplot(
    x=transformed_samples.numpy(),
    stat='density',
    color=sns.color_palette()[1],
    label='Transformed samples',
    alpha=0.5
)

plt.title('Samples from the transformed distribution', fontsize=14)
plt.xlabel('z')
plt.legend()

## Compositions of bijectors

Bijectors can be composed via the `Chain` object in order to obtain a composed transformation.

__Note:__ bijectors are applied in reverse order, **from the last in the list to the first one**.

In this case, the tranformation is
$$
z = \log^2(x)\,,
$$
so if we start from a uniform distribution on $[1, 10]$ (we stay away from $0$ as we need to apply a log) we have:
$$
p_z(z) = \frac{e^\sqrt{z}}{18\sqrt{z}}\,.
$$

In [None]:
composed_bij = tfp.bijectors.Chain(
    bijectors=[
        tfp.bijectors.Square(),
        tfp.bijectors.Log()
    ]
)

transf_comp_distr = tfd.TransformedDistribution(
    distribution=tfd.Uniform(low=1., high=10.),
    bijector=composed_bij
)

In [None]:
comp_samples = transf_comp_distr.sample(10000)

In [None]:
comp_samples.numpy().min(), comp_samples.numpy().max()

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=comp_samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Samples'
)

x_plot = tf.linspace(.01, comp_samples.numpy().max(), 1000)
y_plot = tf.exp(tf.sqrt(x_plot)) / (18. * tf.sqrt(x_plot))

sns.lineplot(
    x=x_plot,
    y=y_plot,
    color=sns.color_palette()[1],
    label='Analytical result'
)

plt.title('Samples from the transformed distribution', fontsize=14)
plt.xlabel('z')
plt.legend()