# Probabilistic Machine Learning: Lecture 15 - Gaussian Process Classification (Laplace Approximation) with JAX and Plotly

#### Introduction

Welcome to Lecture 15 on Gaussian Process Classification! This notebook builds directly on our previous discussion (Lecture 14) about Logistic Regression and the initial introduction to the Laplace Approximation. While Lecture 14 focused on understanding the problem of non-Gaussian posteriors and finding the Maximum A Posteriori (MAP) estimate using Newton's method, this lecture dives deeper into the *full* Laplace Approximation.

We will explore the Hessian matrix in more detail, understand its role in quantifying uncertainty, and see how to compute predictive probabilities for new data points. All computations will leverage **JAX** for efficiency, and visualizations will be interactive using **Plotly**.

#### 1. Recap: Gaussian Process Classification & Non-Gaussian Posterior

As established in Lecture 14, Gaussian Process Classification models binary outputs (e.g., $y \in \{-1, +1\}$) by transforming a latent function $f(x)$ through a sigmoid link function. We place a Gaussian Process (GP) prior over this latent function:

$$p(f) = \mathcal{GP}(f; m, k)$$

where $m$ is the mean function and $k$ is the covariance (kernel) function. The likelihood of an observation $y$ given the latent function value $f_x$ is given by the logistic sigmoid:

$$p(y | f_x) = \sigma(y f_x) = \begin{cases} \sigma(f_x) & \text{if } y = 1 \\ 1 - \sigma(f_x) & \text{if } y = -1 \end{cases}$$

The core challenge arises when we try to compute the posterior distribution $p(f_X | Y)$ over the latent function values at the training points $X$. Because the likelihood function (sigmoid) is non-Gaussian, the resulting posterior is also non-Gaussian, making exact analytical inference intractable. The log posterior is given by:

$$\log p(f_X | Y) = \sum_{i=1}^n \log \sigma(y_i f_{x_i}) - \frac{1}{2}(f_X - m_X)^T K_{XX}^{-1} (f_X - m_X) + \text{const.}$$

The non-elliptical contours of this log posterior (as seen in Lecture 14, Slide 21-23, and Lecture 15, Slide 4-6) visually represent this intractability.

Let's start by setting up our imports and common utility functions.

In [None]:
import jax.numpy as jnp
import jax
import numpy as np  # Still using numpy for data generation
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Set JAX to use 64-bit floats for numerical stability, matching numpy's default
jax.config.update("jax_enable_x64", True)


# --- Utility Functions ---
def sigmoid(f):
    """Logistic sigmoid function."""
    return 1 / (1 + jnp.exp(-f))


def rbf_kernel(X1, X2, length_scale=1.0):
    """Radial Basis Function (RBF) kernel."""
    sqdist = jnp.sum(X1**2, 1)[:, None] + jnp.sum(X2**2, 1) - 2 * jnp.dot(X1, X2.T)
    return jnp.exp(-0.5 * (1 / length_scale**2) * sqdist)


def generate_data(type="separable", n_samples=100):
    """Generates synthetic 2D classification data (using numpy for random generation)."""
    np.random.seed(42)  # Ensure reproducibility for data generation
    if type == "separable":
        mean1 = [-1, 0.5]
        cov1 = [[0.5, 0.2], [0.2, 0.5]]
        data1 = np.random.multivariate_normal(mean1, cov1, n_samples // 2)
        labels1 = np.ones(n_samples // 2) * -1
        mean2 = [1, -0.5]
        cov2 = [[0.5, 0.2], [0.2, 0.5]]
        data2 = np.random.multivariate_normal(mean2, cov2, n_samples // 2)
        labels2 = np.ones(n_samples // 2)
    elif type == "overlapping":
        mean1 = [-0.5, 0.5]
        cov1 = [[1.0, 0.5], [0.5, 1.0]]
        data1 = np.random.multivariate_normal(mean1, cov1, n_samples // 2)
        labels1 = np.ones(n_samples // 2) * -1
        mean2 = [0.5, -0.5]
        cov2 = [[1.0, 0.5], [0.5, 1.0]]
        data2 = np.random.multivariate_normal(mean2, cov2, n_samples // 2)
        labels2 = np.ones(n_samples // 2)
    elif type == "intermingled":
        r1 = np.random.rand(n_samples // 2) * 2
        theta1 = np.random.rand(n_samples // 2) * 2 * np.pi
        data1 = np.array(
            [
                r1 * np.cos(theta1) + np.random.randn(n_samples // 2) * 0.2,
                r1 * np.sin(theta1) + np.random.randn(n_samples // 2) * 0.2,
            ]
        ).T
        labels1 = np.ones(n_samples // 2) * -1
        r2 = np.random.rand(n_samples // 2) * 2 + 1.5
        theta2 = np.random.rand(n_samples // 2) * 2 * np.pi
        data2 = np.array(
            [
                r2 * np.cos(theta2) + np.random.randn(n_samples // 2) * 0.2,
                r2 * np.sin(theta2) + np.random.randn(n_samples // 2) * 0.2,
            ]
        ).T
        labels2 = np.ones(n_samples // 2)
        all_data = np.vstack((data1, data2))
        all_labels = np.hstack((labels1, labels2))
        perm = np.random.permutation(n_samples)
        data1 = all_data[all_labels == -1]
        data2 = all_data[all_labels == 1]
    X = np.vstack((data1, data2))
    y = np.hstack((labels1, labels2))
    return X, y


def plot_data_plotly(X, y, title="", fig=None, row=None, col=None):
    """Plots 2D classification data using Plotly."""
    if fig is None:
        fig = go.Figure()

    # Convert JAX arrays to NumPy for Plotly
    X_np = np.asarray(X)
    y_np = np.asarray(y)

    fig.add_trace(
        go.Scatter(
            x=X_np[y_np == -1, 0],
            y=X_np[y_np == -1, 1],
            mode="markers",
            marker=dict(color="maroon", symbol="circle"),
            name="Class -1",
            showlegend=True,
        ),
        row=row,
        col=col,
    )

    fig.add_trace(
        go.Scatter(
            x=X_np[y_np == 1, 0],
            y=X_np[y_np == 1, 1],
            mode="markers",
            marker=dict(
                color="skyblue", symbol="circle", line=dict(width=1, color="skyblue")
            ),
            name="Class +1",
            showlegend=True,
        ),
        row=row,
        col=col,
    )

    fig.update_layout(title_text=title, title_x=0.5)
    fig.update_xaxes(title_text="$x_1$", range=[-4, 4], row=row, col=col)
    fig.update_yaxes(
        title_text="$x_2$",
        range=[-4, 4],
        scaleanchor="x",
        scaleratio=1,
        row=row,
        col=col,
    )

    return fig


#### 2. The Laplace Approximation Revisited

To overcome the intractability, we use approximation methods. The **Laplace Approximation** provides a local Gaussian approximation to a non-Gaussian distribution $p(\theta)$. The steps are:

1.  **Find the MAP estimate (mode)**: Identify $\hat{\theta} = \arg \max_{\theta} \log p(\theta)$. At this point, the gradient is zero: $\nabla \log p(\hat{\theta}) = 0$.
2.  **Second-order Taylor Expansion**: Approximate $\log p(\theta)$ around $\hat{\theta}$ using a Taylor series up to the second order:
    $$\log p(\theta) \approx \log p(\hat{\theta}) + \frac{1}{2}(\theta - \hat{\theta})^T \Psi (\theta - \hat{\theta})$$
    where $\Psi = \nabla\nabla^T \log p(\hat{\theta})$ is the Hessian matrix evaluated at $\hat{\theta}$.
3.  **Gaussian Approximation**: Define the Laplace approximation $q(\theta)$ as a Gaussian distribution with mean $\hat{\theta}$ and covariance matrix $-\Psi^{-1}$:
    $$q(\theta) = \mathcal{N}(\theta; \hat{\theta}, -\Psi^{-1})$$

For GP Classification, we apply this to the posterior $p(f_X | Y)$, approximating it as $q(f_X) = \mathcal{N}(f_X; \hat{f}, \hat{\Sigma})$, where $\hat{f}$ is the MAP estimate of $f_X$ and $\hat{\Sigma} = -(\nabla\nabla^T \log p(f_X | Y)|_{f_X=\hat{f}})^{-1}$.

The Laplace approximation is computationally efficient and often works well for logistic regression due to the concavity of the log posterior, which implies a single global maximum.

Let's define the `log_posterior` function that we will maximize.

In [None]:
def log_posterior(f_X, K_XX, y_labels):
    """
    Calculates the log posterior of f_X given observations y_labels.
    Assumes mean_X = 0.
    """
    log_likelihood_term = jnp.sum(-jnp.log(1 + jnp.exp(-y_labels * f_X)))
    log_prior_term = -0.5 * f_X.T @ jnp.linalg.solve(K_XX, f_X)
    return log_likelihood_term + log_prior_term


#### 3. Deriving the Hessian for GP Classification

A key component of the Laplace Approximation is the Hessian matrix. Let's derive it for our GP classification log posterior. Recall the log posterior (assuming $m_X = 0$ for simplicity):

$$L(f_X) = \log p(f_X | Y) = \sum_{i=1}^n \log \sigma(y_i f_{x_i}) - \frac{1}{2}f_X^T K_{XX}^{-1} f_X$$

First, let's re-state the gradient (first derivative) with respect to $f_X$. The $j$-th element of the gradient vector is:

$$\frac{\partial L}{\partial f_{x_j}} = \left( \frac{y_j + 1}{2} - \sigma(f_{x_j}) \right) - (K_{XX}^{-1} f_X)_j$$

Now, we need the Hessian (second derivative matrix) $\mathbf{H} = \nabla\nabla^T L(f_X)$. The elements of the Hessian are $H_{jk} = \frac{\partial^2 L}{\partial f_{x_j} \partial f_{x_k}}$.

Let's look at the two terms separately:

1.  **Hessian of the log-likelihood term**: $\sum_{i=1}^n \log \sigma(y_i f_{x_i})$
    The second derivative of $\log \sigma(z)$ with respect to $z$ is $-\sigma(z)(1-\sigma(z))$.
    So, for the $j$-th diagonal element, considering $z = y_j f_{x_j}$ and $y_j^2=1$:
    $$\frac{\partial^2 \log \sigma(y_j f_{x_j})}{\partial f_{x_j}^2} = y_j^2 \left( -\sigma(y_j f_{x_j})(1 - \sigma(y_j f_{x_j})) \right) = -\sigma(y_j f_{x_j})(1 - \sigma(y_j f_{x_j}))$$
    Since $\sigma(y_j f_{x_j})(1 - \sigma(y_j f_{x_j})) = \sigma(f_{x_j})(1 - \sigma(f_{x_j}))$, this term is always non-positive. Let $w_i = \sigma(f_{x_i})(1 - \sigma(f_{x_i}))$. We can represent this part of the Hessian as a diagonal matrix $-W$, where $W = \text{diag}(w_1, ..., w_n)$. Note that $0 < w_i < 1$.

2.  **Hessian of the log-prior term**: $-\frac{1}{2}f_X^T K_{XX}^{-1} f_X$
    The second derivative of this quadratic form with respect to $f_X$ is simply $-K_{XX}^{-1}$.

Combining these, the full Hessian matrix is:

$$\mathbf{H} = \nabla\nabla^T \log p(f_X | Y) = -W - K_{XX}^{-1} = -(W + K_{XX}^{-1})$$

Since $W$ is a diagonal matrix with positive entries and $K_{XX}^{-1}$ is positive definite (as $K_{XX}$ is a covariance matrix), the matrix $W + K_{XX}^{-1}$ is positive definite. Therefore, $-(W + K_{XX}^{-1})$ is negative definite, which confirms that the log posterior is a **concave function**. This concavity is crucial because it guarantees that Newton's method will converge to the unique global maximum (the MAP estimate).

#### 4. Newton Optimization for Mode-Finding (Revisited) with JAX

Given the explicit form of the Hessian, we can now fully implement Newton's method to find the MAP estimate $\hat{f}$. Newton-Raphson is an efficient optimization algorithm that uses both the first and second derivatives of the objective function. The update rule is:

$$f^{new} = f^{old} - \mathbf{H}^{-1} \mathbf{g}$$

where $\mathbf{g}$ is the gradient and $\mathbf{H}$ is the Hessian. Substituting our derived forms:

$$f^{new} = f^{old} - (-(W + K_{XX}^{-1}))^{-1} \left( (\tilde{\mathbf{y}} - \mathbf{\pi}(f^{old})) - K_{XX}^{-1} f^{old} \right)$$

$$f^{new} = f^{old} + (W + K_{XX}^{-1})^{-1} \left( (\tilde{\mathbf{y}} - \mathbf{\pi}(f^{old})) - K_{XX}^{-1} f^{old} \right)$$

With JAX, we can automatically compute these gradients and Hessians, making our code cleaner and less prone to manual derivation errors. We will use `jax.grad` and `jax.hessian` transformations.

In [None]:
@jax.jit  # JIT compile the optimization step for performance
def newton_step(f_X, K_XX, y_labels):
    """Performs one Newton update step for maximizing the log posterior."""
    # We want to maximize log_posterior, so we minimize -log_posterior
    neg_log_posterior = lambda f: -log_posterior(f, K_XX, y_labels)

    # Compute gradient and Hessian of the *negative* log posterior
    grad_neg_log_post = jax.grad(neg_log_posterior)(f_X)
    hess_neg_log_post = jax.hessian(neg_log_posterior)(f_X)

    # Newton update: f_new = f_old - H_inv * grad
    delta_f = jnp.linalg.solve(hess_neg_log_post, grad_neg_log_post)
    f_X_new = f_X - delta_f
    return f_X_new


def find_map_f_newton_jax(
    X_train, y_train, kernel_func, length_scale, max_iter=100, tol=1e-5
):
    """
    Finds the MAP estimate hat_f using Newton's method with JAX autodiff.
    Returns hat_f and the kernel matrix K_XX.
    """
    n = X_train.shape[0]
    f_X = jnp.zeros(n)  # Initialize f_X with zeros
    K_XX = kernel_func(X_train, X_train, length_scale)
    K_XX += 1e-6 * jnp.eye(n)  # Add jitter for numerical stability
    y_labels_jax = jnp.array(y_train)

    print("Starting Newton's method with JAX...")
    for i in range(max_iter):
        f_X_new = newton_step(f_X, K_XX, y_labels_jax)
        change = jnp.linalg.norm(f_X_new - f_X)
        if change < tol:
            print(f"Converged in {i + 1} iterations. Final change: {change:.6f}")
            break
        f_X = f_X_new
    else:
        print(
            f"Newton's method did not converge after {max_iter} iterations. Final change: {change:.6f}"
        )
    return f_X, K_XX


#### 5. Numerical Stability and Efficient Computations (Advanced)

The direct inversion of $(W + K_{XX}^{-1})$ in the Newton step and for computing the posterior covariance can be numerically unstable, especially when $W$ has very small eigenvalues (e.g., when sigmoid outputs are very close to 0 or 1). Rasmussen & Williams (2006) propose a more stable approach using the matrix inversion lemma and a re-parameterization involving the matrix $B = I + W^{1/2} K W^{1/2}$.

From the matrix inversion lemma, we have:
$$
(W + K_{XX}^{-1})^{-1} = K_{XX} - K_{XX} W^{1/2} (I + W^{1/2} K_{XX} W^{1/2})^{-1} W^{1/2} K_{XX}
$$
Let $B = I + W^{1/2} K_{XX} W^{1/2}$. Then the inverse becomes $K_{XX} - K_{XX} W^{1/2} B^{-1} W^{1/2} K_{XX}$. This form is often more numerically stable as $B$ is symmetric positive definite with eigenvalues $\ge 1$.

While we won't implement this optimized form in detail for this introductory notebook, it's important to be aware that practical implementations often incorporate such numerical considerations for robustness.

#### 6. Computing Predictions with Laplace Approximation (Full Posterior)

Once we have the MAP estimate $\hat{f}$ and the approximate posterior covariance $\hat{\Sigma} = (W + K_{XX}^{-1})^{-1}$ (evaluated at $\hat{f}$), we can make predictions for new test points $x_*$. The approximate posterior distribution for the latent function $f_{x_*}$ at a new point $x_*$ is given by:

$$q(f_{x_*} | y) = \mathcal{N}(f_{x_*}; \mu_{q(f_{x_*}|y)}, \Sigma_{q(f_{x_*}|y)})$$

The mean of this predictive distribution (assuming $m_X = 0$):
$$\mu_{q(f_{x_*}|y)} = k_{x_*X} K_{XX}^{-1} \hat{f}$$

The variance of this predictive distribution is (from Lecture 15, Slide 17):
$$\Sigma_{q(f_{x_*}|y)} = k_{x_*x_*} - k_{x_*X} K_{XX}^{-1} k_{X x_*} + k_{x_*X} K_{XX}^{-1} \hat{\Sigma} K_{XX}^{-1} k_{X x_*}$$

A more compact and often numerically stable form for the variance, derived using the matrix inversion lemma (as shown on Slide 17), is:
$$\Sigma_{q(f_{x_*}|y)} = k_{x_*x_*} - k_{x_*X}(K_{XX} + W^{-1})^{-1}k_{X x_*}$$

This predictive distribution $q(f_{x_*} | y)$ is Gaussian. However, to get the predictive probability for the label, $p(y_*=1 | x_*, Y)$, we need to compute $E_{q(f_{x_*}|y)}[\sigma(f_{x_*})$. This integral is still analytically intractable. Common approximations include:

1.  **Plug-in Approximation**: $\hat{\pi}_{x_*} = \sigma(\mu_{q(f_{x_*}|y)})$. This is the simplest and what we used for plotting the decision boundary in Lecture 14.
2.  **MacKay's Approximation (1992)**: A more sophisticated approximation that accounts for the variance $s^2 = \Sigma_{q(f_{x_*}|y)}$:
    $$E[\sigma(f)] \approx \sigma\left(\frac{\mu}{\sqrt{1 + \frac{\pi}{8}s^2}}\right)$$
    This approximation is often more accurate than the plug-in method, especially when the variance is large.

Let's implement these prediction steps and visualize the results.

In [None]:
def predict_laplace_gp_classification_jax(
    X_test, X_train, hat_f, K_XX, kernel_func, length_scale
):
    """
    Computes the mean and variance of the latent function at test points
    using the Laplace Approximation, and then predictive probabilities (with JAX).
    """
    # 1. Compute W_hat at the MAP estimate hat_f
    pi_hat_f = sigmoid(hat_f)
    W_hat = jnp.diag(pi_hat_f * (1 - pi_hat_f))

    # 2. Compute the approximate posterior covariance Sigma_hat
    # Sigma_hat = (W_hat + K_XX^-1)^-1
    K_XX_inv = jnp.linalg.inv(K_XX)
    Sigma_hat = jnp.linalg.inv(W_hat + K_XX_inv)

    # 3. Compute kernel matrices for test points
    K_x_X_test = kernel_func(X_test, X_train, length_scale)
    K_test_test = kernel_func(X_test, X_test, length_scale)

    # 4. Mean of the approximate predictive latent function q(f_* | y)
    mu_latent_pred = K_x_X_test @ K_XX_inv @ hat_f

    # 5. Variance of the approximate predictive latent function q(f_* | y)
    # Using the more stable form from slide 17: k_x*x* - k_x*X @ (K_XX + W_hat^-1)^-1 @ k_X x*
    # Ensure W_hat_inv is computed robustly, avoiding division by zero if pi_hat_f is exactly 0 or 1
    w_diag = pi_hat_f * (1 - pi_hat_f)
    w_diag = jnp.where(
        w_diag < 1e-10, 1e-10, w_diag
    )  # Clamp small values to avoid division by zero
    W_hat_inv = jnp.diag(1 / w_diag)

    Sigma_latent_pred = K_test_test - K_x_X_test @ jnp.linalg.solve(
        (K_XX + W_hat_inv), K_x_X_test.T
    )

    # Extract diagonal for individual variances (s^2 for each test point)
    s2_latent_pred = jnp.diag(Sigma_latent_pred)

    # 6. Compute predictive probabilities (E[sigma(f)] approximations)
    # Plug-in approximation
    prob_pred_plugin = sigmoid(mu_latent_pred)

    # MacKay's approximation (1992)
    prob_pred_mackay = sigmoid(
        mu_latent_pred / jnp.sqrt(1 + (jnp.pi / 8) * s2_latent_pred)
    )

    return mu_latent_pred, s2_latent_pred, prob_pred_plugin, prob_pred_mackay


Now, let's run the full GP Classification with Laplace Approximation on an overlapping dataset to better observe uncertainty, and visualize the results using Plotly.

In [None]:
# Generate some overlapping data to demonstrate uncertainty
X_train_np, y_train_np = generate_data(type="overlapping", n_samples=100)

# Convert training data to JAX arrays
X_train_jax = jnp.array(X_train_np)
y_train_jax = jnp.array(y_train_np)

# Find the MAP estimate hat_f
hat_f_jax, K_XX_jax = find_map_f_newton_jax(
    X_train_jax, y_train_jax, rbf_kernel, length_scale=1.0
)

print("\nMAP estimate (hat_f) first 5 elements:")
print(hat_f_jax[:5])

# Prepare a grid of test points for visualization
x1_grid = np.linspace(-4, 4, 100)
x2_grid = np.linspace(-4, 4, 100)
X1_mesh, X2_mesh = np.meshgrid(x1_grid, x2_grid)
X_grid_np = np.vstack([X1_mesh.ravel(), X2_mesh.ravel()]).T
X_grid_jax = jnp.array(X_grid_np)

# Compute predictions using the Laplace Approximation with JAX
(
    mu_latent_grid_jax,
    s2_latent_grid_jax,
    prob_pred_plugin_grid_jax,
    prob_pred_mackay_grid_jax,
) = predict_laplace_gp_classification_jax(
    X_grid_jax, X_train_jax, hat_f_jax, K_XX_jax, rbf_kernel, length_scale=1.0
)

# Convert JAX arrays back to numpy for Plotly
prob_pred_plugin_mesh_np = np.array(prob_pred_plugin_grid_jax).reshape(X1_mesh.shape)
prob_pred_mackay_mesh_np = np.array(prob_pred_mackay_grid_jax).reshape(X1_mesh.shape)
s2_latent_mesh_np = np.array(s2_latent_grid_jax).reshape(X1_mesh.shape)

# Plotting Predictive Probabilities (MacKay's Approximation) with Plotly
fig_prob = go.Figure()
fig_prob.add_trace(
    go.Contour(
        z=prob_pred_mackay_mesh_np,
        x=x1_grid,
        y=x2_grid,
        colorscale="RdBu",
        opacity=0.6,
        colorbar_title="Predicted Probability $P(y=1|x)$",
        line_smoothing=0.8,  # Smooth contours
        contours_coloring="heatmap",  # Color between contours
        zmin=0,
        zmax=1,
    )
)

# Add decision boundary (0.5 probability contour)
fig_prob.add_trace(
    go.Contour(
        z=prob_pred_mackay_mesh_np,
        x=x1_grid,
        y=x2_grid,
        showscale=False,
        contours=dict(
            start=0.5,
            end=0.5,
            size=0,
            coloring="lines",
            showlabels=True,
            labelfont=dict(size=12, color="black"),
        ),
        line_color="black",
        line_width=2,
        name="Decision Boundary ($P(y=1|x)=0.5$)",
    )
)

# Add training data points
fig_prob = plot_data_plotly(X_train_np, y_train_np, fig=fig_prob)

fig_prob.update_layout(
    title_text="GP Classification: Predictive Probabilities (MacKay) with JAX and Plotly",
    title_x=0.5,
    xaxis_title="$x_1$",
    yaxis_title="$x_2$",
    height=600,
    width=800,
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01,
        bgcolor="rgba(255,255,255,0.7)",
        bordercolor="black",
        borderwidth=1,
        font=dict(size=14),
    ),
)
fig_prob.show()

# Plotting Predictive Uncertainty (Variance of Latent Function) with Plotly
fig_unc = go.Figure()
fig_unc.add_trace(
    go.Contour(
        z=s2_latent_mesh_np,
        x=x1_grid,
        y=x2_grid,
        colorscale="Viridis",
        opacity=0.7,
        colorbar_title="Predictive Variance $s^2$",
        line_smoothing=0.8,
        contours_coloring="heatmap",
    )
)

# Add training data points
fig_unc = plot_data_plotly(X_train_np, y_train_np, fig=fig_unc)

fig_unc.update_layout(
    title_text="GP Classification: Predictive Uncertainty (Latent Variance) with JAX and Plotly",
    title_x=0.5,
    xaxis_title="$x_1$",
    yaxis_title="$x_2$",
    height=600,
    width=800,
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01,
        bgcolor="rgba(255,255,255,0.7)",
        bordercolor="black",
        borderwidth=1,
        font=dict(size=14),
    ),
)
fig_unc.show()


### Interpreting Predictive Uncertainty in GP Classification

The **predictive uncertainty** in Gaussian Process (GP) classification, as visualized by the variance of the latent function ($s^2$) in the contour plot, provides crucial insights into the model's confidence about its predictions at different locations in the input space.

#### What Does Predictive Uncertainty Represent?

- **High Uncertainty ($s^2$ is large):**
    - The model is **less confident** about its prediction.
    - This typically occurs in regions where:
        - The classes overlap (i.e., both class -1 and +1 data points are nearby).
        - There are **few or no training data points** (far from observed data).
        - The decision boundary is located (the model is unsure which class to assign).
- **Low Uncertainty ($s^2$ is small):**
    - The model is **more confident** about its prediction.
    - This is usually seen in regions densely populated by training points of a single class.

#### What Do We See in the Contour Plot?

- The **contour plot of $s^2$** (predictive variance) shows:
    - **High uncertainty (bright/yellow regions in Viridis colormap)** near the **decision boundary** and in areas where the two classes are mixed or close together.
    - **Low uncertainty (dark/purple regions)** in areas well inside a cluster of one class, far from the boundary.
- **Away from the training data** (e.g., corners of the plot), uncertainty increases again because the model has little information there.

#### Why Is This Important?

- **Model Trust:** High uncertainty means the model "knows what it doesn't know"—it is cautious in ambiguous regions.
- **Decision Making:** In applications, you might avoid making hard decisions in high-uncertainty regions, or seek more data there.
- **Active Learning:** The model can suggest where new data would be most informative (regions of high uncertainty).

#### Summary Table

| Region in Input Space         | Predictive Variance $s^2$ | Model Confidence | Typical Location         |
|------------------------------|---------------------------|------------------|-------------------------|
| Near decision boundary       | High                      | Low              | Class overlap           |
| Inside dense class cluster   | Low                       | High             | Well-separated regions  |
| Far from any training points | High                      | Low              | Corners/outliers        |

#### Visual Example

- In your contour plot, **the highest uncertainty forms a band along the decision boundary**—this is where the model is most unsure about the class label.
- **Uncertainty drops off** as you move into regions dominated by one class.

> **In summary:**  
> The predictive uncertainty tells you where the GP classifier is unsure about its predictions. The contour plot visually highlights these regions, guiding interpretation and further data collection.

### Translating Uncertainty from Latent Space to Label Space in GP Classification

In Gaussian Process (GP) classification, the **latent space** refers to the values of the latent function $f(x)$, which are modeled as a Gaussian process. The **label space** refers to the observed class labels $y \in \{-1, +1\}$ (or $\{0, 1\}$).

#### 1. Predictive Distribution in Latent Space

After applying the Laplace approximation, the predictive distribution for the latent function at a new point $x_*$ is:

$$
q(f_{x_*} | y) = \mathcal{N}(f_{x_*}; \mu_{q(f_{x_*}|y)}, \Sigma_{q(f_{x_*}|y)})
$$

- $\mu_{q(f_{x_*}|y)}$: Mean of the latent function at $x_*$.
- $\Sigma_{q(f_{x_*}|y)}$: Variance (uncertainty) of the latent function at $x_*$.

#### 2. From Latent Function to Label Probability

The class label is determined by passing the latent function through a sigmoid (logistic) function:

$$
p(y_* = 1 \mid x_*, \text{data}) = \mathbb{E}_{q(f_{x_*}|y)}[\sigma(f_{x_*})]
$$

where $\sigma(f) = \frac{1}{1 + e^{-f}}$ is the logistic sigmoid.

#### 3. Translating Uncertainty

- **Latent Uncertainty ($s^2$):** The variance $\Sigma_{q(f_{x_*}|y)}$ quantifies how uncertain we are about the value of $f_{x_*}$.
- **Label Uncertainty:** The uncertainty in $f_{x_*}$ translates to uncertainty in the predicted class probability $p(y_* = 1 \mid x_*)$.

However, because the sigmoid is a nonlinear function, the mean of the sigmoid is **not** the sigmoid of the mean:

$$
\mathbb{E}[\sigma(f_{x_*})] \neq \sigma(\mathbb{E}[f_{x_*}])
$$

#### 4. Practical Approximations

- **Plug-in Approximation:**  
    $$p(y_* = 1 \mid x_*) \approx \sigma(\mu_{q(f_{x_*}|y)})$$  
    Ignores uncertainty in $f_{x_*}$.

- **MacKay's Approximation:**  
    $$p(y_* = 1 \mid x_*) \approx \sigma\left(\frac{\mu_{q(f_{x_*}|y)}}{\sqrt{1 + \frac{\pi}{8} s^2}}\right)$$  
    This accounts for the variance $s^2 = \Sigma_{q(f_{x_*}|y)}$ and gives a more accurate probability, especially when uncertainty is high.

#### 5. Intuitive Summary

- **High latent variance ($s^2$):** Predictive probability is "squashed" toward 0.5 (maximum uncertainty in label).
- **Low latent variance:** Predictive probability is closer to 0 or 1 (high confidence in label).

#### 6. Visual Example

- In the contour plots, regions with high $s^2$ correspond to predicted probabilities near 0.5, indicating the model is unsure about the class label.
- Regions with low $s^2$ correspond to probabilities near 0 or 1, indicating high confidence.

---

**In summary:**  
Uncertainty in the latent function $f(x)$ is translated to label uncertainty by integrating the sigmoid over the Gaussian predictive distribution. Approximations like MacKay's formula allow us to efficiently compute this, ensuring that high latent uncertainty leads to label probabilities closer to 0.5.

#### Exercises

**Exercise 1: Compare Plug-in vs. MacKay's Approximation**
Plot the predictive probabilities using both `prob_pred_plugin_mesh_np` and `prob_pred_mackay_mesh_np` side-by-side using Plotly subplots. What differences do you observe, especially in regions of high uncertainty or near the decision boundary? Why might MacKay's approximation be preferred?

**Exercise 2: Interpreting Predictive Uncertainty**
Analyze the predictive uncertainty plot (`s2_latent_mesh_np`). Where is the uncertainty highest? How does this relate to the training data distribution and the decision boundary? What does high uncertainty imply in a classification context?

**Exercise 3: Impact of Data Separability on Uncertainty**
Change the `generate_data` type to `'separable'` and `'intermingled'`. Rerun the entire notebook. How do the predictive probabilities and uncertainties change for these different data distributions? Discuss the implications for model confidence.

**Exercise 4: Effect of Kernel Length Scale on Predictions**
Experiment with different `length_scale` values (e.g., 0.1, 0.5, 2.0) for the `rbf_kernel`. How does the `length_scale` influence the smoothness of the decision boundary and the distribution of predictive uncertainty? Relate this to the concept of kernel hyperparameters and their role in GP models.

**Exercise 5 (Advanced): Implement Numerical Stability Improvement**
Modify the `predict_laplace_gp_classification_jax` function to use the numerically more stable form for `Sigma_latent_pred` (using the $B$ matrix and matrix inversion lemma) as discussed in Section 5 and Slide 17. Verify that the results are consistent with the current implementation, but note the potential benefits for larger or more ill-conditioned datasets.