# Probabilistic Machine Learning: Lecture 17 - Probabilistic Deep Learning

#### Introduction

Welcome to Lecture 17 of Probabilistic Machine Learning! In this lecture, we bridge the gap between deep neural networks and Gaussian Processes, exploring how we can imbue deep learning models with probabilistic capabilities. We will focus on the powerful technique of using Laplace approximations to transform a trained deep network into an approximate Gaussian Process, allowing for uncertainty quantification and a deeper probabilistic understanding.

This notebook will provide detailed explanations and practical code illustrations using **JAX** for efficient numerical computations and **Plotly** for interactive visualizations, building upon the foundations laid in previous lectures.

#### 1. Recap: Deep Networks and Empirical Risk Minimization

As we recapped in Lecture 16 (Slide 2), for our purposes, a deep neural network is a function $f(x, \theta): \mathbb{X} \times \mathbb{R}^D \to \mathbb{R}^F$, parametrized by parameters $\theta \in \mathbb{R}^D$ and mapping inputs $x \in \mathbb{X}$ to outputs $f(x, \theta) \in \mathbb{R}^F$. These networks are typically trained by **Empirical Risk Minimization (ERM)** to find parameters $\theta_*$ on a training set $\mathcal{D}=[(x_i,y_i)]_{i=1,...,N}$:

$$\theta_* = \arg \min_{\theta} \mathcal{L}(\theta) = \arg \min_{\theta} \left( \frac{1}{N} \sum_{i=1}^N \ell(y_i, f(x_i, \theta)) + r(\theta) \right)$$

We also established that this ERM objective is equivalent to finding the **Maximum A Posteriori (MAP) estimate** of the parameters:

$$\theta_* = \arg \max_{\theta \in \mathbb{R}^D} p(\theta | \mathcal{D})$$

This probabilistic interpretation is key to understanding how we can turn a deep network into a Gaussian Process. Let's set up our necessary imports and utility functions, including a simple MLP model from Lecture 16.

In [None]:
import jax.numpy as jnp
import jax
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from jax.flatten_util import ravel_pytree  # Utility to flatten/unflatten JAX pytrees

# Set JAX to use 64-bit floats for numerical stability
jax.config.update("jax_enable_x64", True)


# --- Utility Functions ---
def sigmoid(f):
    """Logistic sigmoid function."""
    return 1 / (1 + jnp.exp(-f))


def relu(x):
    """ReLU activation function."""
    return jnp.maximum(0, x)


def rbf_kernel(X1, X2, length_scale=1.0):
    """Radial Basis Function (RBF) kernel."""
    sqdist = jnp.sum(X1**2, 1)[:, None] + jnp.sum(X2**2, 1) - 2 * jnp.dot(X1, X2.T)
    return jnp.exp(-0.5 * (1 / length_scale**2) * sqdist)


def generate_data(type="sin_wave", n_samples=50, noise_std=0.1):
    """Generates synthetic 1D regression data."""
    np.random.seed(42)
    if type == "sin_wave":
        X = np.linspace(-3, 3, n_samples).reshape(-1, 1)
        y = np.sin(X * 2) + np.cos(X * 3) + noise_std * np.random.randn(n_samples, 1)
    elif type == "quadratic":
        X = np.linspace(-2, 2, n_samples).reshape(-1, 1)
        y = 0.5 * X**2 - 1.0 * X + 0.3 + noise_std * np.random.randn(n_samples, 1)
    return X, y


def plot_regression_plotly(
    X, y, predictions=None, std_dev=None, title="", fig=None, row=None, col=None
):
    """Plots 1D regression data and predictions using Plotly."""
    if fig is None:
        fig = go.Figure()

    X_np = np.asarray(X)
    y_np = np.asarray(y)

    fig.add_trace(
        go.Scatter(
            x=X_np.flatten(),
            y=y_np.flatten(),
            mode="markers",
            marker=dict(color="blue", size=6),
            name="Training Data",
            showlegend=True,
        ),
        row=row,
        col=col,
    )

    if predictions is not None:
        predictions_np = np.asarray(predictions).flatten()
        fig.add_trace(
            go.Scatter(
                x=X_np.flatten(),
                y=predictions_np,
                mode="lines",
                line=dict(color="red", width=2),
                name="Mean Prediction",
                showlegend=True,
            ),
            row=row,
            col=col,
        )

        if std_dev is not None:
            std_dev_np = np.asarray(std_dev).flatten()
            upper_bound = predictions_np + 2 * std_dev_np
            lower_bound = predictions_np - 2 * std_dev_np
            fig.add_trace(
                go.Scatter(
                    x=np.concatenate([X_np.flatten(), X_np.flatten()[::-1]]),
                    y=np.concatenate([upper_bound, lower_bound[::-1]]),
                    fill="toself",
                    fillcolor="rgba(255,0,0,0.1)",
                    line_color="rgba(255,255,255,0)",
                    name="2 Std Dev Uncertainty",
                    showlegend=True,
                ),
                row=row,
                col=col,
            )

    fig.update_layout(title_text=title, title_x=0.5)
    fig.update_xaxes(title_text="X", row=row, col=col)
    fig.update_yaxes(title_text="Y", row=row, col=col)

    return fig


# --- Simple MLP Implementation in JAX ---
def init_mlp_params(key, layer_sizes):
    """Initializes parameters for a simple MLP."""
    params = []
    for i in range(len(layer_sizes) - 1):
        key, subkey = jax.random.split(key)
        in_dim = layer_sizes[i]
        out_dim = layer_sizes[i + 1]
        # Glorot initialization for weights
        limit = jnp.sqrt(6 / (in_dim + out_dim))
        weights = jax.random.uniform(
            subkey, (in_dim, out_dim), minval=-limit, maxval=limit
        )
        biases = jnp.zeros(out_dim)
        params.append({"weights": weights, "biases": biases})
    return params


def mlp_forward(params, x):
    """Forward pass through the MLP."""
    hidden_layers = params[:-1]
    output_layer = params[-1]

    h = x
    for layer in hidden_layers:
        h = jnp.dot(h, layer["weights"]) + layer["biases"]
        h = relu(h)  # Using ReLU as nonlinearity

    # Output layer (no activation for regression, as it's a linear output)
    output = jnp.dot(h, output_layer["weights"]) + output_layer["biases"]
    return output


# --- Loss and Regularization Functions ---
def mse_loss(predictions, targets):
    """Mean Squared Error loss."""
    return jnp.mean(jnp.square(predictions - targets))


def l2_regularization(params, lambda_reg):
    """L2 regularization (weight decay)."""
    l2_norm = 0.0
    for layer in params:
        l2_norm += jnp.sum(jnp.square(layer["weights"]))
    return 0.5 * lambda_reg * l2_norm


def neg_log_posterior(params, X, y, lambda_reg, noise_variance):
    """
    Negative log-posterior for MLP parameters (equivalent to ERM objective).
    Assumes Gaussian likelihood and Gaussian prior on weights (mean zero, precision lambda_reg/sigma_noise^2).
    """
    predictions = mlp_forward(params, X)
    # Negative log-likelihood (Gaussian likelihood)
    neg_log_likelihood = 0.5 * jnp.sum(jnp.square(y - predictions)) / noise_variance
    # Negative log-prior (Gaussian prior on weights)
    neg_log_prior = l2_regularization(params, lambda_reg)
    return neg_log_likelihood + neg_log_prior


# --- Newton's Method for MAP estimate (from Lecture 15, adapted) ---
@jax.jit
def newton_step_mlp(flat_params, unflatten_fn, X, y, lambda_reg, noise_variance):
    """Performs one Newton update step for maximizing the log posterior of MLP params."""
    params = unflatten_fn(flat_params)

    # Compute gradient and Hessian of the negative log posterior
    grad_fn = jax.grad(neg_log_posterior)
    hess_fn = jax.hessian(neg_log_posterior)

    grad_val = grad_fn(params, X, y, lambda_reg, noise_variance)
    hess_val = hess_fn(params, X, y, lambda_reg, noise_variance)

    # Flatten grad and hess for linear algebra
    flat_grad, _ = ravel_pytree(grad_val)
    flat_hess, _ = ravel_pytree(hess_val)

    # Newton update: flat_params_new = flat_params_old - H_inv @ grad
    delta_flat_params = jnp.linalg.solve(flat_hess, flat_grad)
    flat_params_new = flat_params - delta_flat_params
    return flat_params_new


def find_map_mlp_params(
    key,
    X_train,
    y_train,
    layer_sizes,
    lambda_reg,
    noise_variance,
    max_iter=100,
    tol=1e-5,
):
    """
    Finds the MAP estimate (theta_star) for MLP parameters using Newton's method.
    """
    initial_params = init_mlp_params(key, layer_sizes)
    flat_params, unflatten_fn = ravel_pytree(initial_params)

    print("Starting Newton's method for MLP MAP estimate...")
    for i in range(max_iter):
        flat_params_new = newton_step_mlp(
            flat_params, unflatten_fn, X_train, y_train, lambda_reg, noise_variance
        )
        change = jnp.linalg.norm(flat_params_new - flat_params)
        if change < tol:
            print(f"Converged in {i + 1} iterations. Final change: {change:.6f}")
            break
        flat_params = flat_params_new
    else:
        print(
            f"Newton's method did not converge after {max_iter} iterations. Final change: {change:.6f}"
        )

    theta_star = unflatten_fn(flat_params)
    return theta_star


#### 3. Deep Networks are GPs: The Four Easy Steps

The core idea of turning a deep network into a Gaussian Process is elegantly summarized in four steps (as per Slide 4):

1.  **Realize that the loss is a negative log-posterior**: As discussed, the ERM objective $\mathcal{L}(\theta)$ is equivalent to the negative log-posterior of the parameters, up to a constant:
    $$\mathcal{L}(\theta) = \left( \frac{1}{N} \sum_{i=1}^N \ell(y_i, f(x_i, \theta)) \right) + r(\theta) = -\log p(\mathcal{D}|\theta) - \log p(\theta) = -\log p(\theta|\mathcal{D}) + \text{const.}$$

2.  **Train the deep net as usual to find $\theta_*$**: This is the standard deep learning training process, which finds the MAP estimate of the parameters:
    $$\theta_* = \arg \max_{\theta \in \mathbb{R}^D} p(\theta|\mathcal{D})$$

3.  **At $\theta_*$, compute a Laplace approximation of the log-posterior**: We approximate the (typically non-Gaussian) posterior $p(\theta|\mathcal{D})$ with a Gaussian distribution centered at $\theta_*$. This involves computing the Hessian matrix $\Psi$ of the negative log-posterior at $\theta_*$:
    $$\Psi := -\nabla\nabla^T \log p(\theta_*|\mathcal{D})$$
    The Gaussian approximation is then $\mathcal{N}(\theta; \theta_*, -\Psi^{-1})$.

4.  **Linearize $f(x, \theta)$ around $\theta_*$**: We approximate the deep network's output $f(x, \theta)$ with a first-order Taylor expansion around the MAP estimate $\theta_*$:
    $$f(x, \theta) \approx f(x, \theta_*) + J(x, \theta_*) (\theta - \theta_*)$$
    where $J(x, \theta_*)$ is the Jacobian matrix of $f(x, \theta)$ with respect to $\theta$, evaluated at $x$ and $\theta_*$. Its elements are $[J(x)]_{ij} = \frac{\partial f_i(x, \theta_*)}{\partial \theta_j}$.

Combining these steps, the posterior distribution over the function output $f(\bullet)$ given the data $\mathcal{D}$ can be approximated as a Gaussian Process:

$$p(f(\bullet)|\mathcal{D}) \approx \mathcal{GP}(f(\bullet); f(\bullet, \theta_*), -J(\bullet)\Psi^{-1}J(\circ)^T)$$

Thus:
* The **mean function** of this approximate GP is the trained deep network itself: $\mathbb{E}(f(\bullet)) = f(\bullet, \theta_*)$.
* The **covariance function** is the **Laplace tangent kernel**: $\text{cov}(f(\bullet), f(\circ)) = -J(\bullet)\Psi^{-1}J(\circ)^T$.

Let's put this into practice with a code example. We'll train a small MLP on a 1D regression task, find its MAP parameters, and then use the Laplace approximation to get predictive mean and uncertainty.

In [None]:
# --- Main Execution for Deep GP Approximation ---

# 1. Generate 1D regression data
X_train_np, y_train_np = generate_data(type="sin_wave", n_samples=50, noise_std=0.2)
X_train_jax = jnp.array(X_train_np)
y_train_jax = jnp.array(y_train_np)

# Define MLP architecture
input_dim = 1
hidden_dim = 20  # A bit wider to allow for more expressiveness
output_dim = 1
layer_sizes = [input_dim, hidden_dim, hidden_dim, output_dim]  # Two hidden layers

# Hyperparameters for training and Laplace approximation
lambda_reg = 0.01  # L2 regularization strength
noise_variance = 0.1**2  # Assumed noise variance for likelihood

# 2. Find the MAP estimate (theta_star) using Newton's method
key = jax.random.PRNGKey(10)
theta_star = find_map_mlp_params(
    key,
    X_train_jax,
    y_train_jax,
    layer_sizes,
    lambda_reg,
    noise_variance,
    max_iter=1000,
    tol=1e-7,
)

print("\nMAP estimate (theta_star) of MLP parameters found.")

# 3. Compute the Hessian (Psi) of the negative log-posterior at theta_star
flat_theta_star, unflatten_theta = ravel_pytree(theta_star)


# Define a function that computes the negative log-posterior for flattened parameters
def flat_neg_log_posterior(flat_params, X, y, lambda_reg, noise_variance):
    params = unflatten_theta(flat_params)
    return neg_log_posterior(params, X, y, lambda_reg, noise_variance)


hess_neg_log_post_at_theta_star = jax.hessian(flat_neg_log_posterior)(
    flat_theta_star, X_train_jax, y_train_jax, lambda_reg, noise_variance
)

# Psi is the negative of this Hessian
Psi = -hess_neg_log_post_at_theta_star

# For the covariance, we need -Psi_inv, which is Psi_inv if Psi is already negative of Hessian
Sigma_theta_star = jnp.linalg.inv(-Psi)  # Covariance of theta_star approx

print(f"Shape of Psi (Hessian): {Psi.shape}")
print(f"Shape of Sigma_theta_star: {Sigma_theta_star.shape}")

# 4. Linearize f(x, theta) around theta_star and compute predictions
X_test_np = np.linspace(-4, 4, 200).reshape(-1, 1)
X_test_jax = jnp.array(X_test_np)


# Define a function to get the Jacobian of mlp_forward with respect to parameters
def get_jacobian_fn(params, x_val):
    # The Jacobian of mlp_forward(params, x_val) w.r.t. params
    # We need to wrap mlp_forward to take flat_params as first arg
    def f_flat_params(flat_p):
        p = unflatten_theta(flat_p)
        return mlp_forward(p, x_val)

    return jax.jacobian(f_flat_params)(flat_theta_star)


predictive_means = []
predictive_variances = []

for x_t in X_test_jax:
    # Compute J(x_t, theta_star)
    J_xt = get_jacobian_fn(
        theta_star, x_t.reshape(1, -1)
    )  # Reshape x_t to (1, input_dim)

    # Ensure J_xt is 2D (num_outputs, num_flat_params)
    if (
        J_xt.ndim == 1
    ):  # For 1D output, jax.jacobian might return 1D array if input is 1D
        J_xt = J_xt.reshape(1, -1)

    # Predictive Mean: E[f(x)] = f(x, theta_star)
    mean_pred = mlp_forward(theta_star, x_t.reshape(1, -1)).flatten()[
        0
    ]  # Ensure scalar output

    # Predictive Covariance: cov(f(x), f(x')) = J(x) @ Sigma_theta_star @ J(x').T
    # For variance at a single point x, it's J(x) @ Sigma_theta_star @ J(x).T
    # Plus noise variance for the observed data (if predicting observed y, otherwise latent f)
    variance_pred_latent = (J_xt @ Sigma_theta_star @ J_xt.T).flatten()[0]
    variance_pred_observed = (
        variance_pred_latent + noise_variance
    )  # Add noise for observed predictions

    predictive_means.append(mean_pred)
    predictive_variances.append(variance_pred_observed)

predictive_means = jnp.array(predictive_means)
predictive_std_devs = jnp.sqrt(jnp.array(predictive_variances))

print("\nPredictions computed using Laplace Approximation (Deep GP)...")

# Plotting the results
fig = plot_regression_plotly(
    X_test_np,
    predictive_means,
    predictions=predictive_means,
    std_dev=predictive_std_devs,
    title="Probabilistic Deep Learning: MLP as a GP (Laplace Approx.)",
)

# Add training data to the plot
fig.add_trace(
    go.Scatter(
        x=X_train_np.flatten(),
        y=y_train_np.flatten(),
        mode="markers",
        marker=dict(color="black", size=8, symbol="x"),
        name="Training Data (Observed)",
        showlegend=True,
    )
)

fig.update_layout(height=600, width=800)
fig.show()


#### 4. What's to Like? (Advantages of Laplace Approximations for Deep Learning)

This approach offers several compelling advantages for deep learning practitioners (as per Slide 7):

* **You get to keep your beloved point estimate**: The mean function of the resulting GP is simply the prediction from the trained deep network $f(\bullet, \theta_*)$. This means all the performance benefits of the original deep model are retained.
* **You get to keep your beloved training procedure**: The Laplace approximation is constructed *post-hoc*, after the standard deep learning training (e.g., SGD, Adam) is completed. This applies even to pre-trained networks downloaded from the internet, provided the model architecture, training data, and loss function are available.
* **Only auto-diff and numerical linear algebra are needed**: Unlike other probabilistic deep learning methods (e.g., MCMC, variational inference), Laplace approximation avoids sampling, stochasticity, or complex ensemble training. JAX's automatic differentiation makes computing the Hessian and Jacobian straightforward.
* **The result is a GP, with all the trimmings**: By approximating the deep net as a GP, we gain access to all the benefits of Gaussian Processes, including:
    * **Evidence estimation**: For hyperparameter tuning.
    * **Sampling from the posterior**: To understand the range of plausible functions.
    * **Uncertainty quantification**: Providing principled confidence intervals for predictions, which is crucial for safety-critical applications or active learning.
    * **Sparse decompositions**: For scaling to larger datasets.

#### 5. Challenges of Laplace Approximations for Deep Learning

Despite its advantages, the Laplace approximation for deep learning also presents challenges (as per Slide 8):

* **Hessian decomposition is $\mathcal{O}(D^3)$**: The computation and inversion of the Hessian matrix $\Psi \in \mathbb{R}^{D \times D}$ (where $D$ is the number of parameters in the deep network) can be computationally very expensive for large networks. However, active research focuses on approximations (e.g., K-FAC, low-rank approximations) to make this more tractable.
* **Laplace approximations are local**: They are based on a Taylor expansion around a single mode ($\theta_*$) of the posterior. This means they can be "arbitrarily wrong" compared to the full, true posterior, especially if the posterior is multi-modal or highly non-Gaussian. However, they are still generally better than a simple point estimate.
* **Loss functions not designed for generative models**: The loss functions used in deep learning (e.g., cross-entropy, MSE) are primarily designed for predictive performance, not necessarily for accurately capturing the underlying generative process or full uncertainty. This can limit the quality of the probabilistic interpretation derived from the Laplace approximation.

#### 6. The Laplace Tangent Kernel: Some Context

The covariance function we derived, $-J(\bullet)\Psi^{-1}J(\circ)^T$, is known as the **Laplace Tangent Kernel** (as per Slide 9).

It's important to differentiate this from the **Neural Tangent Kernel (NTK)**, introduced by Jacot, Gabriel, and Hongler (NeurIPS 2019). The NTK is defined as:

$$k_{\text{NTK}}(\bullet, \circ) = J_{\theta_0}(\bullet) J_{\theta_0}^T(\circ) = \sum_{d=1}^D \frac{\partial f(\bullet, \theta_0)}{\partial [\theta_0]_d} \frac{\partial f(\circ, \theta_0)}{\partial [\theta_0]_d}$$

The NTK is typically evaluated at a random initialization $\theta_0$ and remains fixed during training. It is a theoretical tool for analyzing the behavior of gradient descent in infinitely wide networks and does not directly yield meaningful uncertainty quantification in the same way the Laplace approximation does. The Laplace Tangent Kernel, in contrast, uses the *trained* parameters $\theta_*$ and incorporates the curvature of the posterior (via $\Psi^{-1}$), providing a direct measure of uncertainty.

#### 7. Summary

In summary (as per Slide 10):

* **Laplace approximations turn (nearly) any deep neural network into a Gaussian process.** This is a powerful way to bridge the gap between deterministic deep learning and probabilistic modeling.
* They involve **only auto-differentiation and linear algebra**, both of which are robust and scalable operations within modern deep learning frameworks like JAX.
* **Deep nets thus approximately inherit probabilistic functionality**, gaining the ability to quantify uncertainty in their predictions.
* For large-scale deep networks, **care must be taken to find approximate solutions to the Hessian decomposition** to manage computational complexity.

This lecture demonstrates a practical method for adding probabilistic capabilities to deep learning models, offering a path towards more robust and interpretable AI systems.

#### Exercises

**Exercise 1: Impact of Regularization and Noise Variance**
Experiment with different values for `lambda_reg` (L2 regularization strength) and `noise_variance` in the `main` execution block. How do these hyperparameters affect the predictive mean and, more importantly, the predictive uncertainty (the shaded region)? Explain your observations in terms of their probabilistic interpretation (prior strength and likelihood noise).

**Exercise 2: Effect of Network Depth and Width**
Modify the `layer_sizes` in the MLP to create a shallower (e.g., `[input_dim, output_dim]`) or deeper (e.g., `[input_dim, 50, 50, 50, output_dim]`) network. How does this change the training process and the resulting predictive mean and uncertainty? (Note: Deeper networks might require more `max_iter` or careful initialization).

**Exercise 3: Comparing to a Pure GP**
Implement a standard Gaussian Process regression model (using `rbf_kernel` and the exact GP regression formulas) on the same `sin_wave` dataset. Compare its predictive mean and uncertainty to the Laplace-approximated Deep GP. Discuss similarities and differences.

**Exercise 4 (Advanced): Hessian Approximation**
For very large networks, computing the full Hessian is infeasible. Research one method for approximating the Hessian (e.g., K-FAC, diagonal approximation, block-diagonal approximation). Briefly describe how it works and its trade-offs in terms of accuracy and computational cost.