# Training distributions (Gamma mixtures)

__WARNING__

This notebook contains experiments with TFP's distributions which are __not__ guaranteed to be the right way to approach the problem. There can be bugs in the code and mistakes in the logic. The solution is built sequentially with multiple attempts and some attempts do contain mistakes that are addressed in the following attempts: they are marked and left there for instructional purposes.

__Problem__

Suppose there's a system that check if you forgot to lock the door of your house when you left it, and in case you left it unlocked it sends a message on your phone: you may or may not see it and if you see it you may or may not read it. In case you haven't read it after five minutes, the system sends you a reminder. We have data about the people who read the message, containing how much time (in minutes) it took people to read it and we want to model its probability distribution.

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

tfd = tfp.distributions

## A first look at the data

In [None]:
time_deltas = tf.constant(np.load('../data/gamma_mixture/time_deltas.npy'))

time_deltas.shape

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    x=time_deltas.numpy(),
    stat='density')

## Modelling the distribution of the data

### Fit a Gamma distribution directly ("by hand")

Suppose the data comes from a single Gamma distribution. We can use the analytical form of the maximum likelihood estimates of its parameter to compute the "maximum likelihood" fit to the data.

In [None]:
def get_gamma_estimates(data):
    """
    Given some data assumed to be distributed according to a Gamma, computes
    the maximum likelihood estimates for the parameters of the distribution.
    Formulas are taken from: https://en.wikipedia.org/wiki/Gamma_distribution#Maximum_likelihood_estimation
    """
    n_samples = tf.cast(data.shape[0], tf.float64)
    
    sum_x = tf.reduce_sum(data)
    sum_log_x = tf.reduce_sum(tf.math.log(data))
    sum_x_log_x = tf.reduce_sum(tf.math.log(data) * data)
    
    denom = (n_samples * sum_x_log_x - sum_log_x * sum_x)
    
    alpha_hat = n_samples * sum_x / denom
    beta_hat = (tf.math.square(n_samples)) / denom
    
    return alpha_hat, beta_hat

In [None]:
alpha_hat, beta_hat = get_gamma_estimates(time_deltas)

gamma_fit = tfd.Gamma(
    concentration=alpha_hat,
    rate=beta_hat)

gamma_fit.parameters

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    time_deltas.numpy(),
    stat='density',
    label='Data')

x_values = tf.linspace(1e-6, 10., 1000).numpy()

plt.plot(
    x_values,
    gamma_fit.prob(x_values),
    color='orange',
    label='Gamma distribution fit')

plt.legend()

Observation: this is definitely not the way to go!

### Training TFP distributions

TFP distributions are trainable: their parameters can be opimized (e.g. via gradient descent) as if they were parameters of a model. This can be used to fit distributions to our data.

#### A single Gamma distribution

The distribution of the data appears to have a bump right after 5 minutes, consistent with the fact that some people read the message after they receive the reminder. This tells us that if we want to try to fit a single Gamma distribution to the data we should drop this bump, so we drop all the time deltas greater than 5 (minutes).

Note: in doing so we also drop the tail of the Gamma distribution we're trying to fit!

In [None]:
def get_loss(distr, data):
    """
    Negative log likelihood loss.
    """
    return - tf.reduce_sum(distr.log_prob(data))


def train_distr(distr, data, epochs, loss_fn, lr=0.05):
    """
    Explicit implementation of gradient descent.
    """
    optimizer = tf.optimizers.Adam(learning_rate=0.05)
    
    loss_history = []
    params_history = []
    
    for i in range(epochs):
        with tf.GradientTape() as g:
            g.watch(distr.trainable_variables)
            
            loss = loss_fn(distr, data)
        
        gradient = g.gradient(loss, distr.trainable_variables)
        
        optimizer.apply_gradients(zip(gradient, distr.trainable_variables))
        
        loss_history.append(loss_fn(distr, data))
        params_history.append(distr.trainable_variables)
        
        if (i % 100) == 0:
            print(f'Epoch: {i+1} of {epochs} | Loss: {loss_history[-1]}')
        
    print(f'Epoch: {i+1} of {epochs} | Loss: {loss_history[-1]}')
        
    return loss_history, params_history, distr

In [None]:
data = tf.cast(time_deltas[time_deltas < 5.], tf.float32)

# Rescale data so it starts from 0.
data = data - tf.reduce_min(data) + 10e-12

In [None]:
# Initial values for the gradient descent.
one_gamma_concentration = tf.Variable(3., name='one_gamma_concentration', dtype=tf.float32)
one_gamma_rate = tf.Variable(3., name='one_gamma_rate', dtype=tf.float32)

one_gamma = tfd.Gamma(
    concentration=one_gamma_concentration,
    rate=one_gamma_rate)

epochs = 2000

loss_history, params_history, one_gamma_trained = train_distr(one_gamma, data, epochs, get_loss)

In [None]:
one_gamma_trained.parameters

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    data.numpy(),
    stat='density',
    label='data')

x_values = tf.linspace(1e-6, 10., 1000).numpy()

plt.plot(
    x_values,
    one_gamma_trained.prob(x_values),
    color='orange')

sns.histplot(
    one_gamma_trained.sample(10000),
    stat='density',
    color='orange',
    label='Fit')

plt.legend()

Observation: not a very good fit, but after all we cut away the tail of the distribution.

#### A mixture of Gamma distributions

Let's now try to keep all the data (including the second peak) and fit a mixture of Gamma distributions to it, one representing the first peak and one representing the second. Now we have two pairs of parameters for the two distributions, plus the parameters for the mixture.

Note: __mistake!__ The parameters of the mixture are probabilities themselves, so each should be in $[0, 1]$ and they should add up to one. This is a constraint that should be imposed during the optimization, otherwise it's not guaranteed that we end up with sensible results (and by the way this is also true for the parameters of the Gammas, which may end up outside of their allowed domain!). We're keeping this error here because we'll correct it later in more refined attempts.

In [None]:
data = tf.cast(time_deltas, tf.float64)

# Rescale data so it starts from 0 (modulo a small offset to avoid log(0) when
# computing the log likelihood).
data = data - tf.reduce_min(data) + 10e-12

In [None]:
# Initial values for the gradient descent.
mixture_probs = tf.Variable([0.7, 0.3], dtype=tf.float64)
mixture_concentrations = tf.Variable([3., 12.], dtype=tf.float64)
mixture_rates = tf.Variable([3., 3.], dtype=tf.float64)

# MixtureSameFamily allows to build a mixture of distribution provided they
# are the same distribution (in terms of functional form) with different
# values for the parameters.
mixture_distr = tfd.MixtureSameFamily(
    # A categorical distribution is used to take a mixture of the component
    # distributions.
    mixture_distribution=tfd.Categorical(
        probs=mixture_probs),
    components_distribution=tfd.Gamma(
      concentration=mixture_concentrations,
      rate=mixture_rates)) 

In [None]:
loss_history, params_history, mixture_distr_trained = train_distr(mixture_distr, data, 5000, get_loss)

In [None]:
print(mixture_distr_trained.parameters['mixture_distribution'].parameters)
print(mixture_distr_trained.parameters['components_distribution'].parameters)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    data.numpy(),
    stat='density',
    label='Data')

sns.histplot(
    mixture_distr_trained.sample(100000),
    stat='density',
    color='orange',
    label='Samples from the trained distribution')

plt.legend(loc='upper right')


# Plot each component.
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    data.numpy(),
    stat='density',
    label='Data',
    color=sns.color_palette()[0])

sns.histplot(
    mixture_distr_trained.parameters['components_distribution'][0].sample(100000),
    stat='density',
    color='orange',
    label='Mixture first component')

sns.histplot(
    mixture_distr_trained.parameters['components_distribution'][1].sample(100000),
    stat='density',
    color=sns.color_palette()[2],
    label='Mixture second component')

plt.legend(loc='upper right')

In [None]:
mixture_distr_trained.parameters['mixture_distribution'].parameters['probs'].numpy().sum()

Observation: we evidently made some mistakes here,
- As mentioned above, we didn't impose any constraints on the parameters we're optimizing, and indeed the above line shows exaclty that.
- Because we used a mixture of two "copies" of the same distribution, the two components ended up overlapping a lot.

#### A mixture of a Gamma distributions and a shifted Gamma distribution

Consider an alternative system that doesn't send any reminder. In this case we'd expect the bump visible after 5 minutes to disappear and we'd model the distribution of the data with a single Gamma distribution.

In the real case the bump at 5 minutes could be modeled as another Gamma "starting" at 5 minutes, so we end up with a mixture of a Gamma (with domain $[0, +\infty)$) and a Gamma shifted forward by 5 (with domain $[5, +\infty)$). In this attempt, the shift is taken as another parameter to optimize (even though we know it should be 5 minutes).

Note: __mistake!__ We're still not addressing the issue of imposing the constraints on the parameters during the optimization.

In [None]:
# Side note: test shifting a distribution with the Shift bijector.
gamma_test = tfd.Gamma(
      concentration=tf.constant(3., dtype=tf.float64),
      rate=tf.constant(3., dtype=tf.float64))

# TFP implements transformation on distributions as bijector objects.
shift = tfp.bijectors.Shift(tf.constant(-3, dtype=tf.float64))

# TransformedDistribution allows to obtain the desired result by combining
# the original distribution with the appropriate bijector.
gamma_test_shifted = tfd.TransformedDistribution(
    gamma_test,
    shift)

fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    gamma_test.sample(100000),
    stat='density',
    label='Original distribution')

sns.histplot(
    gamma_test_shifted.sample(100000),
    stat='density',
    color='orange',
    label='Shifted distribution')

plt.legend(loc='upper right')

In [None]:
data = tf.cast(time_deltas, tf.float64)

# Rescale data so it starts from 0 (modulo a small offset to avoid log(0) when
# computing the log likelihood).
data = data - tf.reduce_min(data) + 10e-12

In [None]:
mixture_probs = tf.Variable([0.9, 0.1], dtype=tf.float64)

gamma_1_concentration = tf.Variable(1.5, dtype=tf.float64)
gamma_1_rate = tf.Variable(1., dtype=tf.float64)
gamma_2_concentration = tf.Variable(3., dtype=tf.float64)
gamma_2_rate = tf.Variable(10., dtype=tf.float64)
gamma_2_shift = tf.Variable(5. - 10e-10, dtype=tf.float64)
# gamma_1_concentration = tf.Variable(3., dtype=tf.float64)
# gamma_1_rate = tf.Variable(3., dtype=tf.float64)
# gamma_2_concentration = tf.Variable(30., dtype=tf.float64)
# gamma_2_rate = tf.Variable(3., dtype=tf.float64)
# gamma_2_shift = tf.Variable(5. - 10e-10, dtype=tf.float64)


mixture_distr = tfd.Mixture(
    cat=tfd.Categorical(
        probs=mixture_probs),
    components=[
        tfd.Gamma(
            concentration=gamma_1_concentration,
            rate=gamma_1_rate,
            name='Gamma'),
        tfd.TransformedDistribution(
            tfd.Gamma(
                concentration=gamma_2_concentration,
                rate=gamma_2_rate),
            tfp.bijectors.Shift(gamma_2_shift),
            name='shiftGamma')
    ]
) 


# Plot samples from the mixture distribution with initial values.
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    data.numpy(),
    stat='density',
    label='Data')

sns.histplot(
    mixture_distr.sample(100000),
    stat='density',
    color='orange',
    label='Samples from the distribution (initial values for the parameters)')

plt.legend(loc='upper right')

In [None]:
# This way, training is hopeless.
loss_history, params_history, mixture_distr_trained = train_distr(mixture_distr, data, 1, get_loss)

Observation: we immediatly got a null value for the loss function!

__Explanation:__ we know have a mixture of two distributions, one defined in $[0, +\infty)$ and the other in $[5, +\infty)$. The loss function we're using is the negative log likelihood, but there's a problem here: according to mixture models, each data point is generated by any of the components of the mixture with a certain probability, so its contribution to the log likelihood will be (the logarithm of) a linear combination of its probability according to each component. But for points lying between 0 and the shift parameter (the "beginning" of the second Gamma), the second Gamma cannot be evaluated because they're outside of its domain.

__Solution:__ if we want to use a mixture with a shifted Gamma we also need to correct the loss function so that the negative log likelihood is computed correctly.

In [None]:
# Demonstration of the above issue.
# Datapoints x s.t. x > 5 generate sensible log likelihood.
print(mixture_distr.log_prob(data[data >= gamma_2_shift]))

# For datapoints x s.t. x < 5 the log likelihood can't be computedf
print(mixture_distr.log_prob(data[data < gamma_2_shift]))

#### A mixture of a Gamma distributions and a shifted Gamma distribution, fixing the loss function

Not only will we fix the Gamma, but we'll also experiment constraining the optimization of the parameters of the mixture and having a variable or fixed shift parameter. Use the various variables/constants for different experiments.

In [None]:
@tf.function
def get_loss_shifted_mixture(mixture_distr, data):
    """
    Negative log likelihood loss, fixed for the mixture distribution of a
    Gamma and a shifted Gamma.
    """
    # Get the shift parameter.
    shift_param = mixture_distr.parameters['components'][1].parameters['bijector'].shift
    
    # For all the datapoints in [0, shift_param], the contribution to the log
    # likelihood is given only by the first Gamma in the distribution.
    ll_below_shift = tf.reduce_sum(mixture_distr.parameters['components'][0].log_prob(data[data < shift_param]))
    
    # For all the datapoints in [shift_param, +oo], the contribution to the
    # log likelihood is given by the full mixture.
    ll_above_shift = tf.reduce_sum(mixture_distr.log_prob(data[data >= shift_param]))

    return - (ll_below_shift + ll_above_shift)

In [None]:
data = tf.cast(time_deltas, tf.float64)

# Rescale data so it starts from 0 (modulo a small offset to avoid log(0) when
# computing the log likelihood.
data = data - tf.reduce_min(data) + 10e-12

# Ignore the last part of the tail (probably outliers).
data = data[data < 10.]

In [None]:
# Define parameters (constant and/or variable, experimenting).
cat_probs = tf.Variable(
    [0.8, 0.2],
    name='cat_probs',
    dtype=tf.float64,
    # Constrain cat_probs[0] to be in [0, 1] (or smaller, to avoid extreme
    # values) and cat_probs[1] to be 1 - cat_probs[0].
    constraint=(lambda cat_probs: tf.Variable([
        tf.clip_by_value(cat_probs[0], 0.8 + 10e-10, 1. - 10e-10),
        1. - tf.clip_by_value(cat_probs[0], 0.8 + 10e-10, 1. - 10e-10)]))
)
cat_probs_const = tf.constant(
    [0.9, 0.1],
    name='cat_probs',
    dtype=tf.float64)

gamma_1_concentration = tf.Variable(1.5, dtype=tf.float64, name='gamma_1_concentration')
gamma_1_rate = tf.Variable(1., dtype=tf.float64, name='gamma_1_rate')

gamma_2_concentration = tf.Variable(3., dtype=tf.float64, name='gamma_2_concentration')
gamma_2_rate = tf.Variable(10., dtype=tf.float64, name='gamma_2_rate')

gamma_2_shift = tf.Variable(
    5. - 10e-10,
    dtype=tf.float64,
    name='gamma_2_shift',
    # What if we constrain the shift parameter so it doesn't move towards 0?
    # constraint=lambda s: tf.clip_by_value(s, 4.7, 99.)
)
gamma_2_shift_const = tf.constant(
    5. - 10e-10,
    dtype=tf.float64,
    name='gamma_2_shift')


# Define the mixture model.
mixture_distr = tfd.Mixture(
    cat=tfd.Categorical(
        probs=cat_probs),
    components=[
        tfd.Gamma(
            concentration=gamma_1_concentration,
            rate=gamma_1_rate,
            name='Gamma'),
        tfd.TransformedDistribution(
            tfd.Gamma(
                concentration=gamma_2_concentration,
                rate=gamma_2_rate),
            tfp.bijectors.Shift(gamma_2_shift_const),
            name='shiftGamma')
    ]
)


# Plot samples from the mixture distribution with initial values.
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    data.numpy(),
    stat='density',
    label='Data')

sns.histplot(
    mixture_distr.sample(100000),
    stat='density',
    color='orange',
    label='Samples from the distribution (initial values for the parameters)')

plt.legend(loc='upper right')

In [None]:
mixture_distr.trainable_variables

In [None]:
# Train!
loss_history, params_history, mixture_distr_trained = train_distr(
    mixture_distr,
    data,
    1000,
    get_loss_shifted_mixture,
    lr=0.9)

# Plot!
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    data.numpy(),
    stat='density',
    label='Data')

sns.histplot(
    mixture_distr.sample(100000),
    stat='density',
    color='orange',
    label='Samples from the trained distribution')

plt.legend(loc='upper right')


# Plot each component separately (without the weights from the mixture).
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    data.numpy(),
    stat='density',
    label='Data',
    color=sns.color_palette()[0])

sns.histplot(
    mixture_distr.parameters['components'][0].sample(100000),
    stat='density',
    color='orange',
    label='Mixture first component')

sns.histplot(
    mixture_distr.parameters['components'][1].sample(100000),
    stat='density',
    color=sns.color_palette()[2],
    label='Mixture second (shifted) component')

plt.legend(loc='upper right')

In [None]:
mixture_distr_trained.parameters['cat'].parameters['probs']

Observation: keeping the weights in the mixture as variables to optimize but having a constant shift for the shifted Gamma we get something that, if not perfect, is at least sensible! Notice however that the mixture weights haven't moved from the original value - which signals that there might be something wrong with how the constraints were enforced.

#### Further ideas

- Try different component distributions (log-normal, exponential, ...).
- Try different combination of trainable and fixed parameters.
- Try different initial values for the trainable parameters.
- Try more epochs.