# Probabilistic Machine Learning: Lecture 19 - Uses for Uncertainty in Deep Learning II

#### Introduction

Welcome to Lecture 19 of Probabilistic Machine Learning! Building on our previous discussions, particularly Lecture 17 on Probabilistic Deep Learning, this lecture delves deeper into the practical **uses and implications of uncertainty quantification in deep learning**. We will explore how probabilistic interpretations can address common pitfalls of deterministic deep networks, such as pathological overconfidence, and enable advanced capabilities like continual learning.

This notebook will provide detailed explanations and practical code illustrations using **JAX** for efficient numerical computations and **Plotly** for interactive visualizations, demonstrating the power of integrating probabilistic methods into deep learning.

#### 1. Recap: Deep Networks are GPs

As a quick recap from Lecture 17 (Slide 2), we learned how to turn a trained deep neural network into an approximate Gaussian Process (GP) using Laplace approximations. This four-step process provides a probabilistic interpretation to an otherwise deterministic model:

1.  **Realize that the loss is a negative log-posterior**: The Empirical Risk Minimization (ERM) objective function used to train deep networks can be seen as the negative log-posterior of the network's parameters, given the data.
    $$\mathcal{L}(\theta) = -\log p(\theta|\mathcal{D}) + \text{const.}$$

2.  **Train the deep net as usual to find $\theta_*$**: Standard training procedures (e.g., SGD, Adam) find the Maximum A Posteriori (MAP) estimate of the parameters, $\theta_*$.
    $$\theta_* = \arg \max_{\theta \in \mathbb{R}^D} p(\theta|\mathcal{D})$$

3.  **At $\theta_*$, compute a Laplace approximation of the log-posterior**: The posterior $p(\theta|\mathcal{D})$ is approximated by a Gaussian distribution $\mathcal{N}(\theta; \theta_*, -\Psi^{-1})$, where $\Psi := -\nabla\nabla^T \log p(\theta_*|\mathcal{D})$ is the negative Hessian of the log-posterior at $\theta_*$.

4.  **Linearize $f(x, \theta)$ around $\theta_*$**: The deep network's output is approximated linearly around $\theta_*$:
    $$f(x, \theta) \approx f(x, \theta_*) + J(x, \theta_*) (\theta - \theta_*)$$
    where $J(x, \theta_*)$ is the Jacobian of $f(x, \theta)$ with respect to $\theta$.

Combining these steps, the posterior distribution over the function output $f(\bullet)$ given the data $\mathcal{D}$ is approximated as a Gaussian Process:

$$p(f(\bullet)|\mathcal{D}) \approx \mathcal{GP}(f(\bullet); f(\bullet, \theta_*), -J(\bullet)\Psi^{-1}J(\circ)^T)$$

The mean function is the trained network $f(\bullet, \theta_*)$, and the covariance function is the **Laplace tangent kernel** $-J(\bullet)\Psi^{-1}J(\circ)^T$. This framework allows us to quantify uncertainty in deep learning predictions.

Let's set up the necessary imports and utility functions, including the MLP model and the `find_map_mlp_params` function from Lecture 17.

In [None]:
import jax.numpy as jnp
import jax
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from jax.flatten_util import ravel_pytree # Utility to flatten/unflatten JAX pytrees

# Set JAX to use 64-bit floats for numerical stability
jax.config.update("jax_enable_x64", True)

# --- Utility Functions ---
def sigmoid(f):
    """Logistic sigmoid function."""
    return 1 / (1 + jnp.exp(-f))

def relu(x):
    """ReLU activation function."""
    return jnp.maximum(0, x)

def rbf_kernel(X1, X2, length_scale=1.0):
    """Radial Basis Function (RBF) kernel."""
    sqdist = jnp.sum(X1**2, 1)[:, None] + jnp.sum(X2**2, 1) - 2 * jnp.dot(X1, X2.T)
    return jnp.exp(-0.5 * (1/length_scale**2) * sqdist)

def generate_data(type='moons', n_samples=100, noise_std=0.1):
    """Generates synthetic 2D classification data (using numpy for random generation)."""
    np.random.seed(42) # Ensure reproducibility for data generation
    if type == 'moons':
        from sklearn.datasets import make_moons
        X, y = make_moons(n_samples=n_samples, noise=noise_std, random_state=42)
        y = np.where(y == 0, -1, 1) # Convert labels to -1 and 1
    elif type == 'circles':
        from sklearn.datasets import make_circles
        X, y = make_circles(n_samples=n_samples, noise=noise_std, factor=0.5, random_state=42)
        y = np.where(y == 0, -1, 1) # Convert labels to -1 and 1
    elif type == 'separable':
        mean1 = [-1, 0.5]; cov1 = [[0.5, 0.2], [0.2, 0.5]]
        data1 = np.random.multivariate_normal(mean1, cov1, n_samples // 2)
        labels1 = np.ones(n_samples // 2) * -1
        mean2 = [1, -0.5]; cov2 = [[0.5, 0.2], [0.2, 0.5]]
        data2 = np.random.multivariate_normal(mean2, cov2, n_samples // 2)
        labels2 = np.ones(n_samples // 2)
        X = np.vstack((data1, data2)); y = np.hstack((labels1, labels2))
    else:
        raise ValueError("Invalid data type. Choose 'moons', 'circles', or 'separable'.")
    return X, y

def plot_classification_plotly(X, y, Z=None, x_grid=None, y_grid=None, title="", fig=None, row=None, col=None, colorscale='RdBu'):
    """Plots 2D classification data and optional contour predictions using Plotly."""
    if fig is None:
        fig = go.Figure()
    
    X_np = np.asarray(X)
    y_np = np.asarray(y)

    # Plot data points
    fig.add_trace(go.Scatter(
        x=X_np[y_np == -1, 0],
        y=X_np[y_np == -1, 1],
        mode='markers',
        marker=dict(color='maroon', symbol='circle'),
        name='Class -1',
        showlegend=True
    ), row=row, col=col)

    fig.add_trace(go.Scatter(
        x=X_np[y_np == 1, 0],
        y=X_np[y_np == 1, 1],
        mode='markers',
        marker=dict(color='skyblue', symbol='circle', line=dict(width=1, color='skyblue')),
        name='Class +1',
        showlegend=True
    ), row=row, col=col)

    # Plot contour if Z and grid are provided
    if Z is not None and x_grid is not None and y_grid is not None:
        fig.add_trace(go.Contour(
            z=np.asarray(Z),
            x=np.asarray(x_grid),
            y=np.asarray(y_grid),
            colorscale=colorscale,
            opacity=0.6,
            colorbar_title='Predicted Probability $P(y=1|x)$' if colorscale=='RdBu' else 'Uncertainty',
            line_smoothing=0.8,
            contours_coloring='heatmap',
            zmin=0, zmax=1 if colorscale=='RdBu' else None
        ), row=row, col=col)
        if colorscale == 'RdBu': # Add decision boundary for probability plots
            fig.add_trace(go.Contour(
                z=np.asarray(Z),
                x=np.asarray(x_grid),
                y=np.asarray(y_grid),
                showscale=False,
                contours=dict(start=0.5, end=0.5, size=0, coloring='lines', showlabels=True, labelfont=dict(size=12, color='black')),
                line_color='black',
                line_width=2,
                name='Decision Boundary'
            ), row=row, col=col)

    fig.update_layout(title_text=title, title_x=0.5)
    fig.update_xaxes(title_text='$x_1$', range=[-3, 3], row=row, col=col)
    fig.update_yaxes(title_text='$x_2$', range=[-3, 3], scaleanchor="x", scaleratio=1, row=row, col=col)
    
    return fig

# --- Simple MLP Implementation in JAX (from Lecture 17) ---
def init_mlp_params(key, layer_sizes):
    """Initializes parameters for a simple MLP."""
    params = []
    for i in range(len(layer_sizes) - 1):
        key, subkey = jax.random.split(key)
        in_dim = layer_sizes[i]
        out_dim = layer_sizes[i+1]
        # Glorot initialization for weights
        limit = jnp.sqrt(6 / (in_dim + out_dim))
        weights = jax.random.uniform(subkey, (in_dim, out_dim), minval=-limit, maxval=limit)
        biases = jnp.zeros(out_dim)
        params.append({'weights': weights, 'biases': biases})
    return params

def mlp_forward(params, x):
    """Forward pass through the MLP."""
    hidden_layers = params[:-1]
    output_layer = params[-1]

    h = x
    for layer in hidden_layers:
        h = jnp.dot(h, layer['weights']) + layer['biases']
        h = relu(h) # Using ReLU as nonlinearity
    
    # Output layer: For classification, we want logits (raw scores) before sigmoid/softmax
    output = jnp.dot(h, output_layer['weights']) + output_layer['biases']
    return output

# --- Loss and Regularization Functions (adapted for classification) ---
def binary_cross_entropy_loss(logits, targets):
    """Binary Cross-Entropy loss for targets -1 and 1."""
    # For targets -1 and 1, BCE is -log(sigmoid(y*f)) = log(1 + exp(-y*f))
    return jnp.mean(jnp.log(1 + jnp.exp(-targets * logits)))

def l2_regularization(params, lambda_reg):
    """L2 regularization (weight decay)."""
    l2_norm = 0.0
    for layer in params:
        l2_norm += jnp.sum(jnp.square(layer['weights']))
    return 0.5 * lambda_reg * l2_norm

def neg_log_posterior_classification(params, X, y, lambda_reg):
    """
    Negative log-posterior for MLP parameters in classification (ERM objective).
    Assumes Bernoulli likelihood (via sigmoid) and Gaussian prior on weights.
    """
    logits = mlp_forward(params, X)
    # Negative log-likelihood (Binary Cross-Entropy)
    neg_log_likelihood = jnp.sum(jnp.log(1 + jnp.exp(-y * logits))) # Sum over samples
    # Negative log-prior (Gaussian prior on weights, proportional to L2 reg)
    neg_log_prior = l2_regularization(params, lambda_reg) * X.shape[0] # Scale regularization by N for consistency
    return neg_log_likelihood + neg_log_prior

# --- Newton's Method for MAP estimate (adapted for classification) ---
@jax.jit
def newton_step_mlp_classification(flat_params, unflatten_fn, X, y, lambda_reg):
    """Performs one Newton update step for maximizing the log posterior of MLP params (classification)."""
    params = unflatten_fn(flat_params)
    
    # Compute gradient and Hessian of the negative log posterior
    grad_fn = jax.grad(neg_log_posterior_classification)
    hess_fn = jax.hessian(neg_log_posterior_classification)

    grad_val = grad_fn(params, X, y, lambda_reg)
    hess_val = hess_fn(params, X, y, lambda_reg)

    # Flatten grad and hess for linear algebra
    flat_grad, _ = ravel_pytree(grad_val)
    flat_hess, _ = ravel_pytree(hess_val)

    # Newton update: flat_params_new = flat_params_old - H_inv @ grad
    delta_flat_params = jnp.linalg.solve(flat_hess, flat_grad)
    flat_params_new = flat_params - delta_flat_params
    return flat_params_new

def find_map_mlp_params_classification(key, X_train, y_train, layer_sizes, lambda_reg, max_iter=100, tol=1e-5):
    """
    Finds the MAP estimate (theta_star) for MLP parameters using Newton's method (classification).
    Returns theta_star.
    """
    initial_params = init_mlp_params(key, layer_sizes)
    flat_params, unflatten_fn = ravel_pytree(initial_params)

    print("Starting Newton's method for MLP MAP estimate (classification)...")
    for i in range(max_iter):
        flat_params_new = newton_step_mlp_classification(flat_params, unflatten_fn, X_train, y_train, lambda_reg)
        change = jnp.linalg.norm(flat_params_new - flat_params)
        if change < tol:
            print(f"Converged in {i+1} iterations. Final change: {change:.6f}")
            break
        flat_params = flat_params_new
    else:
        print(f"Newton's method did not converge after {max_iter} iterations. Final change: {change:.6f}")
    
    theta_star = unflatten_fn(flat_params)
    return theta_star, unflatten_fn

# --- Prediction Functions for Laplace Approx. (adapted for classification) ---
def predict_laplace_deep_gp_classification(X_test, X_train, theta_star, unflatten_fn, lambda_reg, layer_sizes):
    """
    Computes predictive probabilities and uncertainty for a Deep GP (Laplace Approx.).
    Returns (mean_logits, variance_latent, prob_plugin, prob_mackay).
    """
    # 1. Compute the Hessian (Psi) of the negative log-posterior at theta_star
    flat_theta_star, _ = ravel_pytree(theta_star)

    def flat_neg_log_posterior_for_hessian(flat_p):
        p = unflatten_fn(flat_p)
        return neg_log_posterior_classification(p, X_train, y_train_jax, lambda_reg)

    hess_neg_log_post_at_theta_star = jax.hessian(flat_neg_log_posterior_for_hessian)(flat_theta_star)
    Psi = -hess_neg_log_post_at_theta_star
    Sigma_theta_star = jnp.linalg.inv(-Psi) # Covariance of theta_star approx

    # 2. Linearize f(x, theta) around theta_star and compute predictions
    predictive_means_logits = []
    predictive_variances_latent = []

    for x_t in X_test:
        # Compute J(x_t, theta_star)
        def mlp_output_for_jacobian(p_flat):
            p = unflatten_fn(p_flat)
            return mlp_forward(p, x_t.reshape(1, -1))
        
        J_xt = jax.jacobian(mlp_output_for_jacobian)(flat_theta_star)
        
        if J_xt.ndim == 1: # For 1D output, jax.jacobian might return 1D array if input is 1D
            J_xt = J_xt.reshape(1, -1)

        # Predictive Mean of latent function: E[f(x)] = f(x, theta_star)
        mean_logits = mlp_forward(theta_star, x_t.reshape(1, -1)).flatten()[0]
        
        # Predictive Variance of latent function: J(x) @ Sigma_theta_star @ J(x).T
        variance_latent = (J_xt @ Sigma_theta_star @ J_xt.T).flatten()[0]

        predictive_means_logits.append(mean_logits)
        predictive_variances_latent.append(variance_latent)

    mean_logits = jnp.array(predictive_means_logits)
    variance_latent = jnp.array(predictive_variances_latent)

    # 3. Compute predictive probabilities (E[sigma(f)] approximations)
    # Plug-in approximation
    prob_plugin = sigmoid(mean_logits)

    # MacKay's approximation (1992)
    prob_mackay = sigmoid(mean_logits / jnp.sqrt(1 + (jnp.pi/8) * variance_latent))

    return mean_logits, variance_latent, prob_plugin, prob_mackay


#### 2. The Cost of Uncertainty (Numerical Challenges for Laplace Approximations)

While the Laplace approximation provides a powerful way to quantify uncertainty, it comes with computational challenges (as per Slide 3):

* **Exact Hessian computation is $\mathcal{O}(ND^2)$**: For a dataset of size $N$ and a network with $D$ parameters, computing the full Hessian matrix $\Psi$ (or the Hessian of the negative log-posterior) is computationally intensive.
* **Hessian inversion is $\mathcal{O}(D^3)$**: Inverting the $D \times D$ Hessian matrix is even more demanding, especially for large deep networks where $D$ can be in the millions or billions.

These cubic complexities with respect to the number of parameters $D$ make exact Laplace approximation impractical for very large models. Therefore, various **approximation ideas** have been developed:

* **Sub-sampling the dataset**: Compute the Hessian on a subset of $M$ data points, reducing complexity to $\mathcal{O}(MD^2)$ with $M \ll N$.
* **Structural approximations to the Hessian**:
    * **Diagonal approximation**: Approximating $\Psi$ as a diagonal matrix, reducing inversion to $\mathcal{O}(D)$.
    * **Last-layer approximation**: Only computing the Hessian for the parameters of the last layer, which is often much smaller ($D_L \ll D$), leading to $\mathcal{O}(D_L^3)$ inversion.
    * **Kronecker-factored approximate curvature (KFAC)**: Approximating $\Psi$ as a block-diagonal matrix where each block is a Kronecker product of smaller matrices. This reduces the inversion complexity to $\mathcal{O}(\sum_l \text{in}_l^3 + \text{out}_{l-1}^3)$, where $\text{in}_l$ and $\text{out}_l$ are the input and output dimensions of layer $l$.
    * **Generalized Gauss-Newton (GGN)**: Approximating the Hessian using the first-order derivatives of the loss and the network output, often leading to a more structured and easier-to-invert matrix.
* **Approximate eigenvalue decompositions**: Using methods like the Lanczos algorithm (as discussed in Lecture 13) to approximate the Hessian's eigenvalues and eigenvectors, which can be useful for various Hessian-based computations without full inversion.

Let's illustrate the concept of Hessian computation with a simple example on our MLP.

In [None]:
# --- Illustrating Hessian Computation ---

# Generate a small dataset for illustration
X_toy_np, y_toy_np = generate_data(type='separable', n_samples=10, noise_std=0.05)
X_toy_jax = jnp.array(X_toy_np)
y_toy_jax = jnp.array(y_toy_np)

# Define a very small MLP for easier Hessian visualization
toy_input_dim = 2
toy_hidden_dim = 3 # Small hidden layer
toy_output_dim = 1
toy_layer_sizes = [toy_input_dim, toy_hidden_dim, toy_output_dim]

toy_key = jax.random.PRNGKey(0)
toy_params = init_mlp_params(toy_key, toy_layer_sizes)

# Flatten parameters to compute a single Hessian matrix
flat_toy_params, unflatten_toy_params = ravel_pytree(toy_params)

# Define a flattened version of the negative log posterior for Hessian computation
def flat_neg_log_posterior_toy(flat_p):
    p = unflatten_toy_params(flat_p)
    return neg_log_posterior_classification(p, X_toy_jax, y_toy_jax, lambda_reg=0.01)

# Compute the Hessian using JAX
hessian_toy = jax.hessian(flat_neg_log_posterior_toy)(flat_toy_params)

print("\n--- Hessian Computation Example ---")
print(f"Number of parameters (D): {flat_toy_params.shape[0]}")
print(f"Shape of the Hessian matrix: {hessian_toy.shape}")
print("First 5x5 block of the Hessian:\n", hessian_toy[:5, :5])

# Inverting the Hessian (for illustration, not for large D)
try:
    hessian_toy_inv = jnp.linalg.inv(hessian_toy)
    print("\nHessian successfully inverted (first 5x5 block of inverse):\n", hessian_toy_inv[:5, :5])
except jnp.linalg.LinAlgError:
    print("\nHessian is singular or ill-conditioned and could not be inverted.")


#### 3. What's Not to Like About Deep Learning? (Pathological Overconfidence)

One of the significant drawbacks of standard, deterministic deep learning models is their tendency for **pathological overconfidence**, especially when making predictions on data far from the training distribution (out-of-distribution data) (as per Slide 5).

A key theorem by Hein et al. (2019) states that for ReLU networks trained with cross-entropy loss, for almost any input $x$ and any small $\epsilon > 0$, there exists a scaling factor $\alpha$ such that when $x$ is scaled by $\alpha$ and becomes very large ($Z = \alpha X$), the softmax output for some class $k$ approaches 1:

$$\lim_{\alpha \to \infty} \frac{\exp(f_k(\alpha x))}{\sum_{l=1}^K \exp(f_l(\alpha x))} = 1$$

**Intuition**: ReLU networks are piecewise linear. Far from the data, the network essentially behaves like a linear function $f_i(x, \theta) = A_i(\theta)x$. As $||x|| \to \infty$, the difference between the largest output and other outputs, $f_{\arg\max_i A_i}(x) - f_j$, grows linearly with $||x||$. Consequently, the softmax function, which involves exponentials, pushes the probability of the dominant class to 1. This means the model becomes *overconfident* in its predictions, even in regions where it has seen no training data.

Let's illustrate this overconfidence with a simple 2D classification example.

In [None]:
# --- Illustrating Pathological Overconfidence ---

# Generate 'moons' data for a non-linear decision boundary
X_conf_np, y_conf_np = generate_data(type='moons', n_samples=200, noise_std=0.1)
X_conf_jax = jnp.array(X_conf_np)
y_conf_jax = jnp.array(y_conf_np)

# Define MLP architecture for classification
conf_input_dim = 2
conf_hidden_dim = 50 # A reasonably sized hidden layer
conf_output_dim = 1 # Output logits for binary classification
conf_layer_sizes = [conf_input_dim, conf_hidden_dim, conf_hidden_dim, conf_output_dim]

# Train the MLP (find MAP estimate)
conf_key = jax.random.PRNGKey(20)
lambda_reg_conf = 0.001 # Small regularization to allow for overconfidence
theta_star_conf, unflatten_fn_conf = find_map_mlp_params_classification(
    conf_key, X_conf_jax, y_conf_jax, conf_layer_sizes, lambda_reg_conf, max_iter=2000, tol=1e-6
)

print("\nMLP training complete for overconfidence illustration.")

# Create a grid for predictions
x1_grid = np.linspace(-3, 3, 100)
x2_grid = np.linspace(-3, 3, 100)
X1_mesh, X2_mesh = np.meshgrid(x1_grid, x2_grid)
X_grid_np = np.vstack([X1_mesh.ravel(), X2_mesh.ravel()]).T
X_grid_jax = jnp.array(X_grid_np)

# Get deterministic predictions (logits) from the trained MLP
deterministic_logits = mlp_forward(theta_star_conf, X_grid_jax)
deterministic_probs = sigmoid(deterministic_logits).reshape(X1_mesh.shape)

# Plotting the deterministic model's confidence
fig_overconf = plot_classification_plotly(
    X_conf_np, y_conf_np,
    Z=deterministic_probs,
    x_grid=x1_grid, y_grid=x2_grid,
    title='Deterministic MLP: Pathological Overconfidence (Far from Data)'
)
fig_overconf.update_layout(height=600, width=800)
fig_overconf.show()


As you can see in the plot above, even in regions far from the 'moons' data (e.g., the corners of the plot), the deterministic MLP assigns very high (or very low) probabilities, indicating strong confidence. This is problematic because the model hasn't seen any data in those regions and should ideally express uncertainty.

#### 4. Being Bayesian, Even Just a Bit, Fixes Overconfidence

The good news is that introducing even a small amount of Bayesianism (i.e., quantifying uncertainty) can mitigate this overconfidence (as per Slide 6). If we consider a Gaussian measure on the weights $p(\theta) = \mathcal{N}(\theta; \theta_*, \Sigma)$ (e.g., from the Laplace approximation), the associated push-forward GP provides a more calibrated uncertainty estimate.

Far from the data, where the network behaves linearly $f_i(x, \theta) = A_i(\theta)x$, the predictive GP posterior on $f(x)$ is approximately:

$$p(f(x)) \approx \mathcal{GP}(f(x); f(x, \theta_*), J(x, \theta_*) \Sigma J(x, \theta_*)^T)$$

As $||x|| \to \infty$, the mean grows linearly, $\mathbb{E}(f(x)) = \mathcal{O}(||x||)$, but crucially, the **variance also grows quadratically**, $\text{var}(f(x)) = \mathcal{O}(||x||^2)$.

When we pass this through the sigmoid using MacKay's approximation:

$$\mathbb{E}(\sigma(f(x))) \approx \sigma\left(\frac{\mathbb{E}(f(x))}{\sqrt{1 + \frac{\pi}{8}\text{var}(f(x))}}\right) \xrightarrow{||x||=\infty} \sigma\left(\frac{\mathcal{O}(||x||)}{\sqrt{1 + \mathcal{O}(||x||^2)}}\right)$$

As $||x|| \to \infty$, the argument to the sigmoid approaches a constant $c$. Thus, the predictive probability approaches $\sigma(c)$, which is a value strictly between 0 and 1 (e.g., 0.5 for a balanced classifier), rather than saturating at 0 or 1. This indicates a more **calibrated confidence** far from the data, reflecting the model's increased uncertainty in unobserved regions.

Let's re-evaluate the predictions using the Laplace approximation for the same MLP and data.

In [None]:
# --- Illustrating Calibrated Confidence with Laplace Approximation ---

# Compute predictions using the Laplace Approximation (Deep GP)
mean_logits_laplace, variance_latent_laplace, prob_plugin_laplace, prob_mackay_laplace = \
    predict_laplace_deep_gp_classification(X_grid_jax, X_conf_jax, theta_star_conf, unflatten_fn_conf, lambda_reg_conf, conf_layer_sizes)

# Reshape for plotting
prob_mackay_mesh_np = np.asarray(prob_mackay_laplace).reshape(X1_mesh.shape)
variance_latent_mesh_np = np.asarray(variance_latent_laplace).reshape(X1_mesh.shape)

# Plotting the predictive probabilities (MacKay's approximation)
fig_calibrated_prob = plot_classification_plotly(
    X_conf_np, y_conf_np,
    Z=prob_mackay_mesh_np,
    x_grid=x1_grid, y_grid=x2_grid,
    title='Deep GP (Laplace Approx.): Calibrated Probabilities (MacKay)'
)
fig_calibrated_prob.update_layout(height=600, width=800)
fig_calibrated_prob.show()

# Plotting the predictive uncertainty (latent variance)
fig_calibrated_unc = plot_classification_plotly(
    X_conf_np, y_conf_np,
    Z=variance_latent_mesh_np,
    x_grid=x1_grid, y_grid=x2_grid,
    title='Deep GP (Laplace Approx.): Predictive Uncertainty (Latent Variance)',
    colorscale='Viridis' # Use a different colorscale for uncertainty
)
fig_calibrated_unc.update_layout(height=600, width=800)
fig_calibrated_unc.show()


Compare these plots to the deterministic one. You should observe that the probabilities far from the data are now closer to 0.5 (or the class prior), and the uncertainty plot clearly shows higher variance in regions lacking training data. This demonstrates how a probabilistic approach leads to more honest and calibrated predictions.

#### 5. Finite Networks Have Finite Capacity

While the Laplace approximation helps, finite networks, especially those with ReLU activation functions, inherently have limitations in their capacity to express certain types of functions and uncertainties (as per Slide 8).

To achieve truly **calibrated confidence** (e.g., $\mathbb{E}(\sigma(f(x))) = 1/C$ for $C$ classes) as $||x|| \to \infty$, the variance of the latent function $f(x)$ needs to grow faster than quadratically, specifically $\text{var}(f(x)) = \omega(||x||^2)$. This cannot be achieved with finitely many ReLU features in the network, because beyond a certain point, all ReLUs will either be active or inactive, making the network behave linearly. Thus, $f(x) = Ax$ beyond that point, and its variance will grow quadratically at best.

However, we can draw a connection to Gaussian Processes. From Lecture 9, recall the **Integrated Wiener Process (IWP) kernel**, which has $k_{\text{IWP}}(x,x) = \mathcal{O}(x^3)$. This type of kernel can provide the necessary super-quadratic growth in variance far from the data.

This suggests a role for nonparametric modeling in deep learning: combining finite deep networks with GP priors (like the IWP) can yield models with both the strong representational power of deep learning and the robust uncertainty quantification of GPs, particularly in extrapolation scenarios. This leads to concepts like **ReLU-GPs** (Kristiadi, Hein, Hennig, NeurIPS 2021), where a ReLU network is augmented with a GP prior (e.g., $f(x) = f_{\text{NN}}(x) + \hat{f}(x)$ where $\hat{f}(x) \sim \mathcal{GP}(0, k_{\text{IWP}})$).

#### 6. Continual Learning

Another significant application of uncertainty in deep learning is in **continual learning** (also known as lifelong learning or incremental learning). This addresses the challenge of updating a trained model when new data arrives sequentially, and previous datasets are no longer available (as per Slide 14).

Standard deep learning approaches often suffer from **catastrophic forgetting**, where training on new data causes the model to forget previously learned information. Common alternatives like "replay" (storing and re-training on old data) are expensive.

**Probabilistic inference naturally deals with continual learning** because "yesterday's posterior is today's prior" (as per Slide 15):

$$p(\theta | \mathcal{D}_1, \mathcal{D}_2) = \frac{p(\theta | \mathcal{D}_1) \cdot p(\mathcal{D}_2 | \theta)}{p(\mathcal{D}_2)}$$

Taking the negative logarithm:
$$-\log p(\theta | \mathcal{D}_1, \mathcal{D}_2) = -\log p(\mathcal{D}_2 | \theta) - \log p(\theta | \mathcal{D}_1)$$

If we use a Laplace approximation for $p(\theta | \mathcal{D}_1) \approx \mathcal{N}(\theta; \theta_1, \Psi_1^{-1})$, then the new objective function for training on $\mathcal{D}_2$ becomes:

$$-\log p(\theta | \mathcal{D}_1, \mathcal{D}_2) \approx \sum_{i=1}^{N_2} \ell(y_{2,i}, f(x_{2,i}, \theta)) + \frac{1}{2}(\theta - \theta_1)^T \Psi_1 (\theta - \theta_1)$$

This means the posterior from the previous task acts as a **probabilistic regularizer** for the current task. Instead of a simple L2 regularizer (which assumes a spherical Gaussian prior), the Laplace approximation provides a *full covariance matrix* $\Psi_1$, preserving more information about the parameter's uncertainty and importance from the previous task. This helps prevent catastrophic forgetting by penalizing changes to parameters that were well-determined by previous data.

Let's simulate a simple continual learning scenario to illustrate this concept.

In [None]:
# --- Illustrating Continual Learning with Laplace Approximation ---

# Define a simple 1D regression problem with two tasks (shifted sine waves)
def generate_task_data(offset, n_samples=50, noise_std=0.1):
    X = np.linspace(-3 + offset, 3 + offset, n_samples).reshape(-1, 1)
    y = np.sin(X * 2) + noise_std * np.random.randn(n_samples, 1)
    return jnp.array(X), jnp.array(y)

# Task 1 data
X_task1, y_task1 = generate_task_data(offset=0, n_samples=50, noise_std=0.1)

# Task 2 data (shifted)
X_task2, y_task2 = generate_task_data(offset=2.5, n_samples=50, noise_std=0.1)

# MLP architecture
cl_input_dim = 1
cl_hidden_dim = 20
cl_output_dim = 1
cl_layer_sizes = [cl_input_dim, cl_hidden_dim, cl_hidden_dim, cl_output_dim]

lambda_reg_cl = 0.001 # Base regularization
noise_variance_cl = 0.1**2

# --- Scenario 1: Just Keep Training (Catastrophic Forgetting) ---
print("\n--- Scenario 1: Just Keep Training ---")
key_cl1 = jax.random.PRNGKey(30)

# Train on Task 1
theta_task1_naive, _ = find_map_mlp_params(
    key_cl1, X_task1, y_task1, cl_layer_sizes, lambda_reg_cl, noise_variance_cl, max_iter=1000, tol=1e-7
)

# Continue training on Task 2 (naive approach)
theta_task2_naive, _ = find_map_mlp_params(
    key_cl1, X_task2, y_task2, cl_layer_sizes, lambda_reg_cl, noise_variance_cl, max_iter=1000, tol=1e-7
)

# Evaluate performance on both tasks after training on Task 2
X_combined_test = jnp.linspace(-3, 5.5, 200).reshape(-1, 1)
preds_task1_after_task2_naive = mlp_forward(theta_task2_naive, X_task1)
preds_task2_after_task2_naive = mlp_forward(theta_task2_naive, X_task2)

print(f"MSE on Task 1 after Task 2 (Naive): {mse_loss(preds_task1_after_task2_naive, y_task1):.4f}")
print(f"MSE on Task 2 after Task 2 (Naive): {mse_loss(preds_task2_after_task2_naive, y_task2):.4f}")

fig_naive_cl = plot_regression_plotly(
    X_combined_test,
    mlp_forward(theta_task2_naive, X_combined_test),
    predictions=mlp_forward(theta_task2_naive, X_combined_test),
    title='Continual Learning: Naive Approach (After Task 2 Training)'
)
fig_naive_cl.add_trace(go.Scatter(x=X_task1.flatten(), y=y_task1.flatten(), mode='markers', name='Task 1 Data', marker=dict(color='blue')))
fig_naive_cl.add_trace(go.Scatter(x=X_task2.flatten(), y=y_task2.flatten(), mode='markers', name='Task 2 Data', marker=dict(color='green')))
fig_naive_cl.update_layout(height=600, width=800)
fig_naive_cl.show()

# --- Scenario 2: Laplace Approximation for Continual Learning ---
print("\n--- Scenario 2: Laplace Approximation for Continual Learning ---")
key_cl2 = jax.random.PRNGKey(40)

# Train on Task 1 and get its Laplace posterior
theta_task1_laplace, unflatten_fn_cl = find_map_mlp_params(
    key_cl2, X_task1, y_task1, cl_layer_sizes, lambda_reg_cl, noise_variance_cl, max_iter=1000, tol=1e-7
)

# Compute the Hessian (Psi) from Task 1's posterior
flat_theta_task1_laplace, _ = ravel_pytree(theta_task1_laplace)
def flat_neg_log_posterior_task1(flat_p):
    p = unflatten_fn_cl(flat_p)
    return neg_log_posterior(p, X_task1, y_task1, lambda_reg_cl, noise_variance_cl)
Psi_task1 = -jax.hessian(flat_neg_log_posterior_task1)(flat_theta_task1_laplace)

# Define the new objective for Task 2, using Task 1's Laplace posterior as prior
def neg_log_posterior_task2_cl(params, X_current, y_current, theta_prev_map, Psi_prev):
    logits_current = mlp_forward(params, X_current)
    neg_log_likelihood_current = 0.5 * jnp.sum(jnp.square(y_current - logits_current)) / noise_variance_cl

    # Laplace prior from previous task
    flat_params_current, _ = ravel_pytree(params)
    flat_theta_prev_map, _ = ravel_pytree(theta_prev_map)
    diff = flat_params_current - flat_theta_prev_map
    neg_log_prior_laplace = 0.5 * diff.T @ Psi_prev @ diff

    return neg_log_likelihood_current + neg_log_prior_laplace

# Train on Task 2 with Laplace prior from Task 1
def find_map_mlp_params_cl(key, X_train, y_train, initial_params, theta_prev_map, Psi_prev, max_iter=100, tol=1e-5):
    flat_params, unflatten_fn = ravel_pytree(initial_params)
    print("Starting Newton's method for CL MLP MAP estimate...")
    for i in range(max_iter):
        # The gradient and Hessian need to be computed for the CL objective
        grad_fn = jax.grad(lambda p_flat: neg_log_posterior_task2_cl(unflatten_fn(p_flat), X_train, y_train, theta_prev_map, Psi_prev))
        hess_fn = jax.hessian(lambda p_flat: neg_log_posterior_task2_cl(unflatten_fn(p_flat), X_train, y_train, theta_prev_map, Psi_prev))
        
        grad_val = grad_fn(flat_params)
        hess_val = hess_fn(flat_params)

        delta_flat_params = jnp.linalg.solve(hess_val, grad_val)
        flat_params_new = flat_params - delta_flat_params
        change = jnp.linalg.norm(flat_params_new - flat_params)
        if change < tol:
            print(f"Converged in {i+1} iterations. Final change: {change:.6f}")
            break
        flat_params = flat_params_new
    else:
        print(f"Newton's method did not converge after {max_iter} iterations. Final change: {change:.6f}")
    return unflatten_fn(flat_params)

initial_params_task2 = init_mlp_params(key_cl2, cl_layer_sizes) # Re-initialize for Task 2 or use theta_task1_laplace
theta_task2_laplace_cl = find_map_mlp_params_cl(
    key_cl2, X_task2, y_task2, initial_params_task2, theta_task1_laplace, Psi_task1, max_iter=1000, tol=1e-7
)

# Evaluate performance on both tasks after training on Task 2 with CL
preds_task1_after_task2_cl = mlp_forward(theta_task2_laplace_cl, X_task1)
preds_task2_after_task2_cl = mlp_forward(theta_task2_laplace_cl, X_task2)

print(f"MSE on Task 1 after Task 2 (Laplace CL): {mse_loss(preds_task1_after_task2_cl, y_task1):.4f}")
print(f"MSE on Task 2 after Task 2 (Laplace CL): {mse_loss(preds_task2_after_task2_cl, y_task2):.4f}")

fig_laplace_cl = plot_regression_plotly(
    X_combined_test,
    mlp_forward(theta_task2_laplace_cl, X_combined_test),
    predictions=mlp_forward(theta_task2_laplace_cl, X_combined_test),
    title='Continual Learning: Laplace Approx. Approach (After Task 2 Training)'
)
fig_laplace_cl.add_trace(go.Scatter(x=X_task1.flatten(), y=y_task1.flatten(), mode='markers', name='Task 1 Data', marker=dict(color='blue')))
fig_laplace_cl.add_trace(go.Scatter(x=X_task2.flatten(), y=y_task2.flatten(), mode='markers', name='Task 2 Data', marker=dict(color='green')))
fig_laplace_cl.update_layout(height=600, width=800)
fig_laplace_cl.show()


You should observe that the Laplace approximation approach (Scenario 2) retains more knowledge of Task 1 compared to the naive approach (Scenario 1), as indicated by a lower MSE on Task 1 after training on Task 2. This demonstrates the power of using a full posterior (even an approximate one) as a prior for sequential learning.

#### 7. Summary

In summary (as per Slide 19):

* **Uncertainty in Deep Learning fixes (asymptotic and local) overconfidence**: By providing a principled way to quantify predictive uncertainty, probabilistic deep learning models can give more honest and reliable predictions, especially in regions far from the training data.
* **It yields the functionality for continual learning**: The Bayesian framework naturally allows for sequential updating of models, where the posterior from a previous task becomes the prior for the next, mitigating catastrophic forgetting.
* **Many other applications not discussed here**: Uncertainty quantification is crucial for active learning, out-of-distribution detection, robust decision-making, and more.

Ultimately, **Laplace approximations turn deep networks into GPs, inheriting all functionality of GPs**. This powerful connection allows us to leverage the best of both worlds: the representational power of deep learning and the robust probabilistic inference capabilities of Gaussian Processes.

#### Exercises

**Exercise 1: Overconfidence with Different Data Distributions**
Change the `generate_data` type in the overconfidence illustration (Section 3) to `'circles'` or `'separable'`. How does the overconfidence manifest in these different datasets? Does the Laplace approximation still help in mitigating it? Visualize and discuss.

**Exercise 2: Impact of Regularization on Overconfidence**
In the overconfidence illustration, experiment with a much larger `lambda_reg` (e.g., 0.1 or 1.0). How does strong regularization affect the deterministic model's overconfidence? And how does it impact the uncertainty of the Laplace-approximated Deep GP?

**Exercise 3: Simulating More Complex Continual Learning**
Extend the continual learning example (Section 6) to three or more tasks. Observe how catastrophic forgetting progresses in the naive approach and how the Laplace approximation helps. You could try different `offset` values or even different `type` of data for each task (if adaptable to 1D).

**Exercise 4 (Advanced): Implementing a Hessian Approximation**
Modify the `predict_laplace_deep_gp_classification` function to use a simpler Hessian approximation, such as a diagonal approximation. Instead of `Psi = -hess_neg_log_post_at_theta_star`, compute `Psi = -jnp.diag(jnp.diag(hess_neg_log_post_at_theta_star))`. How does this simpler approximation affect the predictive uncertainty and the computational cost (conceptually)? Discuss the trade-offs.