In [None]:
try:
    import tinygp
except ImportError:
    %pip install -q tinygp

try:
    import numpyro
except ImportError:
    %pip uninstall -y jax jaxlib
    %pip install -q numpyro jax jaxlib

try:
    import arviz
except ImportError:
    %pip install arviz

(classification)=

# Classification

In this tutorial, we demonstrate Classification task using GP. In case of classification, the test predictions are the class probabilities.

Instead of target values being in real space, the target values are in discrete values corresponding to the respective classes. In GP Classification, we use a "link function" to link the real output of the GP to the probabilistic distribution of the classes. We use a GP prior on the latent function which is then squashed down using the link function.

In [None]:
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np

In [None]:
import numpyro
import numpyro.distributions as dist

We start with binary classification on a simple XOR Dataset.  
As the classification is binary, the discrete target values are in $[0, 1]$ for Class 1 and Class 2 respectively.

We model the binary class probabilities with Bernoulli distribution with parameter `p`, the probability of class $2$.  
The `link function` we use for squashing the GP prior to range $[0, 1]$ for the `p` is the Sigmoid function.

For binary classification, we build on our model from GP regression and use:

$$ f \sim \mathcal{G P}\left(0, \mathbf{K}_{f}\left(x, x^{\prime}\right)\right) $$

with a sigmoid likelihood

$$ p(y=1 \mid f)=\operatorname{Sigmoid}(f) $$

or

$$ y \sim \text { Bernoulli(Sigmoid(f)) } $$

In [None]:
X = jax.random.normal(jax.random.PRNGKey(1234), (200, 2))
y = jnp.logical_xor(X[:, 0] > 0, X[:, 1] > 0)

c = plt.cm.get_cmap("Paired")(y)

fig = plt.figure()
ax = plt.subplot()

plt.scatter(
    X[:, 0][y == 0],
    X[:, 1][y == 0],
    s=30,
    c=c[y == 0],
    cmap=plt.cm.Paired,
    edgecolors=(0, 0, 0),
    label=f"Class 1",
)
plt.scatter(
    X[:, 0][y == 1],
    X[:, 1][y == 1],
    s=30,
    c=c[y == 1],
    cmap=plt.cm.Paired,
    edgecolors=(0, 0, 0),
    label=f"Class 2",
)
plt.gca().set_aspect("equal")
plt.axhline(0, color="k")
plt.axvline(0, color="k")
plt.xlabel(r"$x_{1}$")
plt.ylabel(r"$x_{2}$")

box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

_ = plt.title("XOR Dataset")

In [None]:
xs = jnp.linspace(-2, 2, num=100)
ys = jnp.linspace(-2, 2, num=100)

xx, yy = jnp.meshgrid(xs, ys)
xx = xx.T
yy = yy.T
true_X = jnp.vstack((xx.ravel(), yy.ravel())).T
true_y = jnp.logical_xor(true_X[:, 0] > 0, true_X[:, 1] > 0)

In [None]:
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

As the likelihood is non-Gaussian we need to use Markov Chain Monte Carlo (MCMC) or Variational Inference (VI) to marginalize numerically.  
This follows from the example in {ref}`likelihoods-mcmc` and {ref}`modeling-numpyro`.

In [None]:
import jax
import jax.numpy as jnp
from flax.linen.initializers import zeros
import numpyro
import numpyro.distributions as dist
from tinygp import kernels, GaussianProcess

jax.config.update("jax_enable_x64", True)


def model(x, y=None):
    # The parameters of the GP regression
    mean = numpyro.param("mean", jnp.zeros(()))
    sigma = numpyro.param("sigma", jnp.ones(()))
    ell = numpyro.param("ell", jnp.ones(()))

    # Set up the kernel and GP objects
    kernel = (sigma**2) * kernels.ExpSquared(scale=ell)
    gp = GaussianProcess(kernel, x, diag=1e-5, mean=mean)

    gp_out = numpyro.sample("gp_out", gp.numpyro_dist())
    # Squashing the GP regression real output to the [0, 1] range
    # using sigmoid as the link function
    p = sigmoid(gp_out)

    # Finally our observation model is Bernoulli distribution
    # where 'p' is the probability of Class 2
    numpyro.sample("obs", dist.Bernoulli(probs=p), obs=y)

    if y is not None:
        # Posterior Inference on true_X input values
        numpyro.deterministic("pred", gp.condition(gp_out, true_X).gp.loc)


nuts_kernel = numpyro.infer.NUTS(model, target_accept_prob=0.8)
mcmc = numpyro.infer.MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=False,
)
rng_key = jax.random.PRNGKey(55873)

In [None]:
%%time
# run the MCMC
mcmc.run(
    rng_key,
    X,
    y=y,
)
samples = mcmc.get_samples()
pred = samples["pred"].block_until_ready()  # Blocking to get timing right

As MCMC is an iterative method, we need to check the convergence.

In [None]:
import arviz as az

data = az.from_numpyro(mcmc)
az.summary(
    data, var_names=[v for v in data.posterior.data_vars if v != "pred"]
)

In the above diagnostic report, the `R-hat` is less than $1.05$ for all parameters, so the method had converged and we are good to proceed with the samples.

Now we look at the train accuracy.  
From the samples, we get `gp_out`, which is the output of the GP regression model on the train points. To convert to class probabilities, we use the Sigmoid as the link function.
Finally we end up with probabilites of Class 2. For deterministically assigning classes, $p > 0.5$ is assigned Class 2 and $p <= 0.5$ is assigned Class 1. 

In [None]:
q = np.percentile(samples["gp_out"], [5, 25, 50, 75, 95], axis=0)
y_hat = sigmoid(q[2]) > 0.5

print(f"Train Accuracy: {(y_hat==y).sum()*100/(len(y)) :0.2f}%")

We see that our model did a reasonable job and got a good accuracy on the train data.      
We now visualize the predictions on 2D grid points `true_X`. `pred` are the GP regression model output samples on 2D grid points.  

In [None]:
q = np.percentile(samples["pred"], [5, 25, 50, 75, 95], axis=0)
true_y_hat = sigmoid(q[2]) > 0.5
print(f"Test Accuracy: {(true_y_hat==true_y).sum()*100/(len(true_y)) :0.2f}%")

In [None]:
def plot_pred_2d(arr, xx, yy, contour=False, ax=None, title=None):
    if ax is None:
        fig, ax = plt.subplots()
    image = ax.imshow(
        arr,
        interpolation="nearest",
        extent=(xx.min(), xx.max(), yy.min(), yy.max()),
        aspect="equal",
        origin="lower",
        cmap=plt.cm.PuOr_r,
    )
    if contour:
        contours = ax.contour(
            xx,
            yy,
            sigmoid(q[2]).reshape(xx.shape),
            levels=[0.5],
            linewidths=2,
            colors=["k"],
        )

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)

    ax.get_figure().colorbar(image, cax=cax)
    if title:
        ax.set_title(title)

In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(12, 4))
plot_pred_2d(
    q[2].reshape(xx.shape),
    xx,
    yy,
    ax=ax[0],
    title=r"$f \sim \mathcal{G P}\left(0, \mathbf{K}_{f}\left(x, x^{\prime}\right)\right)$",
)
plot_pred_2d(
    sigmoid(q[2]).reshape(xx.shape),
    xx,
    yy,
    ax=ax[1],
    title="p(y=1|f) = Sigmoid(f)",
    contour=True,
)
plot_pred_2d(
    true_y_hat.reshape(xx.shape),
    xx,
    yy,
    ax=ax[2],
    title="Predictions (y) ~ Bernoulli(p(y=1|f))",
)

fig.tight_layout()

Similarly, multi-class classification can be done using one-vs-one or one-vs-rest binary classifiers.