# Waste Free SMC comparison

In this notebook we demonstrate the use of the random walk Rosenbluth-Metropolis-Hasting algorithm on a simple logistic regression.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import sklearn

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
plt.rcParams["figure.figsize"] = (12, 8)
import jax

from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
import jax.numpy as jnp
from sklearn.datasets import make_biclusters
import blackjax

## The Data

In [None]:
num_points = 50
X, rows, cols = make_biclusters(
    (num_points, 2), 2, noise=0.6, random_state=314, minval=-3, maxval=3
)
y = rows[0] * 1.0  # y[i] = whether point i belongs to cluster 1

In [None]:
colors = ["tab:red" if el else "tab:blue" for el in rows[0]]
plt.scatter(*X.T, edgecolors=colors, c="none")
plt.xlabel(r"$X_0$")
plt.ylabel(r"$X_1$");

## The Model

We use a simple logistic regression model to infer to which cluster each of the points belongs. We note $y$ a binary variable that indicates whether a point belongs to the first cluster :

$$
y \sim \operatorname{Bernoulli}(p)
$$

The probability $p$ to belong to the first cluster commes from a logistic regression:

$$
p = \operatorname{logistic}(\Phi\,\boldsymbol{w})
$$

where $w$ is a vector of weights whose priors are a normal prior centered on 0:

$$
\boldsymbol{w} \sim \operatorname{Normal}(0, \sigma)
$$

And $\Phi$ is the matrix that contains the data, so each row $\Phi_{i,:}$ is the vector $\left[1, X_0^i, X_1^i\right]$

In [None]:
Phi = jnp.c_[jnp.ones(num_points)[:, None], X]
N, M = Phi.shape


def sigmoid(z):
    return jnp.exp(z) / (1 + jnp.exp(z))


def log_sigmoid(z):
    return z - jnp.log(1 + jnp.exp(z))

def logprior(w, alpha=1.0):
    prior_term = alpha * w @ w / 2
    return -prior_term
    
def loglikelihood(w, alpha=1.0):
    """The log-probability density function of the posterior distribution of the model."""
    log_an = log_sigmoid(Phi @ w)
    an = Phi @ w
    log_likelihood_term = y * log_an + (1 - y) * jnp.log(1 - sigmoid(an))
    return log_likelihood_term.sum()
    
def logdensity_fn(w, alpha=1.0):
    return logprior(w,alpha) + loglikelihood(w,alpha)

In [None]:
from sklearn.linear_model import LogisticRegression

## Posterior Sampling

We use `blackjax`'s Random Walk RMH kernel to sample from the posterior distribution.

In [None]:
rng_key, init_key = jax.random.split(rng_key)

w0 = jax.random.multivariate_normal(init_key, 0.1 + jnp.zeros(M), jnp.eye(M))
rmh = blackjax.rmh(logdensity_fn, blackjax.mcmc.random_walk.normal(jnp.ones(M) * 0.05))
initial_state = rmh.init(w0)

def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

rng_key, sample_key = jax.random.split(rng_key)
states = inference_loop(sample_key, rmh.step, initial_state, 5_000)

Trace display:

In [None]:
burnin = 300

fig, ax = plt.subplots(1, 3, figsize=(12, 2))
for i, axi in enumerate(ax):
    axi.plot(states.position[:, i])
    axi.set_title(f"$w_{i}$")
    axi.axvline(x=burnin, c="tab:red")
plt.show()

In [None]:
burnin = 300
chains = states.position[burnin:, :]
nsamp, _ = chains.shape

# Classic SMC

In [None]:
import jax.numpy as jnp
import numpy as np

from blackjax import adaptive_tempered_smc
from blackjax.smc import resampling, extend_params
from blackjax.smc.inner_kernel_tuning import StateWithParameterOverride
from blackjax.smc.tempered import TemperedSMCState
import jax
from jax import numpy as jnp
from datetime import date
import numpy as np
import pandas as pd
import functools
from jax.scipy.stats import multivariate_normal
from blackjax import additive_step_random_walk, inner_kernel_tuning
from blackjax.mcmc.random_walk import normal
from blackjax.smc.tuning.from_particles import (
    particles_covariance_matrix
)

n_predictors = 3
def initial_particles_multivariate_normal(key, n_samples):
    return jax.random.multivariate_normal(
        key, jnp.zeros(n_predictors) + 0.1, jnp.eye(n_predictors), (n_samples,)
    )

In [None]:
n_particles = 20000
key = jax.random.PRNGKey(10)
key, initial_particles_key, iterations_key = jax.random.split(key, 3)
initial_particles = initial_particles_multivariate_normal(initial_particles_key, n_particles)
initial_parameter_value = extend_params({"cov": particles_covariance_matrix(initial_particles)})


def mcmc_parameter_update_fn(state: TemperedSMCState, info):
    sigma_particles = particles_covariance_matrix(state.particles) * 0.75
    return extend_params({"cov":sigma_particles})

def step_fn(key, state, logdensity, cov):
    return blackjax.rmh(logdensity, normal(cov)).step(key, state)


kernel_tuned_proposal = inner_kernel_tuning(
        logprior_fn=logprior,
        loglikelihood_fn=loglikelihood,
        mcmc_step_fn=step_fn,
        mcmc_init_fn=blackjax.rmh.init,
        resampling_fn=resampling.systematic,
        smc_algorithm=adaptive_tempered_smc,
        mcmc_parameter_update_fn=mcmc_parameter_update_fn,
        initial_parameter_value=initial_parameter_value,
        target_ess=0.5,
        num_mcmc_steps=5,
)

from blackjax.smc.base import SMCInfo
def loop(kernel, rng_key, initial_state):
    normalizing_constant = jnp.zeros((1000))

    def cond(carry):
        _, state, _ = carry
        return state.sampler_state.lmbda < 1

    def body(carry):
        i, state, op_key = carry
        op_key, subkey = jax.random.split(op_key, 2)
        state, info = kernel(subkey, state)
        normalizing_constant.at[i].set(info.log_likelihood_increment)
        return i + 1, state, op_key

    def f(initial_state, key):
        total_iter, final_state, _ = jax.lax.while_loop(
            cond, body, (0, initial_state, key)
        )
        return total_iter, final_state

    total_iter, final_state = f(initial_state, rng_key)
    return total_iter, final_state, normalizing_constant

In [None]:
total_steps, final_state, normalizing_constant = loop(kernel_tuned_proposal.step, iterations_key, kernel_tuned_proposal.init(initial_particles))

In [None]:
np.exp(normalizing_constant[:total_steps])

In [None]:
particles = final_state.sampler_state.particles

In [None]:
final_state.sampler_state.weights

In [None]:
burnin = 300

fig, ax = plt.subplots(1, 3, figsize=(12, 2))
for i, axi in enumerate(ax):
    axi.hist(states.position[burnin:, i])
    axi.hist(particles[:, i])
    axi.set_title(f"$w_{i}$")
plt.show()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 2))
for i, axi in enumerate(ax):
    axi.hist(particles[:, i])
    axi.set_title(f"$w_{i}$")
plt.show()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 2))
for i, axi in enumerate(ax):
    axi.hist(initial_particles[:, i])
    axi.set_title(f"$w_{i}$")
plt.show()

In [None]:
def predict(x, w):
    return sigmoid(x@w)
    

In [None]:
pred=(predict(Phi,np.mean(particles, axis=0))>0.5).astype(int)

In [None]:
pred2=(predict(Phi,np.mean(states.position, axis=0))>0.5).astype(int)

In [None]:
import sklearn
sklearn.metrics.confusion_matrix(y, pred)

In [None]:
sklearn.metrics.confusion_matrix(y, pred2)

In [None]:
def posterior_predictive_plot(samples):
    xmin, ymin = X.min(axis=0) - 0.1
    xmax, ymax = X.max(axis=0) + 0.1
    step = 0.1
    Xspace = jnp.mgrid[xmin:xmax:step, ymin:ymax:step]
    _, nx, ny = Xspace.shape
    
    # Compute the average probability to belong to the first cluster at each point on the meshgrid
    Phispace = jnp.concatenate([jnp.ones((1, nx, ny)), Xspace])
    Z_mcmc = sigmoid(jnp.einsum("mij,sm->sij", Phispace, samples))
    Z_mcmc = Z_mcmc.mean(axis=0)
    
    plt.contourf(*Xspace, Z_mcmc)
    plt.scatter(*X.T, c=colors)
    plt.xlabel(r"$X_0$")
    plt.ylabel(r"$X_1$");

In [None]:
posterior_predictive_plot(chains)

In [None]:
posterior_predictive_plot(particles)

# Waste-Free SMC

In [None]:
import importlib
importlib.reload(blackjax)
from blackjax.smc.waste_free import waste_free_smc

waste_free_smc_kernel = inner_kernel_tuning(
        logprior_fn=logprior,
        loglikelihood_fn=loglikelihood,
        mcmc_step_fn=step_fn,
        mcmc_init_fn=blackjax.rmh.init,
        resampling_fn=resampling.systematic,
        smc_algorithm=adaptive_tempered_smc,
        mcmc_parameter_update_fn=mcmc_parameter_update_fn,
        initial_parameter_value=initial_parameter_value,
        target_ess=0.5,
        num_mcmc_steps=None,
        update_strategy=waste_free_smc(n_particles,10)
)

In [None]:
total_steps_waste_free, final_state_waste_free, normalizing_constant_waste_free = loop(waste_free_smc_kernel.step, iterations_key, waste_free_smc_kernel.init(initial_particles))

In [None]:
posterior_predictive_plot(final_state_waste_free.sampler_state.particles)

In [None]:
particles_waste_free = final_state_waste_free.sampler_state.particles

In [None]:

fig, ax = plt.subplots(1, 3, figsize=(12, 2))
for i, axi in enumerate(ax):
    axi.hist(chains[:,i])
    axi.hist(particles[:, i])
    axi.hist(particles_waste_free[:, i])
    axi.set_title(f"$w_{i}$")
plt.show()

In [None]:
 final_state_waste_free.sampler_state

In [None]:
final_state_waste_free