# Notebook 3: GP Classification with Non-Conjugate Likelihoods

This notebook demonstrates Gaussian Process Classification using non-conjugate likelihoods.

**Learning objectives:**
- Understand non-conjugate likelihoods (Bernoulli, Poisson)
- Learn about Gauss-Hermite and Monte Carlo estimators
- Perform binary and multi-class classification
- Visualize decision boundaries and uncertainty

## Setup

In [None]:
# Enable auto-reload for development
%load_ext autoreload
%autoreload 2

# Fix import path
import sys
if '..' not in sys.path:
    sys.path.insert(0, '..')

import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'

import jax
jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

from infodynamics_jax.core import Phi
from infodynamics_jax.gp.kernels.params import KernelParams
from infodynamics_jax.gp.kernels.rbf import rbf as rbf_kernel
from infodynamics_jax.gp.likelihoods import get as get_likelihood
from infodynamics_jax.energy import InertialEnergy, InertialCFG
from infodynamics_jax.inference.optimisation import TypeII, TypeIICFG
from infodynamics_jax.infodynamics import run, RunCFG
from infodynamics_jax.gp.ansatz.state import VariationalState
from infodynamics_jax.gp.ansatz.expected import qfi_from_qu_full

print(f"JAX version: {jax.__version__}")

## 1. Understanding Non-Conjugate Likelihoods

### Conjugate vs. Non-Conjugate

**Conjugate (Gaussian likelihood)**:
- Posterior has closed-form solution
- Fast, exact inference
- Limited to regression

**Non-Conjugate (Bernoulli, Poisson, etc.)**:
- No closed-form posterior
- Requires approximation methods:
  - **Gauss-Hermite (GH)**: Deterministic quadrature
  - **Monte Carlo (MC)**: Stochastic sampling
- Enables classification, count data, etc.

### Available Non-Conjugate Likelihoods

1. **Bernoulli**: Binary classification
   - $p(y=1|f) = \sigma(f)$ where $\sigma$ is sigmoid
   
2. **Poisson**: Count data
   - $p(y|f) = \text{Poisson}(y | \lambda = \exp(f))$
   
3. **Negative Binomial**: Over-dispersed counts
   - Generalizes Poisson
   
4. **Ordinal**: Ordered categorical data
   - For ratings, severity levels, etc.

## 2. Generate Binary Classification Data

In [None]:
key = jax.random.key(456)

# Generate 2D data
N_train = 100

# Class 0: cluster around (-2, -2)
key, subkey = jax.random.split(key)
X_class0 = jax.random.normal(subkey, (N_train // 2, 2)) * 0.8 + jnp.array([-2.0, -2.0])

# Class 1: cluster around (2, 2)
key, subkey = jax.random.split(key)
X_class1 = jax.random.normal(subkey, (N_train // 2, 2)) * 0.8 + jnp.array([2.0, 2.0])

# Combine
X_train = jnp.vstack([X_class0, X_class1])
Y_train = jnp.concatenate([
    jnp.zeros(N_train // 2),
    jnp.ones(N_train // 2)
])

# Shuffle
key, subkey = jax.random.split(key)
perm = jax.random.permutation(subkey, N_train)
X_train = X_train[perm]
Y_train = Y_train[perm]

print(f"Training set: {N_train} points")
print(f"X shape: {X_train.shape}, Y shape: {Y_train.shape}")
print(f"Class 0: {jnp.sum(Y_train == 0)} points")
print(f"Class 1: {jnp.sum(Y_train == 1)} points")

In [None]:
# Visualize training data
plt.figure(figsize=(8, 6))
plt.scatter(X_train[Y_train == 0, 0], X_train[Y_train == 0, 1], 
           c='blue', s=50, alpha=0.6, label='Class 0', edgecolors='k')
plt.scatter(X_train[Y_train == 1, 0], X_train[Y_train == 1, 1], 
           c='red', s=50, alpha=0.6, label='Class 1', edgecolors='k')
plt.xlabel('X1', fontsize=12)
plt.ylabel('X2', fontsize=12)
plt.title('Binary Classification Data', fontsize=14, fontweight='bold')
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.show()

## 3. Train GP Classifier with Bernoulli Likelihood

We use Gauss-Hermite quadrature to handle the non-conjugate Bernoulli likelihood.

In [None]:
# Get Bernoulli likelihood
bernoulli_likelihood = get_likelihood("bernoulli")

# Initialize kernel parameters
kernel_params = KernelParams(
    lengthscale=jnp.array(1.0),
    variance=jnp.array(1.0),
)

# Create inducing points (grid)
M = 20
Z_x1 = jnp.linspace(X_train[:, 0].min(), X_train[:, 0].max(), int(jnp.sqrt(M)))
Z_x2 = jnp.linspace(X_train[:, 1].min(), X_train[:, 1].max(), int(jnp.sqrt(M)))
Z_grid = jnp.stack(jnp.meshgrid(Z_x1, Z_x2), axis=-1).reshape(-1, 2)
Z = Z_grid[:M]  # Take first M points

# Create Phi (no noise variance for Bernoulli)
phi_init = Phi(
    kernel_params=kernel_params,
    Z=Z,
    likelihood_params={},  # Bernoulli has no extra parameters
    jitter=1e-5,
)

print(f"Number of inducing points: {len(Z)}")
print(f"Kernel lengthscale: {phi_init.kernel_params.lengthscale}")
print(f"Kernel variance: {phi_init.kernel_params.variance}")

In [None]:
# Create InertialEnergy with Gauss-Hermite estimator
inertial_cfg = InertialCFG(
    estimator="gh",  # Gauss-Hermite quadrature
    gh_n=20,         # Number of quadrature points
    inner_steps=0,   # No inner optimization
)

inertial_energy = InertialEnergy(
    kernel_fn=rbf_kernel,
    likelihood=bernoulli_likelihood,
    cfg=inertial_cfg,
)

print("InertialEnergy created with Bernoulli likelihood!")
print(f"Estimator: {inertial_cfg.estimator}")
print(f"GH quadrature points: {inertial_cfg.gh_n}")

In [None]:
# Configure and run TypeII
typeii_cfg = TypeIICFG(
    steps=150,
    lr=1e-2,
    optimizer="adam",
    jit=True,
    constrain_params=True,
)

method = TypeII(cfg=typeii_cfg)

key, subkey = jax.random.split(key)
out = run(
    key=subkey,
    method=method,
    energy=inertial_energy,
    phi_init=phi_init,
    energy_args=(X_train, Y_train),
    cfg=RunCFG(jit=True),
)

phi_opt = out.result.phi
energy_trace = out.result.energy_trace

print("\nOptimization complete!")
print(f"Final energy: {energy_trace[-1]:.2f}")
print(f"Energy reduction: {energy_trace[0] - energy_trace[-1]:.2f}")
print(f"\nOptimized hyperparameters:")
print(f"  Lengthscale: {float(phi_opt.kernel_params.lengthscale):.3f}")
print(f"  Variance: {float(phi_opt.kernel_params.variance):.3f}")

In [None]:
# Plot optimization trace
plt.figure(figsize=(10, 4))
plt.plot(energy_trace, 'b-', linewidth=2)
plt.xlabel('Iteration', fontsize=12)
plt.ylabel('Energy', fontsize=12)
plt.title('MAP-II Optimization Trace (Bernoulli Likelihood)', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.show()

## 4. Make Predictions and Visualize Decision Boundary

In [None]:
# Create grid for prediction
resolution = 100
x1_min, x1_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1
x2_min, x2_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1
xx1, xx2 = jnp.meshgrid(
    jnp.linspace(x1_min, x1_max, resolution),
    jnp.linspace(x2_min, x2_max, resolution)
)
X_grid = jnp.c_[xx1.ravel(), xx2.ravel()]

print(f"Grid shape: {X_grid.shape}")

In [None]:
# Compute posterior state
state = VariationalState.initialise(phi_opt, X_train, Y_train)

# Make predictions on grid
mu_grid, var_grid = qfi_from_qu_full(
    phi_opt, X_grid, rbf_kernel, state.m_u, state.L_u
)

mu_grid = mu_grid.squeeze()
var_grid = var_grid.squeeze()

# Convert latent function to probabilities: p(y=1) = sigmoid(f)
prob_class1 = jax.nn.sigmoid(mu_grid)
prob_class1 = prob_class1.reshape(xx1.shape)

print(f"Predictions computed on {len(X_grid)} grid points")
print(f"Probability range: [{prob_class1.min():.3f}, {prob_class1.max():.3f}]")

In [None]:
# Visualize decision boundary
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: Decision boundary
ax = axes[0]
contour = ax.contourf(xx1, xx2, prob_class1, levels=20, cmap='RdBu_r', alpha=0.8)
ax.contour(xx1, xx2, prob_class1, levels=[0.5], colors='black', linewidths=2)
ax.scatter(X_train[Y_train == 0, 0], X_train[Y_train == 0, 1], 
          c='blue', s=50, alpha=0.8, edgecolors='k', label='Class 0')
ax.scatter(X_train[Y_train == 1, 0], X_train[Y_train == 1, 1], 
          c='red', s=50, alpha=0.8, edgecolors='k', label='Class 1')
ax.scatter(phi_opt.Z[:, 0], phi_opt.Z[:, 1], 
          c='green', marker='x', s=100, linewidths=2, label='Inducing points')
ax.set_xlabel('X1', fontsize=12)
ax.set_ylabel('X2', fontsize=12)
ax.set_title('Decision Boundary (p(y=1))', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
plt.colorbar(contour, ax=ax, label='P(Class 1)')

# Plot 2: Prediction uncertainty
ax = axes[1]
std_grid = jnp.sqrt(var_grid).reshape(xx1.shape)
contour = ax.contourf(xx1, xx2, std_grid, levels=20, cmap='viridis', alpha=0.8)
ax.scatter(X_train[Y_train == 0, 0], X_train[Y_train == 0, 1], 
          c='blue', s=50, alpha=0.5, edgecolors='k')
ax.scatter(X_train[Y_train == 1, 0], X_train[Y_train == 1, 1], 
          c='red', s=50, alpha=0.5, edgecolors='k')
ax.set_xlabel('X1', fontsize=12)
ax.set_ylabel('X2', fontsize=12)
ax.set_title('Prediction Uncertainty (std)', fontsize=14, fontweight='bold')
plt.colorbar(contour, ax=ax, label='Std Dev')

plt.tight_layout()
plt.show()

## 5. Evaluate Classification Performance

In [None]:
# Make predictions on training set
mu_train, var_train = qfi_from_qu_full(
    phi_opt, X_train, rbf_kernel, state.m_u, state.L_u
)

mu_train = mu_train.squeeze()
prob_train = jax.nn.sigmoid(mu_train)
Y_pred = (prob_train > 0.5).astype(float)

# Compute metrics
accuracy = float(jnp.mean(Y_pred == Y_train))
true_positives = float(jnp.sum((Y_pred == 1) & (Y_train == 1)))
false_positives = float(jnp.sum((Y_pred == 1) & (Y_train == 0)))
true_negatives = float(jnp.sum((Y_pred == 0) & (Y_train == 0)))
false_negatives = float(jnp.sum((Y_pred == 0) & (Y_train == 1)))

precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

print("Classification Metrics (Training Set):")
print(f"  Accuracy:  {accuracy:.3f}")
print(f"  Precision: {precision:.3f}")
print(f"  Recall:    {recall:.3f}")
print(f"  F1 Score:  {f1_score:.3f}")
print(f"\nConfusion Matrix:")
print(f"  TP: {int(true_positives):<3} FP: {int(false_positives):<3}")
print(f"  FN: {int(false_negatives):<3} TN: {int(true_negatives):<3}")

## 6. Comparing Gauss-Hermite vs. Monte Carlo Estimators

Let's compare the two estimators for non-conjugate likelihoods.

In [None]:
# Train with Monte Carlo estimator
inertial_cfg_mc = InertialCFG(
    estimator="mc",      # Monte Carlo
    n_mc_samples=100,    # Number of MC samples
    inner_steps=0,
)

inertial_energy_mc = InertialEnergy(
    kernel_fn=rbf_kernel,
    likelihood=bernoulli_likelihood,
    cfg=inertial_cfg_mc,
)

key, subkey = jax.random.split(key)
out_mc = run(
    key=subkey,
    method=method,
    energy=inertial_energy_mc,
    phi_init=phi_init,
    energy_args=(X_train, Y_train),
    cfg=RunCFG(jit=True),
)

phi_opt_mc = out_mc.result.phi
energy_trace_mc = out_mc.result.energy_trace

print("\nMonte Carlo Estimator Results:")
print(f"Final energy: {energy_trace_mc[-1]:.2f}")
print(f"Optimized lengthscale: {float(phi_opt_mc.kernel_params.lengthscale):.3f}")
print(f"Optimized variance: {float(phi_opt_mc.kernel_params.variance):.3f}")

In [None]:
# Compare optimization traces
plt.figure(figsize=(10, 4))
plt.plot(energy_trace, 'b-', linewidth=2, label='Gauss-Hermite', alpha=0.8)
plt.plot(energy_trace_mc, 'r--', linewidth=2, label='Monte Carlo', alpha=0.8)
plt.xlabel('Iteration', fontsize=12)
plt.ylabel('Energy', fontsize=12)
plt.title('Estimator Comparison', fontsize=14, fontweight='bold')
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()

print("\nComparison:")
print(f"GH final energy:  {energy_trace[-1]:.2f}")
print(f"MC final energy:  {energy_trace_mc[-1]:.2f}")
print(f"\nNote: MC is stochastic, so the trace is noisy.")
print(f"      GH is deterministic and generally more stable.")

## Summary

In this notebook, we demonstrated GP classification with non-conjugate likelihoods:

### Key Concepts

1. **Non-Conjugate Likelihoods**: Bernoulli for binary classification
2. **Approximation Methods**:
   - **Gauss-Hermite**: Deterministic, stable, good for low-dimensional integrals
   - **Monte Carlo**: Stochastic, scales to high dimensions, noisier gradients

3. **GP Classification**:
   - Latent function $f(x)$ modeled by GP
   - Probability: $p(y=1|x) = \sigma(f(x))$
   - Natural uncertainty quantification

### When to Use Each Estimator

| Estimator | Pros | Cons | Best For |
|-----------|------|------|----------|
| **Gauss-Hermite** | Deterministic, stable | Limited to 1D integrals | Standard classification |
| **Monte Carlo** | Scales to high dims | Noisy, slower convergence | Multi-output, complex models |

### Extensions

- **Multi-class**: Use softmax likelihood
- **Count data**: Poisson or Negative Binomial
- **Ordinal**: Ordered categories (ratings, severity)
- **Robust**: Student-t for outlier resistance