# A/B testing with Bernoulli trials

Let's say that we want to try two different versions of an e-commerce website to see which one has a higher chance of seeing the customers convert (place an order). This is the typical example of an A/B test, with a control group that gets served one version of the website and a treatment group that gets served the other. The data consists in binary labels corresponding to customers and indicating whether a customer converted (value = 1) or not (value = 0).

Idea: we can model our conversions data as a set of Bernoulli trials. The test will then be about whether the probability in these Bernoulli distribution for the control and treatment group is the same or not. In the classical statistic framework of hypothesis testing, we'd have that the the probability being the same between the two groups is the null hypothesis.

Source: this notebook is explicitly "inspired" by the [corresponding section of the Bayesian methods for hackers book](https://github.com/CamDavidsonPilon/Probabilistic-Programming-and-Bayesian-Methods-for-Hackers/blob/master/Chapter2_MorePyMC/Ch2_MorePyMC_TFP.ipynb).

In [None]:
import itertools
import warnings
import numpy as np
from scipy.fftpack import next_fast_len
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import tensorflow_probability as tfp
import arviz as az

az.style.use("arviz-darkgrid")

tfd = tfp.distributions

## Data loading and exploration

In [None]:
data_control = tf.constant(np.load('../data/ab_testing_bernoulli/data_control.npy'))
data_treatment = tf.constant(np.load('../data/ab_testing_bernoulli/data_treatment.npy'))

data_control.shape, data_treatment.shape

The maximum likelihood estimate of the probability parameter $p$ in a Bernoulli distribution, given the data, is the sample mean.

In [None]:
p_est_control = tf.reduce_mean(data_control)
p_est_treatment = tf.reduce_mean(data_treatment)
p_est_pooled = tf.reduce_mean(tf.concat([data_control, data_treatment], axis=0))

print('Maximum likelihood estimate for p for control group:', p_est_control.numpy())
print('Maximum likelihood estimate for p for treatment group:', p_est_treatment.numpy())
print('Maximum likelihood estimate for p for the pooling of the groups:', p_est_pooled.numpy())

Let's write a batch of 3 Bernoulli distributions, corresponding to the 3 cases above.

In [None]:
bernoullis = tfd.Bernoulli(probs=[p_est_control, p_est_treatment, p_est_pooled])

bernoullis

In [None]:
bernoullis.sample(10)

As we increase the number of samples, taking the mean for each distribution should return a better and better approximation of the probabilities we started with.

In [None]:
tf.reduce_mean(tf.cast(bernoullis.sample(100000), tf.float32), axis=0)

## Bayesian inference

Notes:
- Not much tuning of the parameters for MCMC has been done, but it should have been!
- The method to update the sampling step size is now deprecated and should be updated with its new version.
- In general, it's a better idea to sample multiple chains in parallel for the same parameter, to check that they converge to the same distribution (robustness). This hasn't been done here, mostly because starting multiple parallel chains entails reworking the joint log prob so it transparently uses batches of values by broadcasting variables - and getting it right is a bit tedious.

### Inference on probabilities, separately

Let's perform Bayesian inference on the parameter $p$ of the two groups separately.

There are multiple ways to compute the joint log prob, we explore two here:
- We can define a function that given the data and a value of $p$ insantiates the corresponding prior and likelihood and returns the joint log likelihood. This way there's one function that works for all the groups.
- We can define a joint probability distribution directly, specifying a (potentially different) prior and likelihood for each group, at the cost of having a different distribution object for each group.

In [None]:
def joint_log_prob(occurrences, prob):
    """
    Joint log probability optimization function.
        
    Args:
      occurrences: An array of binary values (0 & 1), representing 
                   the observed frequency
      prob_A: scalar estimate of the probability of a 1 appearing 
    Returns: 
      sum of the joint log probabilities from all of the prior and conditional distributions
    """  
    # Prior.
    rv_prob= tfd.Uniform(low=0., high=1.)
    
    # Likelihood.
    rv_occurrences = tfd.Bernoulli(probs=prob)

    return (
        rv_prob.log_prob(prob)
        + tf.reduce_sum(rv_occurrences.log_prob(occurrences))
    )


def trace_stuff(states, previous_kernel_results):
    """
    """
    # I couldn't find a way not to make the counter global.
    step = next(counter)
    
    if (step % 100) == 0:
        print(f"Step {step}, state: {states}")
    
    return previous_kernel_results


def autocov(ary, axis=-1):
    """Compute autocovariance estimates for every lag for the input array.
    Parameters
    ----------
    ary : Numpy array
        An array containing MCMC samples
    Returns
    -------
    acov: Numpy array same size as the input array
    """
    axis = axis if axis > 0 else len(ary.shape) + axis
    n = ary.shape[axis]
    m = next_fast_len(2 * n)

    ary = ary - ary.mean(axis, keepdims=True)

    # added to silence tuple warning for a submodule
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        ifft_ary = np.fft.rfft(ary, n=m, axis=axis)
        ifft_ary *= np.conjugate(ifft_ary)

        shape = tuple(
            slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape)
        )
        cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape]
        cov /= n

    return cov


def autocorr(ary, axis=-1):
    """Compute autocorrelation using FFT for every lag for the input array.
    See https://en.wikipedia.org/wiki/autocorrelation#Efficient_computation
    Parameters
    ----------
    ary : Numpy array
        An array containing MCMC samples
    Returns
    -------
    acorr: Numpy array same size as the input array
    """
    corr = autocov(ary, axis=axis)
    axis = axis = axis if axis > 0 else len(corr.shape) + axis
    norm = tuple(
        slice(None, None) if dim != axis else slice(None, 1) for dim, _ in enumerate(corr.shape)
    )
    with np.errstate(invalid="ignore"):
        corr /= corr[norm]
    return corr

In [None]:
joint_distr_control = tfd.JointDistributionSequential([
    tfd.Uniform(low=0., high=1.),
    lambda p: tfd.Independent(
        tfd.Bernoulli(probs=p * tf.ones_like(data_control)),
        reinterpreted_batch_ndims=1
    )
])

joint_distr_treatment = tfd.JointDistributionSequential([
    tfd.Uniform(low=0., high=1.),
    lambda p: tfd.Independent(
        tfd.Bernoulli(probs=p * tf.ones_like(data_treatment)),
        reinterpreted_batch_ndims=1
    )
])

In [None]:
# Check that the two methods return the same joint log propbability.
print(
    'Control:',
    joint_log_prob(data_control, p_est_control).numpy(),
    joint_distr_control.log_prob(p_est_control, data_control).numpy()
)

print(
    'Treatment:',
    joint_log_prob(data_treatment, p_est_treatment).numpy(),
    joint_distr_treatment.log_prob(p_est_treatment, data_treatment).numpy()
)

Note: there might be a small discrepancy between the two methods for computing the joing log probability in the treatment group. This is due to a rounding error introduced when computing the log in the log prob.

#### Inference on the control group

In [None]:
# Define a closure over our joint_log_prob.
# The closure makes it so the HMC doesn't try to change the `occurrences` but
# instead determines the distributions of other parameters that might generate
# the `occurrences` we observed.
unnormalized_posterior_log_prob_control = lambda p: joint_distr_control.log_prob(p, data_control)

In [None]:
number_of_steps = 2000
burnin = 500
leapfrog_steps=2

# Set the chain's start state.
initial_chain_state = [
    tf.reduce_mean(tf.cast(data_control, tf.float32)),
]

# Since HMC operates over unconstrained space, we need to transform the
# samples so they live in real-space.
unconstraining_bijectors = [
    tfp.bijectors.Sigmoid()   # Maps R to (0, 1).
]

step_size = tf.Variable(0.5, dtype=tf.float32)

# Defining the HMC
hmc = tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=unnormalized_posterior_log_prob_control,
        num_leapfrog_steps=leapfrog_steps,
        step_size=step_size,
        # The step size adaptation prevents stationarity to occur, so the
        # number of adaptation steps should be smaller than the number of
        # burnin steps so that in the remaining part of the burnin phase
        # stationarity can be reached.
        step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(num_adaptation_steps=int(burnin * 0.8)),
        state_gradients_are_stopped=True),
    bijector=unconstraining_bijectors)

# Sampling from the chain.
print('Sampling started')

counter = itertools.count(1)

[
    posterior_prob_control
], kernel_results = tfp.mcmc.sample_chain(
    num_results=number_of_steps,
    num_burnin_steps=burnin,
    current_state=initial_chain_state,
    kernel=hmc,
    trace_fn=trace_stuff)

print('Sampling finished')

burned_posterior_prob_control_trace = posterior_prob_control[burnin:]

In [None]:
kernel_results.inner_results.is_accepted.numpy().mean()

In [None]:
burned_posterior_prob_control_trace.numpy().mean(), p_est_control

In [None]:
# Plotting using ArviZ.
az.plot_trace(
    burned_posterior_prob_control_trace.numpy(),
    divergences='bottom',
    figsize=(16, 6)
)

az.plot_autocorr(
    burned_posterior_prob_control_trace.numpy(),
    figsize=(16, 6)
)

az.plot_posterior(
    burned_posterior_prob_control_trace.numpy(),
    kind='hist',
    figsize=(16, 6)
)

In [None]:
# Custom plotting.
# Trace.
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

plt.scatter(
    x=tf.range(1, burned_posterior_prob_control_trace.shape[0] + 1, dtype=tf.int32).numpy(),
    y=burned_posterior_prob_control_trace.numpy(),
    alpha=0.5,
)

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.title("Traceplot", fontsize=18)

plt.xlabel("Iteration", fontsize=14)
plt.ylabel("Sample", fontsize=14)

# Autocorrelations.
fig = plt.figure(figsize=(14, 6))

x_autocorr = np.arange(1, burned_posterior_prob_control_trace.shape[0])

plt.bar(
    x_autocorr,
    autocorr(burned_posterior_prob_control_trace.numpy())[1:],
    width=1,
    label="$m$",
    edgecolor=sns.color_palette()[0],
    color=sns.color_palette()[0]
)

plt.title("Autocorrelation plot of traces for differing $k$ lags")
plt.ylabel("Correlation \nbetween $x_t$ and $x_{t-k}$")
plt.xlabel("k (lag)")
plt.legend(loc='upper right')

# Posterior distribution.
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    burned_posterior_prob_control_trace,
    bins=40,
    stat='density',
    label='Posterior samples'
)

plt.vlines(
    x=burned_posterior_prob_control_trace.numpy().mean(),
    ymin=0.,
    ymax=30.,
    color='r',
    label='Posterior mean')

plt.legend()
plt.xlabel('p (sampled)', fontsize=12)
plt.title('Distribution of samples from the posterior', fontsize=14)

#### Inference on the treatment group

In [None]:
unnormalized_posterior_log_prob_treatment = lambda p: joint_distr_treatment.log_prob(p, data_treatment)

In [None]:
number_of_steps = 2000
burnin = 500
leapfrog_steps=2

initial_chain_state = [
    tf.reduce_mean(tf.cast(data_treatment, tf.float32))
]

unconstraining_bijectors = [
    tfp.bijectors.Sigmoid()   # Maps R to (0, 1).
]

step_size = tf.Variable(0.5, dtype=tf.float32)

hmc = tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=unnormalized_posterior_log_prob_treatment,
        num_leapfrog_steps=leapfrog_steps,
        step_size=step_size,
        step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(num_adaptation_steps=int(burnin * 0.8)),
        state_gradients_are_stopped=True),
    bijector=unconstraining_bijectors)

# Sampling from the chain.
print('Sampling started')

counter = itertools.count(1)

[
    posterior_prob_treatment
], kernel_results = tfp.mcmc.sample_chain(
    num_results=number_of_steps,
    num_burnin_steps=burnin,
    current_state=initial_chain_state,
    kernel=hmc,
    trace_fn=trace_stuff)

print('Sampling finished')

burned_posterior_prob_treatment_trace = posterior_prob_treatment[burnin:]

In [None]:
kernel_results.inner_results.is_accepted.numpy().mean()

In [None]:
burned_posterior_prob_treatment_trace.numpy().mean(), p_est_treatment.numpy()

In [None]:
# Plotting using ArviZ.
az.plot_trace(
    burned_posterior_prob_treatment_trace.numpy(),
    divergences='bottom',
    figsize=(16, 6)
)

az.plot_autocorr(
    burned_posterior_prob_treatment_trace.numpy(),
    figsize=(16, 6)
)

az.plot_posterior(
    burned_posterior_prob_treatment_trace.numpy(),
    kind='hist',
    figsize=(16, 6)
)

Plot of the posterior sampled distribution for both groups together.

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

# Control.
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    burned_posterior_prob_treatment_trace,
    bins=40,
    stat='density',
    label='Posterior samples (treatment)'
)

plt.vlines(
    x=burned_posterior_prob_treatment_trace.numpy().mean(),
    ymin=0.,
    ymax=30.,
    color='r',
    label='Posterior mean (treatment)')

# Treatment.
sns.histplot(
    burned_posterior_prob_control_trace,
    bins=40,
    stat='density',
    label='Posterior samples (control)',
    color='orange'
)

plt.vlines(
    x=burned_posterior_prob_control_trace.numpy().mean(),
    ymin=0.,
    ymax=30.,
    color='purple',
    label='Posterior mean (control)')

plt.legend()
plt.xlabel('p (sampled)', fontsize=12)
plt.title('Distribution of samples from the posterior', fontsize=14)

### Inference on probabilities, both grouops together

Note: this is conceptually the same thing done above, but with chains sampled in parallel. This is because the model assumes no interaction between the distributions, so the sampling of one does not affect the other.

In [None]:
# Because the samples are independent, the combined joint log prob is just the
# sum of the log probs for each group.
unnormalized_posterior_log_prob_combined = lambda p_control, p_treatment: (
    joint_distr_control.log_prob(p_control, data_control)
    + joint_distr_treatment.log_prob(p_treatment, data_treatment)
)

In [None]:
# Set the chain's start state.
initial_chain_state = [    
    tf.reduce_mean(tf.cast(data_control, tf.float32)),
    tf.reduce_mean(tf.cast(data_treatment, tf.float32))
]

In [None]:
# Test that the target log prob function behaves as expected.
# Note: in case of more than one parameter, because the initial state is
#       defined as a list, we need to pass it to the function with a `*`
#       (that's probably what TFP does internally).
unnormalized_posterior_log_prob_combined(*initial_chain_state)

In [None]:
number_of_steps = 2000
burnin = 500
leapfrog_steps=3

# Since HMC operates over unconstrained space, we need to transform the
# samples so they live in real-space.
unconstraining_bijectors = [
    tfp.bijectors.Sigmoid(),   # Maps R to (0, 1).
    tfp.bijectors.Sigmoid()    # Maps R to (0, 1).
]

# Initialize the step_size. (It will be automatically adapted.)
step_size = tf.Variable(0.5, dtype=tf.float32)

# Defining the HMC
hmc=tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=unnormalized_posterior_log_prob_combined,
        num_leapfrog_steps=3,
        step_size=step_size,
        step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(num_adaptation_steps=int(burnin * 0.8)),
        state_gradients_are_stopped=True),
    bijector=unconstraining_bijectors)

# Sample from the chain.
print('Sampling started')

counter = itertools.count(1)

[
    posterior_prob_control_combined,
    posterior_prob_treatment_combined
], kernel_results = tfp.mcmc.sample_chain(
    num_results=number_of_steps,
    num_burnin_steps=burnin,
    current_state=initial_chain_state,
    kernel=hmc,
    trace_fn=trace_stuff
)

print('Sampling finished')

burned_posterior_prob_control_combined_trace = posterior_prob_control_combined[burnin:]
burned_posterior_prob_treatment_combined_trace = posterior_prob_treatment_combined[burnin:]
burned_delta_trace = (burned_posterior_prob_control_combined_trace - burned_posterior_prob_treatment_combined_trace)[burnin:]

In [None]:
kernel_results.inner_results.is_accepted.numpy().mean()

In [None]:
print(
    'Control:',
    burned_posterior_prob_control_combined_trace.numpy().mean(),
    p_est_control.numpy()
)

print(
    'Treatment:',
    burned_posterior_prob_treatment_combined_trace.numpy().mean(),
    p_est_treatment.numpy()
)

In [None]:
# Plotting using ArviZ.
az.plot_trace(
    np.array([
        burned_posterior_prob_control_combined_trace.numpy(),
        burned_posterior_prob_treatment_combined_trace.numpy()
    ]),
    divergences='bottom',
    figsize=(16, 6)
)

az.plot_autocorr(
    np.array([
        burned_posterior_prob_control_combined_trace.numpy(),
        burned_posterior_prob_treatment_combined_trace.numpy()
    ]),
    figsize=(16, 6)
)

az.plot_posterior(
    np.array([
        burned_posterior_prob_control_combined_trace.numpy(),
        burned_posterior_prob_treatment_combined_trace.numpy()
    ]),
    kind='hist',
    figsize=(16, 6)
)

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

# Fluid.
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

sns.histplot(
    burned_posterior_prob_control_combined_trace,
    bins=40,
    stat='density',
    label='Posterior samples (control)'
)

plt.vlines(
    x=burned_posterior_prob_control_combined_trace.numpy().mean(),
    ymin=0.,
    ymax=30.,
    color='r',
    label='Posterior mean (control)')

# Analytics.
sns.histplot(
    burned_posterior_prob_treatment_combined_trace,
    bins=40,
    stat='density',
    label='Posterior samples (treatment)',
    color='orange'
)

plt.vlines(
    x=burned_posterior_prob_treatment_combined_trace.numpy().mean(),
    ymin=0.,
    ymax=30.,
    color='purple',
    label='Posterior mean (treatment)')

plt.legend()
plt.xlabel('p (sampled)', fontsize=12)
plt.title('Distribution of samples from the posterior', fontsize=14)


# Delta.
fig = plt.figure(figsize=(14, 6))

sns.histplot(
    burned_delta_trace,
    bins=40,
    stat='density',
    label='Posterior samples (delta)'
)

plt.vlines(
    x=0.,
    ymin=0.,
    ymax=20.,
    color='r',
    label='"Null hypothesis"')

plt.legend()
plt.xlabel('Delta', fontsize=12)
plt.title('Distribution of samples from the posterior for delta', fontsize=14)

### Sampling from multiple chains

Let's try to sample from multiple chains in parallel in order to do diagnostics on the convergence (assess robustness).

#### Mini-guide: how to set up multichain sampling with TFP

Sources:
- `tfp.mcmc.sample_chain` function [documentation](https://www.tensorflow.org/probability/api_docs/python/tfp/mcmc/sample_chain).
- [Issue on GitHub](https://github.com/tensorflow/probability/issues/1093) asking the same question.

Sampling can be performed in batches (a batch of chains) and the batch shape should be the same for all the parameters, because the `target_log_prob_fn` function (i.e. the unnormalized log prob, i.e. the closure on the function that returns the log prob given value(s) for the parameters to estimate and the data) needs to return values with a shape equal to the batch shape. In practice, we can restrict to having the number of chains as a single dimension, in which case this reduces to saying that the number of chains should be the same for all the parameters on which MCMC is performed.

From the documentation we're told that the number of chains is inferred from the shape of the state (i.e. the initial state), but it's also true that the target log prob function must return results with the corresponding shape. Also, the bijectors (in case HMC is used) need to be able to handle the shape of the states. The ingredients are:
- States (initial state in particular, as it's the only on under our control).
- Target log prob function.
- Bijectors.

With one chain,
- The initial state is a list of tensors with scalar shape (`()`), one for each parameter.
- The target log prob function returns a scalar.
- The bijectors are put in a list, with one bijector for each parameter.

With `n_chains` chains:
- The initial state is still a list of tensors, one for each parameter, but this time each tensor must have shape `(n_chains,)`, which is the "batch shape".
- The target log prob function must return a tensor with the same batch shape, `(n_chains,)`.
- The bijectors are put in a list with one bijector for each parameter, exactly as before (bijectors handle the batch shape automatically).

A good idea to make sure everything is working as expected is to define the initial shape as a list of tensors with the desired shape, then pass it as input to the target log prob function and see if the result has the same shape as the tensors. There's only **one tricky point** to remember, irrespective of whether we're running a single or multiple chains: if there's more than one parameters on which to perform MCMC, we need to pass the initial state to the target log prob function prepended with a `*`.

In [None]:
n_chains = 4

In [None]:
# First possibility: define a joint distribution object that behaves well with
# varying parameter shapes.
joint_distr_control = tfd.JointDistributionSequential([
    tfd.Uniform(low=0., high=1.),
    lambda p: tfd.Independent(
        tfd.Bernoulli(
            probs=tf.expand_dims(p, -1) * tf.ones_like(data_control)
        ),
        reinterpreted_batch_ndims=1
    )
])

joint_distr_treatment = tfd.JointDistributionSequential([
    tfd.Uniform(low=0., high=1.),
    lambda p: tfd.Independent(
        tfd.Bernoulli(
            probs=tf.expand_dims(p, -1) * tf.ones_like(data_treatment)
        ),
        reinterpreted_batch_ndims=1
    )
])

In [None]:
# Second possibility: use a target log prob function instead of the log prob of
# a joint distribution object.
# Idea: whatever the shape of the probabilities, interpret all dimensions except from the
# right-most as the batch shape, so that the jont log prob function returns a
# batch of log probs upon evaluation.
def joint_log_prob_combined(p_control, p_treatment, data_control, data_treatment):
    prior_p_control = tfd.Uniform(low=0., high=1.)
    prior_p_treatment = tfd.Uniform(low=0., high=1.)
    
    # The total likelihood is the product of the likelihoods, as control and
    # test samples are independent.
    likelihood_control = tfd.Bernoulli(probs=tf.expand_dims(p_control, -1))
    likelihood_treatment = tfd.Bernoulli(probs=tf.expand_dims(p_treatment, -1))
    
    return (
        tf.reduce_sum(prior_p_control.log_prob(p_control), axis=-1)
        + tf.reduce_sum(prior_p_treatment.log_prob(p_treatment), axis=-1)
        + tf.reduce_sum(likelihood_control.log_prob(data_control), axis=-1)
        + tf.reduce_sum(likelihood_treatment.log_prob(data_treatment), axis=-1)
    )

In [None]:
# Target log prob function in the two cases. Comment/uncomment to select which
# one to use.
# Using a joint distribution object.
unnormalized_posterior_log_prob_combined = lambda p_control, p_treatment: (
    joint_distr_control.log_prob(p_control, data_control)
    + joint_distr_treatment.log_prob(p_treatment, data_treatment)
)
# Defining a function to return the log prob.
# unnormalized_posterior_log_prob_combined = (
#     lambda p_control, p_treatment: joint_log_prob_combined(p_control, p_treatment, data_control, data_treatment)
# )

# # Test if the unnormalized posterior log prob behaves as expected with a
# possible initial state as the input.
state_batch = [
    tf.stack([
        tf.reduce_mean(tf.cast(data_control, tf.float32)),
    ] * n_chains),
    tf.stack([
        tf.reduce_mean(tf.cast(data_treatment, tf.float32)),
    ] * n_chains)
]

unnormalized_posterior_log_prob_combined(*state_batch)

In [None]:
number_of_steps = 2000
burnin = 500
leapfrog_steps=2

# Set the chain's start state.
initial_chain_state = [
    tf.stack([
        tf.reduce_mean(tf.cast(data_control, tf.float32)),
    ] * n_chains),
    tf.stack([
        tf.reduce_mean(tf.cast(data_treatment, tf.float32)),
    ] * n_chains)
]

# Since HMC operates over unconstrained space, we need to transform the
# samples so they live in real-space.
unconstraining_bijectors = [
    tfp.bijectors.Sigmoid(),  # Maps R to (0, 1).
    tfp.bijectors.Sigmoid()   # Maps R to (0, 1).
]

step_size = tf.Variable(0.5, dtype=tf.float32)

# Defining the HMC
hmc = tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
        # target_log_prob_fn=unnormalized_posterior_log_prob_control,
        target_log_prob_fn=unnormalized_posterior_log_prob_combined,
        num_leapfrog_steps=leapfrog_steps,
        step_size=step_size,
        # The step size adaptation prevents stationarity to occur, so the
        # number of adaptation steps should be smaller than the number of
        # burnin steps so that in the remaining part of the burnin phase
        # stationarity can be reached.
        step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(num_adaptation_steps=int(burnin * 0.8)),
        state_gradients_are_stopped=True),
    bijector=unconstraining_bijectors)

# Sampling from the chain.
print('Sampling started')

counter = itertools.count(1)

[
    posterior_prob_control_combined,
    posterior_prob_treatment_combined
], kernel_results = tfp.mcmc.sample_chain(
    num_results=number_of_steps + burnin,
    num_burnin_steps=burnin,
    current_state=initial_chain_state,
    kernel=hmc,
    trace_fn=trace_stuff)

print('Sampling finished')

trace_control_combined_burned = posterior_prob_control_combined[burnin:]
trace_treatment_combined_burned = posterior_prob_treatment_combined[burnin:]

inference_data = az.convert_to_inference_data({
    'p_control': tf.transpose(trace_control_combined_burned),
    'p_treatment': tf.transpose(trace_treatment_combined_burned)
})

In [None]:
inference_data

In [None]:
az.summary(inference_data)

In [None]:
az.plot_trace(inference_data)

az.plot_autocorr(inference_data)

az.plot_posterior(inference_data)

az.plot_forest(inference_data)