# TFP bijectors

__Objective:__ explore bijectors in Tensorflow Probability.

**Syntax hint:** the behaviour of a `bijector` object when **called** upon another object depends on the type of the latter, with three cases:
- When called on a **tensor of samples**, the resulting object is a tensor of transformed samples (equivalent to applying the bijector's `forward` method).
- When called on a distribution, the resulting object is a `TansformedDistribution` object corresponding to the transformed distribution through the bijector (equivalent to using `TransformedDistribution` specifying the base - source - distribution and the bijector).
- When called on another bijector, the resulting object is a bijector equivalent to the chain of bijector (via a `Chain` object), with the inner-most being applied first.

Source: [here](https://github.com/tensorchiefs/dl_book/blob/master/chapter_06/nb_ch06_03.ipynb)

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

tfd = tfp.distributions

sns.set_theme()

## Using bijectors to trasform samples

Generate uniformly distributed samples in the $[0, 10]$ interval.

In [None]:
uniform_distr = tfd.Uniform(low=0., high=10.)

samples = uniform_distr.sample(10000)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=samples.numpy(),
    stat='density'
)

plt.title('Uniform samples', fontsize=14)
plt.xlabel('x')

Instantiate a `Square` bijector and apply it to the samples. The original samples were obtained from a uniform distribution,
$$
p_x(x) = \mathcal{U}\left( x | 1, 10\right) =
\left\lbrace\begin{array}{l}
\frac{1}{10}\quad\text{if}\quad x\in[0, 10] \\
0\quad\text{otherwise}
\end{array}\right.
$$
and are mapped to a new space $z$ such that
$$
z = x^2.
$$

The probability density on $z$ is given by the transformation rule
$$
p_z(z) = p_x(x)\, \left| \frac{\mathrm{d}x}{\mathrm{d}z} \right| = p_x(x)\, \frac{1}{2\sqrt{z}} = \frac{1}{20 \sqrt{z}},
$$
where we assumed that the inverse transofrmation is $x = \sqrt{z}$ (otherwise it wouldn't have been invertible).

In [None]:
square_bij = tfp.bijectors.Square()

Transformations are applied with the `forward` and the `inverse` methods, which correspond to the direct and the inverse transformations respectively.

In [None]:
transformed_samples = square_bij.forward(samples)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=transformed_samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Transformed samples'
)

x_plot = tf.linspace(0.5, 100., 1000)
y_plot = 1. / (20. * tf.sqrt(x_plot))

sns.lineplot(
    x=x_plot,
    y=y_plot,
    color=sns.color_palette()[1],
    label='Analytical result'
)

plt.title('Transformed samples', fontsize=14)
plt.xlabel('z')
plt.legend()

The transformed samples can be brought back to the original space via the **inverse** transformation.

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=square_bij.inverse(transformed_samples).numpy(),
    stat='density',
    color=sns.color_palette()[0]
)

plt.title('Transformed samples mapped back to the original space', fontsize=14)
plt.xlabel('x')

## Using bijectors to transform distributions

Distributions can be transformed via bijectors as well, using the `TransformedDistribution` object, which accepts a "source" distribution and a bijector as its inputs, outputting another distribution corresponding to the mapping of the source one through the bijector.

In [None]:
# Sampling this distribution is equivalent to sampling
# the source one and then applying the transformation
# to the samples.
square_distr = tfd.TransformedDistribution(
    distribution=tfd.Uniform(low=0., high=10.),
    bijector=square_bij
)

In [None]:
transformed_distr_samples = square_distr.sample(10000)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=transformed_distr_samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Samples from the transformed distribution',
    alpha=0.5
)

sns.histplot(
    x=transformed_samples.numpy(),
    stat='density',
    color=sns.color_palette()[1],
    label='Transformed samples',
    alpha=0.5
)

plt.title('Samples from the transformed distribution', fontsize=14)
plt.xlabel('z')
plt.legend()

`TransformedDistribution` objects can also be created directly calling the bijector on the source distribution.

In [None]:
square_distr_2 = square_bij(tfd.Uniform(low=0., high=10.))

square_distr_2

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=square_distr_2.sample(10000).numpy(),
    stat='density',
    color=sns.color_palette()[1],
    label='Transformed samples',
    alpha=0.5
)

plt.title('Samples from the transformed distribution', fontsize=14)
plt.xlabel('z')
plt.legend()

## Compositions of bijectors

Bijectors can be composed via the `Chain` object in order to obtain a composed transformation.

__Note:__ bijectors are applied in reverse order, **from the last in the list to the first one**.

In this case, the tranformation is
$$
z = \log^2(x)\,,
$$
so if we start from a uniform distribution on $[1, 10]$ (we stay away from $0$ as we need to apply a log) we have:
$$
p_z(z) = \frac{e^\sqrt{z}}{18\sqrt{z}}\,.
$$

In [None]:
composed_bij = tfp.bijectors.Chain(
    bijectors=[
        tfp.bijectors.Square(),
        tfp.bijectors.Log()
    ]
)

transf_comp_distr = tfd.TransformedDistribution(
    distribution=tfd.Uniform(low=1., high=10.),
    bijector=composed_bij
)

In [None]:
comp_samples = transf_comp_distr.sample(10000)

In [None]:
comp_samples.numpy().min(), comp_samples.numpy().max()

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=comp_samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Samples'
)

x_plot = tf.linspace(.01, comp_samples.numpy().max(), 1000)
y_plot = tf.exp(tf.sqrt(x_plot)) / (18. * tf.sqrt(x_plot))

sns.lineplot(
    x=x_plot,
    y=y_plot,
    color=sns.color_palette()[1],
    label='Analytical result'
)

plt.title('Samples from the transformed distribution', fontsize=14)
plt.xlabel('z')
plt.legend()

Bijectors can also be composed by calling one on the other. Transformations are applied from the inner-most bijector to the outer-most one in the chain of calls. Let's implement an **affine transformation** (a scaling frollowed by a shift) this way.

In [None]:
scale_bij = tfp.bijectors.Scale(3.)
shift_bij = tfp.bijectors.Shift(-1.)

# This composition is equivalent to:
# tfp.bijectors.Chain([shift_bij, scale_bij])
affine_bij = shift_bij(scale_bij)

In [None]:
affine_transf_samples = affine_bij.forward(samples)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=samples.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Samples',
    alpha=0.5
)

sns.histplot(
    x=affine_transf_samples.numpy(),
    stat='density',
    color=sns.color_palette()[1],
    label='Affine-transformed samples',
    alpha=0.5
)

plt.title('Samples and their affine transformation', fontsize=14)
plt.xlabel('x')
plt.legend()

## Computing (log) probabilities

Let's go back to our uniformly distribution and square bijector and distribution. We can compute the log probability of the transformed samples in two ways:
- Take the transformed distribution and use its `log_prob` method on the transformed samples.
- Take the original samples and compute the difference between the `log_prob` yielded by the original uniform distribution they were drawn from and the log of the absolute value of the determinant of the Jacobian of the square transformation (which is obtained via the the bijector's `forward_log_det_jacobian` method).

The second method works because of the change of variable formula for probability density: if $x, z \in \mathbb{R}^d$ we have
$$
p_x(x) = \left\vert \det(J) \right\vert\, p_z(z),
$$

where $J = \left[ \frac{\partial z_i}{\partial x_j} \right]$ is the Jacobian matrix of the transformation $x \to z(x)$ and in the formula $z$ takes the value $z=z(x)$. Inverting this we get
$$
p_z(z) = \left\vert \det(J) \right\vert^{-1}\, p_x(x),
$$

where now $x$ is evaluated at the value corresponding to $z$ via the inverse transformation ($x=x(z)$).

**Note:** had we started directy with the inverse transformation ($z \to x(z)$) we would have got the inverse Jacobian, but since the determinant of the inverse matrix is the inverse of the determinant of the matrix itself, we would have obtained exactly the same formula.

Taking the log on both sides we get the formula for the second mathod mentioned above:
$$
\log(p_z(z)) = \log(p_x(x)) - \log\left(\left\vert \det(J) \right\vert \right)\,.
$$

In [None]:
# First method.
square_distr.log_prob(transformed_samples)

In [None]:
# Second method.
# Note: the `event_ndims` argument indicates how many of the right-most
#       dimensions of the tensor to interpret as event shape.
uniform_distr.log_prob(samples) - square_bij.forward_log_det_jacobian(samples, event_ndims=0)

In [None]:
# Check that the result is the same.
np.isclose(
    square_distr.log_prob(transformed_samples).numpy(),
    uniform_distr.log_prob(samples) - square_bij.forward_log_det_jacobian(samples, event_ndims=0)
).all()

## Bijectors and broadcasting

If we pass a list of parameters to a bijector and then apply it to some samples tensor, the shape of the tensor is broadcast against the list of parameters (if possible, otherwise and error is raised).

In [None]:
gaussian_samples = tfd.Normal(loc=4., scale=.5).sample((10000, 1))

In [None]:
softfloor_bij = tfp.bijectors.Softfloor(temperature=[0.4, 1.2])

# The shape of the samples tensor is broadcast against the number of
# parameter falues in the bijector.
# Shape: (n_samples, softfloor_bij.temperature.shape).
softfloor_transf_samples = softfloor_bij.forward(gaussian_samples)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=gaussian_samples[:, 0].numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Gaussian samples',
    alpha=0.5,
)

for i in range(softfloor_transf_samples.shape[1]):
    sns.histplot(
        x=softfloor_transf_samples[:, i].numpy(),
        stat='density',
        color=sns.color_palette()[i+1],
        label='Transformed samples (softfloor)',
        alpha=0.5
    )

plt.title('Samples and their transformation', fontsize=14)
plt.xlabel('x')
plt.legend()

## Linear operator bijectors

Linear operator bijectors implement linear operations on tensors or batches of distributions. They perform various linear algebra operations between the objects (tensors, batches of distributions) and some other given tensors of parameters.

In [None]:
# Batch shape: (4,).
base_gaussian = tfd.Normal(loc=[2.] * 4, scale=[.5] * 4)

base_gaussian

Create a lower triangular matrix by applying a `LinearOperatorLowerTriangular` operator to a tensor of shape (4, 4). The result is another tensor of shape (4, 4) that will be used as the parameter for a bijector.

In [None]:
lower_tr_mat = tf.linalg.LinearOperatorLowerTriangular(2. * tf.ones(shape=(4, 4)))

lower_tr_mat.to_dense()

Define a `ScaleMatvecLinearOperator` bijector performing matrix multiplication of the base distribution (batch) with the specified tensor.

In [None]:
scale_lower_triangular_bij = tfp.bijectors.ScaleMatvecLinearOperator(lower_tr_mat)

In [None]:
scale_lower_triangular_bij(base_gaussian).sample()

## Bijector subclassing

If in need to define a custom bijector for which there's no TFP implementation, a new one can be defined by subclassing the `Bijector` base class. As an example, we define the new `Affine` bijector, implementing an affine coordinate transformation (a rescaling followed by a shift).

In [None]:
class AffineBijector(tfp.bijectors.Bijector):
    """
    Implementation of an affine bijector, i.e. an affine tranformation
    of the form
        x -> y(x) = shift + scale * x
    """
    def __init__(self, scale, shift, validate_args=False, name='affine'):
        """
        Constructor method. The `validate_args` and `name` arguments
        are mandatory.
        """
        # The `forward_min_event_ndims` argument of the parent class'
        # constructor indicates the minimum event dimension of the
        # bijector that is being implemented, in the forward transformation
        # (it can be different between the forward and the inverse one).
        # The `is_constant_jacobian` argument is used to specify that the
        # Jacobian is independent from the point at which it's evaluated.
        # This allows for caching of the value of the Jacobian, bringing
        # performance gains.
        super().__init__(
            validate_args=validate_args,
            forward_min_event_ndims=0,
            is_constant_jacobian=True,
            name=name
        )

        # Bijector's parameters.
        self.scale = scale
        self.shift = shift

    def _forward(self, x):
        """
        Implements the forward (direct) tranformation.
        """
        return tfp.bijectors.Shift(self.shift)(
            tfp.bijectors.Scale(self.scale)
        )(x)

    def _inverse(self, y):
        """
        Implements the inverse tranformation.
        """
        return tfp.bijectors.Scale(1./self.scale)(
            tfp.bijectors.Shift(-self.shift)
        )(y)

    def _inverse_log_det_jacobian(self, y):
        """
        Implements the computation of the logarithm of the
        absolute determinant of the Jacobian of the inverse tranformation
        (to be applied to transformed samples, or to samples
        in the transformed space).
        """
        return - tf.math.log(tf.abs(self.scale))

    def _forward_log_det_jacobian(self, x):
        """
        Implements the computation of the logarithm of the
        absolute determinant of the Jacobian of the forward (direct)
        transformation (to be applied to non-transformed samples,
        or to samples in the non-transformed space).
        """
        return - self._inverse_log_det_jacobian(self._forward(x))

In [None]:
affine_bij_class = AffineBijector(shift=-1., scale=2.)

affine_bij_class.forward(tf.constant([1., 2., 3., 4.]))

In [None]:
samples_plot = tfd.Normal(loc=2., scale=1.).sample(10000)
affine_class_transf_samples = affine_bij_class(samples_plot)

fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=samples_plot.numpy(),
    stat='density',
    color=sns.color_palette()[0],
    label='Gaussian samples',
    alpha=0.5,
)

sns.histplot(
    x=affine_class_transf_samples.numpy(),
    stat='density',
    color=sns.color_palette()[1],
    label='Transformed samples (affine)',
    alpha=0.5
)

plt.title('Samples and their transformation', fontsize=14)
plt.xlabel('Value')
plt.legend()