# Probabilistic Machine Learning: Lecture 14 - Logistic Regression with JAX and Plotly

#### Introduction

Welcome to Lecture 14 of Probabilistic Machine Learning, focusing on Logistic Regression! This notebook aims to translate the core concepts from the lecture slides into a clear, understandable format, augmented with Python code examples and exercises to reinforce your learning. We'll be using **JAX** for efficient numerical computations and **Plotly** for interactive visualizations.

#### 1. From Regression to Classification: A Conceptual Shift

In previous lectures, we primarily focused on **regression** problems, where the goal is to predict a continuous output variable $Y$ given input data $X$. For instance, in least-squares regression, we aim to find a function $f: X \to \mathbb{R}^d$ such that $Y \approx f(X)$.

However, many real-world problems involve predicting discrete categories or labels. This is where **classification** comes in. In classification, given supervised data 
$$(X, Y) = (x_i, c_i)_{i=1,...,n}$$
 where $x_i \in \mathcal{X}$ and $c_i \in \{1, ..., d\}$, our goal is to find a probability distribution 
 $$\pi: X \to \mathcal{U}^d$$
  (where $\mathcal{U}^d = \{p \in [0, 1]^d : \sum_{i=1}^d p_i = 1\}$) such that $\pi$ "models" $y_i \sim \pi_{x_i}$.

**Key Distinction**: Regression predicts a function, while classification predicts a probability.

For simplicity, we will initially focus on **discriminative binary classification**, where the output $y$ can take one of two values, typically $y \in \{-1, +1\}$. In this setting, we want to learn $\pi(x) = p(y | x) \in [0, 1]$. The probability distribution for $y$ given $x$ can be expressed as:

$$p(y | x) = \begin{cases} \pi(x) & \text{if } y = 1 \\ 1 - \pi(x) & \text{if } y = -1 \end{cases}$$

##### **Visualizing Classification Problems**

The lecture slides provide several examples of classification problems with varying degrees of separability. Let's recreate these visualizations to understand what we're trying to achieve.

* **Slide 6 & 7**: Clearly separable classes. A linear boundary works well.
* **Slide 8 & 9**: Overlapping classes. A linear boundary still attempts to separate, but there will be misclassifications.
* **Slide 10 & 11**: Highly intermingled classes. A simple linear boundary is insufficient.

**Activity 1.1: Recreate Classification Problem Visualizations**

Let's generate some synthetic 2D data to mimic the classification problems shown in the slides. We'll start by defining our necessary imports and utility functions.

In [None]:
import jax.numpy as jnp
import jax
import numpy as np  # Still using numpy for data generation
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Set JAX to use 64-bit floats for numerical stability, matching numpy's default
jax.config.update("jax_enable_x64", True)


# --- Utility Functions ---
def sigmoid(f):
    """Logistic sigmoid function."""
    return 1 / (1 + jnp.exp(-f))


def rbf_kernel(X1, X2, length_scale=1.0):
    """Radial Basis Function (RBF) kernel."""
    sqdist = jnp.sum(X1**2, 1)[:, None] + jnp.sum(X2**2, 1) - 2 * jnp.dot(X1, X2.T)
    return jnp.exp(-0.5 * (1 / length_scale**2) * sqdist)


def generate_data(type="separable", n_samples=100):
    """Generates synthetic 2D classification data (using numpy for random generation)."""
    np.random.seed(42)  # Ensure reproducibility for data generation
    if type == "separable":
        mean1 = [-1, 0.5]
        cov1 = [[0.5, 0.2], [0.2, 0.5]]
        data1 = np.random.multivariate_normal(mean1, cov1, n_samples // 2)
        labels1 = np.ones(n_samples // 2) * -1
        mean2 = [1, -0.5]
        cov2 = [[0.5, 0.2], [0.2, 0.5]]
        data2 = np.random.multivariate_normal(mean2, cov2, n_samples // 2)
        labels2 = np.ones(n_samples // 2)
    elif type == "overlapping":
        mean1 = [-0.5, 0.5]
        cov1 = [[1.0, 0.5], [0.5, 1.0]]
        data1 = np.random.multivariate_normal(mean1, cov1, n_samples // 2)
        labels1 = np.ones(n_samples // 2) * -1
        mean2 = [0.5, -0.5]
        cov2 = [[1.0, 0.5], [0.5, 1.0]]
        data2 = np.random.multivariate_normal(mean2, cov2, n_samples // 2)
        labels2 = np.ones(n_samples // 2)
    elif type == "intermingled":
        r1 = np.random.rand(n_samples // 2) * 2
        theta1 = np.random.rand(n_samples // 2) * 2 * np.pi
        data1 = np.array(
            [
                r1 * np.cos(theta1) + np.random.randn(n_samples // 2) * 0.2,
                r1 * np.sin(theta1) + np.random.randn(n_samples // 2) * 0.2,
            ]
        ).T
        labels1 = np.ones(n_samples // 2) * -1
        r2 = np.random.rand(n_samples // 2) * 2 + 1.5
        theta2 = np.random.rand(n_samples // 2) * 2 * np.pi
        data2 = np.array(
            [
                r2 * np.cos(theta2) + np.random.randn(n_samples // 2) * 0.2,
                r2 * np.sin(theta2) + np.random.randn(n_samples // 2) * 0.2,
            ]
        ).T
        labels2 = np.ones(n_samples // 2)
        all_data = np.vstack((data1, data2))
        all_labels = np.hstack((labels1, labels2))
        perm = np.random.permutation(n_samples)
        data1 = all_data[all_labels == -1]
        data2 = all_data[all_labels == 1]
    X = np.vstack((data1, data2))
    y = np.hstack((labels1, labels2))
    return X, y


def plot_data_plotly(X, y, title="", fig=None, row=None, col=None):
    """Plots 2D classification data using Plotly."""
    if fig is None:
        fig = go.Figure()

    # Convert JAX arrays to NumPy for Plotly
    X_np = np.asarray(X)
    y_np = np.asarray(y)

    fig.add_trace(
        go.Scatter(
            x=X_np[y_np == -1, 0],
            y=X_np[y_np == -1, 1],
            mode="markers",
            marker=dict(color="maroon", symbol="circle"),
            name="Class -1",
            showlegend=True,
        ),
        row=row,
        col=col,
    )

    fig.add_trace(
        go.Scatter(
            x=X_np[y_np == 1, 0],
            y=X_np[y_np == 1, 1],
            mode="markers",
            marker=dict(
                color="skyblue", symbol="circle", line=dict(width=1, color="skyblue")
            ),
            name="Class +1",
            showlegend=True,
        ),
        row=row,
        col=col,
    )

    fig.update_layout(title_text=title, title_x=0.5)
    fig.update_xaxes(title_text="$x_1$", range=[-4, 4], row=row, col=col)
    fig.update_yaxes(
        title_text="$x_2$",
        range=[-4, 4],
        scaleanchor="x",
        scaleratio=1,
        row=row,
        col=col,
    )

    return fig


Now, let's generate and plot our synthetic classification datasets.

In [None]:
# Generate different types of data for demonstration
X_sep_np, y_sep_np = generate_data(type="separable", n_samples=100)
X_ovl_np, y_ovl_np = generate_data(type="overlapping", n_samples=100)
X_int_np, y_int_np = generate_data(type="intermingled", n_samples=100)

# Create subplots for data visualization
fig_data = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=("Separable Classes", "Overlapping Classes", "Intermingled Classes"),
)

plot_data_plotly(X_sep_np, y_sep_np, fig=fig_data, row=1, col=1)
plot_data_plotly(X_ovl_np, y_ovl_np, fig=fig_data, row=1, col=2)
plot_data_plotly(X_int_np, y_int_np, fig=fig_data, row=1, col=3)

fig_data.update_layout(showlegend=False, height=500, width=1500)
fig_data.show()

#### 2. Link Functions: Connecting Latent Functions to Probabilities

Since classification predicts probabilities, and probabilities are restricted to the $[0, 1]$ interval, we cannot directly use a linear model $f(x) = w^T x + b$ (which can output any real value). We need a "link function" that maps the real-valued output of a latent function (like $f(x)$) to a probability.

The most common link function in binary classification is the **logistic sigmoid function** (often just called the sigmoid function):

$$\sigma(f) = \frac{1}{1 + \exp(-f)}$$

This function has several desirable properties:
* It maps any real number $f$ to a value between 0 and 1.
* It is monotonically increasing.
* It is differentiable, which is crucial for optimization.

The inverse of the sigmoid function, $f(\pi) = \ln \frac{\pi}{1-\pi}$, is called the **logit function**. The derivative of the sigmoid function with respect to $f$ is given by $\frac{d\pi}{df} = \pi(f)(1 - \pi(f))$.

**Activity 2.1: Plot the Sigmoid Function**

Let's visualize the sigmoid function and its properties.

In [None]:
# Generate a range of f values
f_values = np.linspace(-5, 5, 100)
sigma_f_values = sigmoid(jnp.array(f_values))  # Convert to JAX array for sigmoid

fig_sigmoid = go.Figure()
fig_sigmoid.add_trace(
    go.Scatter(
        x=f_values,
        y=np.asarray(sigma_f_values),
        mode="lines",
        name="$\\sigma(f) = 1 / (1 + \\exp(-f))$",
    )
)
fig_sigmoid.add_shape(
    type="line",
    x0=-5,
    y0=0.5,
    x1=5,
    y1=0.5,
    line=dict(color="gray", dash="dot", width=0.8),
)
fig_sigmoid.add_shape(
    type="line", x0=0, y0=0, x1=0, y1=1, line=dict(color="gray", dash="dot", width=0.8)
)
fig_sigmoid.add_annotation(
    x=0.1, y=0.55, text="$\\sigma(0) = 0.5$", showarrow=False, font=dict(size=10)
)
fig_sigmoid.update_layout(
    title_text="Logistic Sigmoid Link Function",
    title_x=0.5,
    xaxis_title="$f$",
    yaxis_title="$\\sigma(f)$",
)
fig_sigmoid.show()

#### 3. Gaussian Process Model for Classification: Logistic Regression

In a probabilistic setting, we can define a **Gaussian Process (GP) model for classification**, often referred to as **Logistic Regression** when using a linear basis function or when the latent function $f$ is a Gaussian Process.

The core idea is to place a Gaussian Process prior over the latent function $f(x)$:

$$p(f) = \mathcal{GP}(f; m, k)$$

where $m$ is the mean function and $k$ is the covariance (kernel) function.

Then, the likelihood of observing a label $y$ given the latent function output $f_x$ at point $x$ is defined using the sigmoid link function:

$$p(y | f_x) = \sigma(y f_x)$$

This can be explicitly written as:

$$p(y | f_x) = \begin{cases} \sigma(f_x) & \text{if } y = 1 \\ 1 - \sigma(f_x) & \text{if } y = -1 \end{cases}$$

*(Self-reflection): Note that $\sigma(f)$ and $1-\sigma(f)$ are equivalent to $\sigma(f)$ and $\sigma(-f)$ respectively, due to the property $\sigma(x) = 1 - \sigma(-x)$. This simplification is useful for writing the likelihood compactly as $\sigma(y f_x)$.*

#### 4. The Challenge: Non-Gaussian Posterior

The beauty of Gaussian Processes in regression is that if the likelihood (noise model) is Gaussian, the posterior distribution over the latent function $f$ (and thus the predictions) remains Gaussian, making inference tractable.

However, with the logistic sigmoid likelihood, the posterior distribution $p(f_X | Y)$ is **not Gaussian**. This makes exact inference intractable.

Let's look at the log-posterior:

$$\log p(f_X | Y) = \log p(Y | f_X) + \log p(f_X) - \log p(Y)$$

Substituting the GP prior and the sigmoid likelihood:

$$\log p(f_X | Y) = \sum_{i=1}^n \log \sigma(y_i f_{x_i}) - \frac{1}{2}(f_X - m_X)^T K_{XX}^{-1} (f_X - m_X) + \text{const.}$$

where $\log \sigma(y_i f_{x_i}) = -\log(1 + e^{-y_i f_{x_i}})$.

The presence of the $\log(1 + e^{-y_i f_{x_i}})$ term in the sum makes the posterior non-Gaussian. This means we cannot simply compute the posterior mean and covariance analytically as we would in GP regression.

**Visualizing Intractability (Slide 21, 22, 23 of Lecture 14)**

The slides illustrate this non-Gaussian posterior by showing a contour plot that is not perfectly elliptical (which would be the case for a Gaussian). This non-elliptical shape signifies the intractability.

Let's define the `log_posterior` function that we will maximize.

In [None]:
def log_posterior(f_X, K_XX, y_labels):
    """
    Calculates the log posterior of f_X given observations y_labels.
    Assumes mean_X = 0.
    """
    log_likelihood_term = jnp.sum(-jnp.log(1 + jnp.exp(-y_labels * f_X)))
    log_prior_term = -0.5 * f_X.T @ jnp.linalg.solve(K_XX, f_X)
    return log_likelihood_term + log_prior_term


#### 5. Solution: The Laplace Approximation (Introduction)

Since exact inference is intractable, we resort to approximation methods. One popular and computationally efficient method is the **Laplace Approximation**.

The core idea of the Laplace Approximation is to approximate a non-Gaussian probability distribution $p(\theta)$ with a local Gaussian distribution $q(\theta)$. This approximation is built around a (local) maximum of $p(\theta)$ (or equivalently, $\log p(\theta)$), which is known as the Maximum A Posteriori (MAP) estimate.

Here are the steps:
1.  **Find the MAP estimate**: Locate $\hat{\theta} = \arg \max_{\theta} \log p(\theta)$. At this point, the gradient of the log-posterior is zero: $\nabla \log p(\hat{\theta}) = 0$.
2.  **Perform a second-order Taylor expansion**: Expand $\log p(\theta)$ around $\hat{\theta}$:
    $$\log p(\theta) \approx \log p(\hat{\theta}) + \frac{1}{2}(\theta - \hat{\theta})^T \Psi (\theta - \hat{\theta})$$
    where $\Psi = \nabla\nabla^T \log p(\hat{\theta})$ is the Hessian matrix (matrix of second derivatives) evaluated at $\hat{\theta}$.
3.  **Define the Gaussian approximation**: The Laplace approximation $q(\theta)$ is then a Gaussian distribution with mean $\hat{\theta}$ and covariance matrix $-\Psi^{-1}$:
    $$q(\theta) = \mathcal{N}(\theta; \hat{\theta}, -\Psi^{-1})$$

The Laplace approximation is a local approximation and can be arbitrarily wrong, but it tends to be computationally efficient and often works well for logistic regression because the log posterior is concave.

#### 6. Numerical Implementation: Mode-Finding with Newton's Method (JAX)

To implement the mode-finding part (finding $\hat{f}$), we use Newton's method. This iterative optimization algorithm requires the gradient and Hessian of the log posterior. With JAX, we can automatically compute these derivatives, simplifying our implementation.

Let $L(f_X) = \log p(f_X | Y)$. We aim to maximize $L(f_X)$, which is equivalent to minimizing $-L(f_X)$.

The Newton-Raphson update rule is:
$$f^{new} = f^{old} - \mathbf{H}^{-1} \mathbf{g}$$

where $\mathbf{g}$ is the gradient of $L(f_X)$ and $\mathbf{H}$ is the Hessian of $L(f_X)$. Alternatively, if we work with the negative log posterior, the update is $f^{new} = f^{old} - (\text{Hessian of } -L(f_X))^{-1} (\text{Gradient of } -L(f_X))$.

We will define the `newton_step` function using JAX's automatic differentiation capabilities and then the `find_map_f_newton_jax` function to perform the iterations.

In [None]:
@jax.jit  # JIT compile the optimization step for performance
def newton_step(f_X, K_XX, y_labels):
    """Performs one Newton update step for maximizing the log posterior."""
    # We want to maximize log_posterior, so we minimize -log_posterior
    neg_log_posterior = lambda f: -log_posterior(f, K_XX, y_labels)

    # Compute gradient and Hessian of the *negative* log posterior
    grad_neg_log_post = jax.grad(neg_log_posterior)(f_X)
    hess_neg_log_post = jax.hessian(neg_log_posterior)(f_X)

    # Newton update: f_new = f_old - H_inv * grad
    delta_f = jnp.linalg.solve(hess_neg_log_post, grad_neg_log_post)
    f_X_new = f_X - delta_f
    return f_X_new


def find_map_f_newton_jax(
    X_train, y_train, kernel_func, length_scale, max_iter=100, tol=1e-5
):
    """
    Finds the MAP estimate hat_f using Newton's method with JAX autodiff.
    Returns hat_f and the kernel matrix K_XX.
    """
    n = X_train.shape[0]
    f_X = jnp.zeros(n)  # Initialize f_X with zeros
    K_XX = kernel_func(X_train, X_train, length_scale)
    K_XX += 1e-6 * jnp.eye(n)  # Add jitter for numerical stability
    y_labels_jax = jnp.array(y_train)

    print("Starting Newton's method with JAX...")
    for i in range(max_iter):
        f_X_new = newton_step(f_X, K_XX, y_labels_jax)
        change = jnp.linalg.norm(f_X_new - f_X)
        if change < tol:
            print(f"Converged in {i + 1} iterations. Final change: {change:.6f}")
            break
        f_X = f_X_new
    else:
        print(
            f"Newton's method did not converge after {max_iter} iterations. Final change: {change:.6f}"
        )
    return f_X, K_XX


Now, let's apply the Newton's method to find the MAP estimate for our separable dataset and visualize the resulting decision boundary.

In [None]:
# Convert training data to JAX arrays for use in JAX functions
X_train_jax = jnp.array(X_ovl_np)  # Using the separable data generated earlier
y_train_jax = jnp.array(y_ovl_np)

# Find the MAP estimate hat_f
hat_f_jax, K_XX_jax = find_map_f_newton_jax(
    X_train_jax, y_train_jax, rbf_kernel, length_scale=1.0
)

print("\nMAP estimate (hat_f) first 5 elements:")
print(hat_f_jax[:5])

# Prepare a grid of test points for visualization
x1_grid = np.linspace(-4, 4, 100)
x2_grid = np.linspace(-4, 4, 100)
X1_mesh, X2_mesh = np.meshgrid(x1_grid, x2_grid)
X_grid_np = np.vstack([X1_mesh.ravel(), X2_mesh.ravel()]).T
X_grid_jax = jnp.array(X_grid_np)

# Compute mean of the latent function at grid points given hat_f
# E[f_* | hat_f] = k_x*X @ K_XX^-1 @ hat_f (assuming m_X=0)
K_x_X_grid = rbf_kernel(X_grid_jax, X_train_jax, length_scale=2.0)
K_XX_inv = jnp.linalg.inv(K_XX_jax)
mean_latent_predictions_jax = K_x_X_grid @ K_XX_inv @ hat_f_jax

# Convert latent predictions to probabilities using sigmoid (plug-in approximation)
prob_predictions_jax = sigmoid(mean_latent_predictions_jax)
prob_predictions_mesh_np = np.array(prob_predictions_jax).reshape(X1_mesh.shape)


In [None]:
# Plotting Predictive Probabilities with Plotly
fig_prob = go.Figure()
fig_prob.add_trace(
    go.Contour(
        z=prob_predictions_mesh_np,
        x=x1_grid,
        y=x2_grid,
        colorscale="RdBu",
        opacity=0.6,
        colorbar_title="Predicted Probability $P(y=1|x)$",
        line_smoothing=0.8,  # Smooth contours
        contours_coloring="heatmap",  # Color between contours
        zmin=0,
        zmax=1,
    )
)

# Add decision boundary (0.5 probability contour)
fig_prob.add_trace(
    go.Contour(
        z=prob_predictions_mesh_np,
        x=x1_grid,
        y=x2_grid,
        showscale=False,
        contours=dict(
            start=0.5,
            end=0.5,
            size=0,
            coloring="lines",
            showlabels=True,
            labelfont=dict(size=12, color="black"),
        ),
        line_width=2,
        name="Decision Boundary ($P(y=1|x)=0.5$)",
    )
)

# Add training data points
fig_prob = plot_data_plotly(X_train_jax, y_train_jax, fig=fig_prob)

fig_prob.update_layout(
    title_text="Logistic Regression: Predictive Probabilities (Plug-in) with JAX and Plotly",
    title_x=0.5,
    xaxis_title="$x_1$",
    yaxis_title="$x_2$",
    height=600,
    width=800,
    template="plotly_white",
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="center",
        x=0.5,
        font=dict(size=14),
    ),
)
fig_prob.show()


### What Does the Contour Plot Show?

The contour plot you see above visualizes the **predicted probability** of class $+1$ (i.e., $P(y=1|x)$) across the 2D input space for the logistic regression model trained on your data.

#### Key Elements of the Plot

- **Background Colors (Contours):**
    - Each color represents a different predicted probability value for $P(y=1|x)$.
        - Redder regions (or one end of the color scale) indicate higher probabilities of class $+1$.
        - Bluer regions (or the other end) indicate higher probabilities of class $-1$.
        - The color smoothly transitions, showing how the model's confidence changes across the input space.

    - **Black Line (Decision Boundary):**
        - The thick black contour corresponds to $P(y=1|x) = 0.5$.
        - This line represents the set of points where the model is equally uncertain between class $+1$ and class $-1$—in other words, the model's "best guess" at the boundary between the two classes.

    - **Data Points:**
        - The scatter points overlaying the contours show the actual training data, colored by their true class.
        - This helps you see how well the decision boundary separates the two classes and where the model might be uncertain.

    #### How to Interpret the Plot

    - Regions far from the decision boundary (deep red or blue) indicate high model confidence in predicting a particular class.
    - Regions near the decision boundary (where the color is more neutral) indicate uncertainty, with predicted probabilities close to $0.5$.
    - The shape of the decision boundary reflects the flexibility of the model and the influence of the kernel (RBF in this case).

    This visualization is a powerful way to understand not just the model's predictions, but also its uncertainty and the effect of the chosen kernel and hyperparameters.

### Why Does the Model Become Uncertain Far from the Data?

Great observation! In the contour plot, you may notice that the model is **most confident** (probabilities close to 0 or 1) near the training data, especially away from the decision boundary. However, **far from the data**—even far from the boundary—the predicted probabilities tend to approach 0.5, indicating **high uncertainty**.

#### **Why Does This Happen?**

- **Gaussian Process Prior:**  
    The GP prior assumes that, in regions where there are no data points, the latent function $f(x)$ reverts to its prior mean (often zero). The model has no evidence to push the prediction toward either class, so it defaults to being maximally uncertain.

- **Sigmoid Link Function:**  
    The sigmoid function $\sigma(f(x))$ maps $f(x) = 0$ to a probability of 0.5. Thus, in regions where $f(x)$ is close to zero (i.e., where the model is unsure), the predicted probability is 0.5.

- **Kernel Influence:**  
    The RBF kernel causes the influence of each data point to decay rapidly with distance. Far from any data, the kernel values are near zero, so the posterior mean of $f(x)$ is close to the prior mean.

#### **Summary Table**

| Region                  | Model Confidence | Reason                                      |
|-------------------------|------------------|---------------------------------------------|
| Near data, away from boundary | High (close to 0 or 1) | Strong evidence from nearby labeled data    |
| Near decision boundary  | Low (close to 0.5) | Model is uncertain between classes          |
| Far from all data       | Low (close to 0.5) | No evidence; model reverts to prior         |

#### **Key Takeaway**

> **In Gaussian Process classification, the model is only confident where it has seen data. Far from the data, it expresses uncertainty, which is a desirable property in probabilistic modeling.**

This behavior is a hallmark of Bayesian models: **uncertainty increases in regions with little or no data**.

#### Exercises

**Exercise 1: Impact of Kernel Length Scale**
Modify the `length_scale` parameter in the `rbf_kernel` and observe its effect on the learned decision boundary. How does a very small `length_scale` differ from a very large one? Explain why.

**Exercise 2: Non-Separable Data**
Change the `generate_data` type to `'overlapping'` and `'intermingled'`. Rerun the mode-finding and plotting. How does the decision boundary change? What are the limitations of a linear-like decision boundary (which is what we get with the RBF kernel and no explicit feature mapping)?

**Exercise 3: Iteration Count and Convergence**
Experiment with the `max_iter` and `tol` parameters in `find_map_f_newton_jax`. How do they affect the convergence speed and the quality of the solution? What happens if `max_iter` is too small?

**Exercise 4: Exploring the Log Posterior**
For a very simple 1D classification problem with two data points, plot the log posterior function. Can you visually verify its concavity (which allows Newton's method to work well)? *(Hint: This is an advanced exercise. You'll need to define a 1D problem and plot a 2D surface for $\log p(f_1, f_2 | y_1, y_2)$.)*