# Bayesian linear regressions

Objectives:
- Use MCMC sampling to perform a Bayesian version of standard linear regression.
- Repeat the same trying to fit a piecewise linear function to some data.

In [None]:
import itertools
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
import arviz as az

az.style.use("arviz-darkgrid")
tfd = tfp.distributions

## Linear regression

We'll generate synthetic data for the fit. The data is normally distributed around a straight line. To make things more interesting we'll write a single joint distribution for everything: the parameters of the distribution of the data (slope, intercept and variance of the Normal distribution) and the data itself.

### Generate synthetic data

In [None]:
n_points = 100

joint_distr_synthetic = tfd.JointDistributionSequential([
    tfd.Uniform(
        low=-10.5 * tf.ones(shape=n_points), high=23. * tf.ones(shape=n_points)
    ),  # x coordinates of the datapoints.
    tfd.Normal(loc=2.5, scale=3.),  # m
    tfd.Uniform(low=-5., high=12.),  # q
    tfd.TransformedDistribution(
        tfd.HalfNormal(scale=.5),
        tfp.bijectors.Shift(shift=10.)),  # sigma
    lambda sigma, q, m, x_data: tfd.Independent(
        tfd.Normal(loc=x_data * m + q, scale=sigma),
        reinterpreted_batch_ndims=1)
])

# Sample the joint distribution.
distr, samples = joint_distr_synthetic.sample_distributions()

x_data, m_sampled, q_sampled, sigma_sampled, y_data = samples

# Plot the samples.
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

plt.scatter(
    x=x_data,
    y=y_data)

### Bayesian linear regression

We start by writing our Bayesian model, i.e. a distribution that descirbes how we think the data was generated, including the priors from the parametres (and pretending we never saw the distribution that generated the data to begin with!). This is our modelling hypothesis.

In [None]:
def trace_stuff(states, previous_kernel_results):
    """
    """
    # I couldn't find a way not to make the counter global.
    step = next(counter)
    
    if (step % 100) == 0:
        print(f"Step {step}, state: {states}")
    
    return previous_kernel_results

In [None]:
joint_prob = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=20.),  # Prior for m.
    tfd.Normal(loc=0, scale=30.),  # Prior for q.
    tfd.Uniform(low=0., high=50.),  # Prior for sigma.
    # Note: the Independent distribution here is needed so a single sample
    #       of this distribution corresponds to the whole dataset, which means
    #       that given values for m, q, sigma and x and y coordinated of the
    #       datapoints, a call to the log_prob method returns a scalar.
    # Note: the batch size is kept nontrivial, which is the way we deal with
    #       sampling from multiple chains in parallel.
    lambda sigma, q, m: tfd.Independent(
        tfd.Normal(
            loc=tf.transpose(tf.expand_dims(x_data, 1)) * tf.expand_dims(m, -1) + tf.expand_dims(q, -1),
            scale=tf.expand_dims(sigma, -1)
        ),
        reinterpreted_batch_ndims=1
    )
])

joint_log_prob_closure = (
    lambda m, q, sigma: joint_prob.log_prob(m, q, sigma, y_data))

In [None]:
# Test that the closure of the joint log prob correclty evaluates potential
# values for m, q and sigma.
n_chains = 4

test_state = [
    1. * tf.ones(shape=n_chains),
    1. * tf.ones(shape=n_chains),
    1. * tf.ones(shape=n_chains),
]

joint_log_prob_closure(*test_state)

In [None]:
# Set the chain's start state using a the frequentist statistics estimators.
lr = LinearRegression(fit_intercept=True)
lr.fit(x_data.numpy().reshape(-1, 1), y_data.numpy())

max_lik_est_m = tf.constant(lr.coef_[0])
max_lik_est_q = tf.constant(lr.intercept_)

residuals = (
    y_data
    - (x_data * max_lik_est_m + max_lik_est_q)
)

max_likest_sigma = tf.sqrt(tf.reduce_sum(residuals * residuals) / (n_points - 2))


initial_chain_state = [    
    max_lik_est_m * tf.ones(shape=n_chains),
    max_lik_est_q * tf.ones(shape=n_chains),
    max_likest_sigma * tf.ones(shape=n_chains)
]

initial_chain_state

In [None]:
# Plot the frequentist linear regression.
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

data_max_lik_fit = tf.sort(tf.transpose(tf.stack([
    x_data, x_data * max_lik_est_m + max_lik_est_q])), axis=0)

data_max_lik_fit

plt.scatter(
    x=x_data,
    y=y_data,
    color='b',
    label='Data')

plt.fill_between(
    x=data_max_lik_fit[:, 0].numpy(),
    y2=(data_max_lik_fit[:, 1] - 2. * max_likest_sigma).numpy(),
    y1=(data_max_lik_fit[:, 1] + 2. * max_likest_sigma).numpy(),
    alpha=0.3,
    color='g',
    label='2-$\sigma$ band'
)

plt.plot(
    data_max_lik_fit[:, 0].numpy(),
    data_max_lik_fit[:, 1].numpy(),
    color='r',
    label='Maximum likelihood fit')

plt.legend(loc='upper left')

In [None]:
number_of_steps = 5000
burnin = 1000
leapfrog_steps=3

# Since HMC operates over unconstrained space, we need to transform the
# samples so they live in real-space.
unconstraining_bijectors = [
    tfp.bijectors.Identity(),  # Maps R to R (m).
    tfp.bijectors.Identity(),  # Maps R to R (q).
    tfp.bijectors.Exp(),  # Maps R to (0, +oo) (sigma).
]

# Initialize the step_size. (It will be automatically adapted.)
step_size = tf.Variable(0.5, dtype=tf.float32)

# Defining the HMC
hmc=tfp.mcmc.TransformedTransitionKernel(
    inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=joint_log_prob_closure,
        num_leapfrog_steps=3,
        step_size=step_size,
        step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(num_adaptation_steps=int(burnin * 0.8)),
        state_gradients_are_stopped=True),
    bijector=unconstraining_bijectors)

# Sample from the chain.
print('Sampling started')

counter = itertools.count(1)

[
    trace_m,
    trace_q,
    trace_sigma
], kernel_results = tfp.mcmc.sample_chain(
    num_results=number_of_steps + burnin,
    num_burnin_steps=burnin,
    current_state=initial_chain_state,
    kernel=hmc,
    trace_fn=trace_stuff
)

print('Sampling finished')

trace_m_burned = trace_m[burnin:, :]
trace_q_burned = trace_q[burnin:, :]
trace_sigma_burned = trace_sigma[burnin:, :]

posterior_means = {
    'm': tf.reduce_mean(trace_m_burned, axis=0),
    'q': tf.reduce_mean(trace_q_burned, axis=0),
    'sigma': tf.reduce_mean(trace_sigma_burned, axis=0)}

inference_data = az.convert_to_inference_data({
    'm': tf.transpose(trace_m_burned),
    'q': tf.transpose(trace_q_burned),
    'sigma': tf.transpose(trace_sigma_burned)
})

In [None]:
kernel_results.inner_results.is_accepted.numpy().mean()

In [None]:
posterior_means

In [None]:
m_sampled, q_sampled, sigma_sampled

In [None]:
inference_data

In [None]:
az.summary(inference_data)

In [None]:
az.plot_trace(inference_data)

az.plot_autocorr(inference_data)

az.plot_posterior(inference_data)

az.plot_forest(inference_data)

An example of autocorrelation with thinning.

In [None]:
az.plot_autocorr(trace_m_burned[:, 0].numpy().T)

plt.title('Autocorrelation without thinning')

az.plot_autocorr(trace_m_burned[::3, 0].numpy().T)

plt.title('Autocorrelation with thinning (keeping 1 sample every 3)')

Plot the line corresponding to the mean of the posterior samples for each parameter and another line corresponding to other porterior samples drawn randomly from the traces.

In [None]:
def compute_pred(x, m, q):
    return x * m + q

In [None]:
trace_m_burned.shape

In [None]:
x_plot = np.linspace(x_data.numpy().min(), x_data.numpy().max(), 100)
y_plot = compute_pred(
    x_plot,
    tf.reduce_mean(tf.concat([
        trace_m_burned[:, 0],
        trace_m_burned[:, 1],
        trace_m_burned[:, 2],
        trace_m_burned[:, 3]
    ], axis=-1)),
    tf.reduce_mean(tf.concat([
        trace_q_burned[:, 0],
        trace_q_burned[:, 1],
        trace_q_burned[:, 2],
        trace_q_burned[:, 3]
    ], axis=-1)),
)

fig = plt.figure(figsize=(14, 6))

sns.set_theme()

plt.scatter(
    x=x_data,
    y=y_data,
    label='Data')

n_param_samples = 50

chain_indices = np.random.choice(trace_m_burned.shape[1], n_param_samples)
sample_indices = np.random.choice(trace_m_burned.shape[0], n_param_samples)

for si, ci in zip(sample_indices, chain_indices):
    plt.plot(
        x_plot,
        compute_pred(x_plot, trace_m_burned[si, ci], trace_q_burned[si, ci]),
        color='g',
        alpha=.5
    )
    
plt.plot(
    x_plot,
    y_plot,
    color='r',
    label='Fit with posterior sample means')

plt.legend()

Generate a synthetic dataset using the means of the posterior distributions or a random sample from the posterior distribution as values for the parameters ($m$, $q$ and $\sigma$).

In [None]:
def generate_synthetic_dataset(x_data, m, q, sigma):
    synth_data_distr = tfd.Independent(
        tfd.Normal(
            loc=x_data * m + q,
            scale=sigma),
        reinterpreted_batch_ndims=1)
    
    return synth_data_distr.sample()

In [None]:
fig = plt.figure(figsize=(14, 6))

sns.set_theme()

plt.scatter(
    x=x_data,
    y=y_data,
    label='Data',)

m_synthetic_data = tf.reduce_mean(tf.concat([
    trace_m_burned[:, 0],
    trace_m_burned[:, 1],
    trace_m_burned[:, 2],
    trace_m_burned[:, 3]
], axis=-1))
q_synthetic_data = tf.reduce_mean(tf.concat([
    trace_q_burned[:, 0],
    trace_q_burned[:, 1],
    trace_q_burned[:, 2],
    trace_q_burned[:, 3]
], axis=-1))
sigma_synthetic_data = q_synthetic_data = tf.reduce_mean(tf.concat([
    trace_sigma_burned[:, 0],
    trace_sigma_burned[:, 1],
    trace_sigma_burned[:, 2],
    trace_sigma_burned[:, 3]
], axis=-1))

plt.scatter(
    x=x_data,
    y=generate_synthetic_dataset(x_data, m_synthetic_data, q_synthetic_data, sigma_synthetic_data),
    label='Synthetic data (posterior means)')

chain_index = np.random.choice(trace_m_burned.shape[1])
sample_index = np.random.choice(trace_m_burned.shape[0])

plt.scatter(
    x=x_data,
    y=generate_synthetic_dataset(
        x_data,
        trace_m_burned[sample_index, chain_index],
        trace_q_burned[sample_index, chain_index],
        trace_sigma_burned[sample_index, chain_index]),
    label='Synthetic data (random posterior sample)')

plt.legend(loc='upper right')