# Logistic regression and EF
From [this](https://blackjax-devs.github.io/sampling-book/models/logistic_regression.html) and [this](https://blackjax-devs.github.io/sampling-book/models/LogisticRegressionWithLatentGaussianSampler.html).

## Prepare the data and constants

In [None]:
import jax
import numpy as np
import matplotlib.pyplot as plt
from datetime import date

rnd_state = int(date.today().strftime("%Y%m%d"))
rng_key = jax.random.key(rnd_state)

import jax.numpy as jnp
from sklearn.datasets import make_blobs

import blackjax

# Variance of the Gaussian prior.
# Smaller sigmas mean higher regularisation, i.e. we give more
# importance to the prior and a bit less to the likelihood (the
# data that is).
SIGMA = 2

## MCMC
# variance of the proposal distribution (Gaussian) in the Metropolis
# MCMC algorithm. This is proportional to the distance travelled in
# sampling the posterior, but TAU too high results in "sticky" chains.
TAU = 0.005
CHAINS = 5_000
BURNIN = 500

## DIMENSIONS
# Datapoints
N = 50
# Dimesions of the problem (dimensions of w not X)
D = 2 + 1  # need to add the bias term

# initialise the weigths
rng_key, init_key = jax.random.split(rng_key)
w0 = jax.random.multivariate_normal(init_key, -jnp.ones(D), jnp.eye(D))

In [None]:
## PREDICTIONS
def construct_prediction_points_mesh(x_train):
    """Prediction points constructed using a mesh grid."""
    xmin, ymin = x_train.min(axis=0) - 0.1
    xmax, ymax = x_train.max(axis=0) + 0.1
    step = 0.05
    return np.mgrid[xmin:xmax:step, ymin:ymax:step]


def push_predictions_forward(x_predict, weights_posterior):
    _, nx, ny = x_predict.shape
    phi_predict = jnp.concatenate([jnp.ones((1, nx, ny)), x_predict])
    # perform push forward: sigma(y * Phi * w) with y = +1 or -1 for 2nd or 1st class
    return jax.nn.sigmoid(jnp.einsum("dxy,sd->sxy", phi_predict, weights_posterior))


def plot_prediction_axis(x_predict, x_train, probs_2nd_class, ax):
    ax.set_xlabel(r"$X_0$")
    ax.set_ylabel(r"$X_1$")
    cf = ax.contourf(*x_predict, probs_2nd_class)
    cbar = fig.colorbar(cf, ax=ax)
    cbar.set_label("P(y = +1 | x)")
    ax.scatter(*x_train.T, c=colors)


def plot_predictions_2nd_class(x_train, weights_posterior, ax):
    """Plot the avg probability to belong to the 2nd cluster at each point in a grid."""
    x_predict = construct_prediction_points_mesh(x_train)
    probs_2nd_class = push_predictions_forward(x_predict, weights_posterior)
    # avg all the MCMC samples (only one chain here for now)
    plot_prediction_axis(x_predict, x_train, probs_2nd_class.mean(axis=0), ax)


def plot_one_prediction_2nd_class(x_train, weights_posterior, sample2plot, ax):
    x_predict = construct_prediction_points_mesh(x_train)
    probs_2nd_class = push_predictions_forward(x_predict, weights_posterior)
    weights, probs_2nd_class = (
        weights_posterior[sample2plot, :],
        probs_2nd_class[sample2plot, :],
    )
    plot_prediction_axis(x_predict, x_train, probs_2nd_class, ax)
    ax.set_title(f"Weights {weights}")

In [None]:
## DATA
X, y = make_blobs(
    N,
    2,
    centers=((-3, -3), (3, 3)),
    cluster_std=1.5,
    random_state=rnd_state,
)
# this is convenient for the likelihood formulation
y = jnp.where(y, y, -1)  # convert into jax arrays
colors = ["tab:red" if el > 0 else "tab:blue" for el in y]
plt.scatter(*X.T, edgecolors=colors, c="none")  # transpose to unpack last axis (2 dim)
plt.xlabel(r"$X_0$")
plt.ylabel(r"$X_1$")
plt.show()
assert np.nonzero(y)[0].shape == y.shape

## The model

The labels $y$ are drawn from a Bernouilli distribution with probability $p$
$$y \sim \mathrm{Bern}(p) $$

and $p$ is computed using an exponential family model following a logistic regression likelihood
$$p = \text{log}\left( \Phi(X) w \right)$$
with $w$ being the weigts and $\Phi(X)$ a function of the data $X$.

**Dimensions.** There is one label for each 2D datapoint.
1. $X$ has dimensions NxD
2. $\Phi(X)$ has dimensions NxD
3. $w$ has dimensions D
4. $p$ has dimensions N

and $D=2$.

### Exponential family and NN
We want to predict the class $c$ for $X$, using the Bayes' theorem and assuming a likelihood being of the exponential family.

Assumptions:
1. **Bayes' theorem.** Using Bayes' theorem, we assign a finite amount of mass (probability) to both classes $c_1$ and $c_2$ and write down the posterior distribution for $w_1$
$$ p(c = c_1 | X) = \frac{p(X | c_1) p(c_1)}{p(X | c_1) p(c_1) + p(X | c_2) p(c_2) }.$$
If we define $a_1 = \mathrm{log}\,\frac{p(X | c_1) p(c_1)}{p(X | c_2) p(c_2)}$ then the posterior becomes
$$ p(c = c_1 | X) = \frac{1}{1 + \mathrm{exp}(-a_1)} = \sigma(a)$$
with $\sigma(.)$ called the logistic sigmoid function and $a_1$ is called the logit function for the class $c_1$ (aka activation).
Note that $a_1$ represents a log ratio of the two possible classes.

2. **EF.** Assume now that the probability $p(X | c = c_1)$ is an exponential family (EF), meaning:
$$p(X | c = c_1) = \text{exp}\left[ \Phi(X) w_1 - g(w_1) + h(X) \right]$$
with $ \Phi(X) $ being the sufficient statistics, $g(w_1)$ the log of normalisation function and $h(x)$ the base measure, the latter assumed to be zero.

Given 1 and 2, what is the log odd $a_1$?
Recall that $a_1 = \mathrm{log}\,\frac{p(X | c_1) p(c_1)}{p(X | c_2) p(c_2)}$ from assumption 1. Given assumption 2 (EF), $a_1$ becomes 
$$ a_1 = \Phi(X) \left( w_1 - w_2 \right ) + g(w_1) - g(w_2) + \text{log}\,p(c_1) - \text{log}\,p(c_2)$$
and thus the input of the logistic function $a$ will be a **linear classifier** in the natural parameters $w$.
Therefore, we get $\sigma(\text{linear } a_1)$, where $a_1$ is a product between some weights and the features $\Phi$ which describe the exponential famility we are considering.

Note that this is the same idea of the last layer of a neural network with cross-entropy loss! Indeed, 
$$ a_1 = \Phi(X) \left( w_1 - w_2 \right ) + g(w_1) - g(w_2) + \text{log}\,p(c_1) - \text{log}\,p(c_2)$$
becomes
$$\Phi(X)\theta + b$$
with all the terms in $w$ grouped into $\theta$ and the bias representing the priors.
The big difference is in the choice of $\Phi$: in EF, $\Phi$ are the summary statistics and ensure that the class distribution is a probability distribution. In deep learning, this is not the case: can pick any $\Phi$ don't need to be normalised that they sum to 1. Actually, $\Phi$ comes from the lower layers, not normalised.

This holds also in multiple classes, with the logistic function becoming a softmax and $a_k = \Phi(X) w_k + g(w_k) + \text{log}\, p(c_k)$.
There is one linear classifier for each class.

In [None]:
Phi = jnp.c_[jnp.ones((N, 1)), X]
assert Phi.shape == (N, D)

In [None]:
def logdensity(w, alpha=1.0, sigma=SIGMA):
    """The log-probability density function of the posterior distribution of the model up to a norm constant."""
    log_prior = -alpha * w @ w / (2 * sigma**2)  # log Gaussian prior
    logits = Phi @ w  # linear comb of Gaussian is Gaussian
    log_likelihood = jax.nn.log_sigmoid(y * logits)

    return log_prior + log_likelihood.sum()

Next different algorithms are presented to find the optimal hyperplane separating the datapoints into two classes.
Compared to a Gaussian regression, logistic regression has no closed form solution and approximations are then required.

However, has fewer parameters (TODO from p X in Bishop): TODO.

The following are considered here:
1. Iterative reweighted least squares aka Newton-Raphson (this is the implementation in `R`?)
2. Laplace approximation
3. MacKay's approximation
4. Random-walk Metropolis Hasting kernel MCMC

This should work also with a Poisson likelihood regression and not only classification (Bernouilli).

## Algo 1: Iterative reweighted least squares

## Algo 4: Random-walk Metropolis MCMC
Use a Normal distribution with covariance $\tau^2 I_M$, where $\tau$ is then a "step" size.
The average distance travelled is proportional to $\tau$ and the number of dimensions.

In [None]:
rmh = blackjax.rmh(logdensity, blackjax.mcmc.random_walk.normal(jnp.ones(D) * TAU))
initial_state = rmh.init(w0)

In [None]:
# scan :: (f, (c, [a])) -> (c, [b])
# f :: (c, a) -> (c, b)
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit  # this is not necessary as scan compiles the fn
    def one_step(state, rng_key):  # f :: (c, a) -> (c, b)
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    # (f, (c, [a])) -> (c, [b])
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
    # note that b is a PyTree (per-step output) so JAX will stack each leaf
    # across time, which looks like a transpose from “sequence of pytrees”
    # to “pytree of sequences”. So instead of
    # [(state0, info0), (state1, info1), ...] (a “list/array of tuples”) JAX
    # returns (after stacking): ([state0, state1, ...], [info0, info1, ...])

    return states, infos

In [None]:
rng_key, sample_key = jax.random.split(rng_key)
states, infos = inference_loop(sample_key, rmh.step, initial_state, CHAINS)
rmh_weights = states.position[BURNIN:, :]
print(infos.acceptance_rate.mean())
infos.acceptance_rate

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 2))
for i, axi in enumerate(ax):
    # axi.plot(np.c_[np.ones_like(states.position[:, i]) * w0[i], states.position[:, i]])
    axi.plot(states.position[:, i])
    # as we concat first ones and then data, the first weight is the bias
    axi.set_title(f"$w_{i}$" if i else "b")
    axi.axvline(x=BURNIN, c="tab:red")
plt.show()

fig, ax = plt.subplots(1, 1, layout="tight")
plot_predictions_2nd_class(X, rmh_weights, ax)
plt.show()

fig, ax = plt.subplots(1, 1, layout="tight")
sample2plot = 20
plot_one_prediction_2nd_class(X, rmh_weights, sample2plot, ax)
print(states.logdensity[BURNIN + sample2plot])
print(states.logdensity[BURNIN:].mean())
plt.show()