# Representing joint probabilities in TensorFlow Probability

**Objective:** in this notebook we explore various methods to write a joint probability with TensorFlow Probability.

Two of the main ingredients of Bayesian methods are the likelihood function and the prior distribution. In particular:
- Likelihood: expresses the probability of having observed the data given certain values of the parameters of the corresponding probability density function.
- Prior: expresses the probability density for the values of the parameters of the probability density function (and is thus defined over the space of the parameters).

Given some data $\mathcal{D}$, the likelihood is a function $p(\mathcal{D} | \alpha)$, where $\alpha$ is the set of parameters, while the prior is $p(\alpha)$. The fundamental equation of Bayesian modeling is Bayes' theorem,

$$
p(\alpha | \mathcal{D}) = \frac{p(\mathcal{D} | \alpha)\, p(\alpha)}{p(\mathcal{D})}\,.
$$

The numerator on the RHS is the product of the likelihood and the prior, which, by the chain rule of probabilities, also expresses the joint probability for the data and the parameters, $p(\mathcal{D} | \alpha)\, p(\alpha) = p(\mathcal{D}, \alpha)$.

The denominator is a normalization constant and in fact numerical methods like MCMC do not require its explicit computation (which can usually be done only with numerical approximations anyway). We can ignore it.

Numerical methods do require the evaluation of the numerator though, so it's important to be able to write and evaluate likelihoods when using a probabilistic programming language like TensorFlow Probability (TFP).

## The model

Specifying a model equates to specifying its joint probability distribution. As an example to work with, let's consider `N=100` random values $\lbrace x_1, \ldots, x_N\rbrace$ sampled from a Gaussian distribution with parameters $(\mu, \sigma)$ with a Gaussian prior on $\mu$ and a Half-normal prior on $\sigma$:

$$
\begin{eqnarray}
\mu &\sim& \mathcal{N}(\mu_\mu, \sigma_\mu), \\
\sigma &\sim& \mathcal{HN}(0, \sigma_\sigma), \\
x_i &\sim& \mathcal{N}(\mu, \sigma)\quad \forall i = 1, \ldots, N\,.
\end{eqnarray}
$$

The prior themselves depend on the choice on some parameters that we need to specify (we could put a prior on those as well, but at some point some parameter will need to be specified). In particular we can set

$$
\begin{eqnarray}
(\mu_\mu, \sigma_\mu) &=& (15.0, 2.0), \\
\sigma_\sigma &=& 1.0\,.
\end{eqnarray}
$$

We thus have

$$
\begin{eqnarray}
p(\mathcal{D} | \alpha) &=&  \prod_{i=1}^{N} \mathcal{N}(x_i | \mu, \sigma), \\
p(\alpha) &=& \mathcal{N}(\mu_\mu, \sigma_\mu)\, \mathcal{HN}(0, \sigma_\sigma),
\end{eqnarray}
$$

where now $\mathcal{D}$ denotes the whole dataset, $\alpha$ denotes the set of parameters and we assumed that the datapoints are independent and identically distributed and that the parameters are independent as well.

The sampling process generating a value $x$ is now to be understood as follows: first the priors are sampled and values for $(\mu, \sigma)$ are generated. Then those values are used as the parameters for the Gaussian distribution that is sampled to generate $x$. We now want to genereate `N` such points with TFP. Then, we'll write down various ways to code up the joint probability distribution so that, given a set of points and values for the parameters, it returns a number (their joint probability). That's what numerical methods require.

## Generating the data

First, we need to generate some datapoints. This entails some work that already goes in the direction of writing the joint probability distribution, but we'll also revisit that later.

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

tfd = tfp.distributions

sns.set_theme()

Structure of the joint probability distribution: a single distribution for $\mu$ and a single distribution for $\sigma$ corresponds to `n_samples` independent copies of the distribution of $x$.

In [None]:
n_samples = 100

μ_μ = 15. 
σ_μ = 2.

σ_σ = 1.

distr_dict = {
    'μ': tfd.Normal(loc=μ_μ, scale=σ_μ),
    'σ': tfd.HalfNormal(scale=σ_σ),
    # n_samples independent Gaussian distribution.
    'x': lambda μ, σ: tfd.Independent(
        # The following dimension expansion allows for broadcasting the values of
        # μ and σ (if not done, there would be different behaviours in sampling
        # when calling for a single sample vs multiple samples.
        tfd.Normal(loc=tf.expand_dims(μ, -1) * tf.ones(n_samples), scale=tf.expand_dims(σ, -1) * tf.ones(n_samples)),
        reinterpreted_batch_ndims=1
    )
}

generating_distr = tfd.JointDistributionNamed(
    distr_dict
)

In [None]:
generating_distr

Note: because of the structure, each sample will consist in a signle sample of $\mu$, a single sample of $\sigma$ and 100 `n_samples` samples of $x$.

Example: if we sample the joint probability distribution 1000 times we'll get two tensors of 1000 samples each for $\mu$ and $\sigma$ respectively, and a tensor of shape `(100, 1000)` of samples of $x$.

In [None]:
samples = generating_distr.sample(sample_shape=1000)

In [None]:
samples['x']

In [None]:
samples['x'][:, 0].numpy().mean(), samples['μ'][0].numpy()

In [None]:
samples['x'][:, 1].numpy().mean(), samples['μ'][1].numpy()

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(14, 6))

plt.subplots_adjust(wspace=.4)

for i, rv in enumerate(samples.keys()):
    sns.histplot(
        x=samples[rv].numpy().ravel(),
        ax=axs[i],
        stat='density',
        kde=True
    )
    plt.sca(axs[i])
    plt.title(f'{rv} samples')

In [None]:
single_samples = {
    'μ': samples['μ'][0],  # 1 value.
    'σ': samples['σ'][0],  # 1 value.
    'x': samples['x'][0, :]  # n_samples values.
}

Passing a "single sample" to the `log_prob` method of the joint probability distribution object, the log (joint) probability of the sample (i.e. one value for each distribution in the joint distribution) is computed. This corresponds to the sum of the log probabilities obtained from each "component" distribution.

**Note:** 1 "sample" actually corresponds to 1 value for $\mu$ and $\sigma$ respectively and 100 values for $x$. The result will be the joint probability of that sample (1 value).

In [None]:
generating_distr.log_prob(single_samples)

In [None]:
sum([
    tfd.Normal(loc=μ_μ, scale=σ_μ).log_prob(single_samples['μ']).numpy(),
    tfd.HalfNormal(scale=σ_σ).log_prob(single_samples['σ']).numpy(),
    tfd.Independent(
        tfd.Normal(loc=[single_samples['μ']] * n_samples, scale=[single_samples['σ']] * n_samples),
        reinterpreted_batch_ndims=1
    ).log_prob(single_samples['x']).numpy()
])

We can also compute the log prob of multiple samples at the same time just by passing tensors with the appropriate shape (notice that the sample size should correspond to the **first axis**). In this case we'll get one value of the log prob for each sample.

In [None]:
generating_distr.log_prob(samples)

## Sampling the hierarchical model VS sampling with a single value for each parameter

This is equivalent to having Dirac deltas a priors for the parameter of the Gaussian generating $x$.

In [None]:
# Sample di distributions.
N = 10000

samples = generating_distr.sample(N)

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(14, 6))

plt.subplots_adjust(wspace=.4)

for i, rv in enumerate(samples.keys()):
    sns.histplot(
        x=samples[rv].numpy().ravel(),
        ax=axs[i],
        stat='density',
        kde=True
    )
    plt.sca(axs[i])
    plt.title(f'{rv} samples')

In the data generation process, the distributions were sampled hieararchically: first the distributions for $\mu$ and $\sigma$ were sampled, then these values were used as the parameters to define the Gaussian distribution from which to sample $x$ - and this was repeated for each value of $x$.

What if we didn't sample $\mu$ and $\sigma$ but just used a single value for each? We can have a look at what would have happened by sampling a Gaussian distribution for $x$ with values for the parameters fixed at the mean value of the sampled $\mu$ and $\sigma$.

In [None]:
μ_mean = tf.reduce_mean(samples['μ'])
σ_mean = tf.reduce_mean(samples['σ'])

print(f'Mean μ: {μ_mean} | Mean σ: {σ_mean}')

x_samples_single_parameters = tfd.Normal(loc=μ_mean, scale=σ_mean).sample(N)

fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=samples['x'].numpy().ravel(),
    stat='density',
    kde=True,
    label='Samples from hierarchical model',
    color=sns.color_palette()[0]
)

sns.histplot(
    x=x_samples_single_parameters,
    stat='density',
    kde=True,
    label='Samples with mean values for the parameters',
    color=sns.color_palette()[1]
)

plt.legend(loc='upper right')
plt.title('Comparison between sampling the full model and using only the mean values for the parameters', fontsize=14);

Is the result from the hierarchical model still a Gaussian? Let's try to fit one.

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    x=samples['x'].numpy().ravel(),
    stat='density',
    kde=True,
    label='Samples from hierarchical model',
    color=sns.color_palette()[0]
)

x_range = tf.linspace(samples['x'].numpy().min(), samples['x'].numpy().max(), 1000)

sns.lineplot(
    x=x_range,
    y=tfd.Normal(loc=samples['x'].numpy().mean(), scale=samples['x'].numpy().std()).prob(x_range),
    label='"Fitted" Gaussian',
    color=sns.color_palette()[1]
)

plt.legend(loc='upper right')
plt.title('Comparison with a Gaussian with parameters estimated from the samples', fontsize=14)

## Writing joint probaiblities with `JointDistributionCoroutine`s

TFP offers different methods to define joint probability distribution, one of them is `JointProbabilityNamed` (seen above) and another ont is `JointDistributionCoroutine`, which can be made to act as a function decorator turning the function into a distribution object.

Inside the function, the component distributions must be specified using the `yield` keyword, with the roots of the Bayesian graphs identified with `Root`. The result is equivalent to the one seen previously.

In [None]:
@tfd.JointDistributionCoroutine
def generating_distr_jdc():
    μ = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(loc=μ_μ, scale=σ_μ))
    σ = yield tfd.JointDistributionCoroutine.Root(tfd.HalfNormal(scale=σ_σ))
    x = yield tfd.Independent(
        # The following dimension expansion allows for broadcasting the values of
        # μ and σ (if not done, there would be different behaviours in sampling
        # when calling for a single sample vs multiple samples.
        # tfd.Normal(loc=tf.expand_dims(μ, -1) * tf.ones(n_samples), scale=tf.expand_dims(σ, -1) * tf.ones(n_samples)),
        tfd.Normal(loc=tf.expand_dims(μ, -1) * tf.ones(n_samples), scale=tf.expand_dims(σ, -1) * tf.ones(n_samples)),
        reinterpreted_batch_ndims=1
    )
    
generating_distr_jdc