# Bayesian Hybrid Models for Enzyme Kinetics - Group Exercise
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jonathon-langford/aims-inference-2026/blob/main/3-Bayesian-models-for-kinetics/2_group_work_bayesian_kinetics.ipynb)

## Learning Objectives

By the end of this exercise, you will be able to:
1. Build and compare mechanistic vs. hybrid Bayesian models
2. Diagnose MCMC convergence using trace plots, $\hat{R}$, and ESS
3. Use posterior predictive checks to assess model adequacy
4. Visualise and interpret prediction uncertainty

## Overview

You will analyse enzyme kinetics data where the reaction rate constant k depends on experimental conditions: temperature (T) and pH. 

The challenge: Can a hybrid model (mechanistic ODE + Gaussian process) outperform a simple mechanistic model with constant k?

## Task Distribution (6 students, feel free to redistribute as you like)

- **Task 1**: Check convergence for baseline model
- **Task 2**: Posterior predictive checks for baseline
- **Task 3**: Build the hybrid GP model
- **Task 4**: Check convergence and visualize learned k(T, pH)
- **Task 5**: Compare models with posterior predictive checks
- **Task 6**: Experiment with kernels and length scales

In [None]:
# Install dependencies
!pip install pymc arviz matplotlib numpy pandas --quiet

In [None]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm

rng = np.random.default_rng(42)

print("Imports successful")
print(f"PyMC version: {pm.__version__}")
print(f"ArviZ version: {az.__version__}")

---

# SECTION 1: Load and Explore Data

In [None]:
# Load the kinetic data
url = "https://raw.githubusercontent.com/jonathon-langford/aims-inference-2026/4cba750e6017d7d4b236aef029b556e04369fea2/3-Bayesian-models-for-kinetics/kinetic_data.csv"
data = pd.read_csv(url)

print("Data Overview:")
print(f"Total observations: {len(data)}")
print(f"Unique samples: {data['SampleID'].nunique()}")
print(f"Time points per sample: {data.groupby('SampleID').size().iloc[0]}")
print(f"\nTemperature range: {data['T'].min():.1f} - {data['T'].max():.1f} °C")
print(f"pH range: {data['pH'].min():.2f} - {data['pH'].max():.2f}")
print(f"\nFirst few rows:")
data.head(20)

### Visualise sample time courses at different conditions

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

sample_ids = rng.choice(data["SampleID"].unique(), size=6, replace=False)
for idx, sid in enumerate(sample_ids):
    ax = axes[idx]
    subset = data[data["SampleID"] == sid]
    T_val = subset["T"].iloc[0]
    pH_val = subset["pH"].iloc[0]
    
    ax.plot(subset["Time"], subset["P_obs"], "o-", alpha=0.7)
    ax.set_xlabel("Time (min)")
    ax.set_ylabel("Product [P]")
    ax.set_title(f"Sample {sid}: T={T_val:.1f}°C, pH={pH_val:.2f}")
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Questions:")
print("1. Do all samples reach the same plateau? (Hint: use sharey in plots)")
print("2. Do some samples show faster reaction rates than others?")

### Visualise distribution of experimental conditions

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

sample_conditions = data.groupby("SampleID")[["T", "pH"]].first()

axes[0].hist(sample_conditions["T"], bins=10, edgecolor="black")
axes[0].set_xlabel("Temperature (°C)")
axes[0].set_ylabel("Count")
axes[0].set_title("Temperature Distribution")

axes[1].hist(sample_conditions["pH"], bins=10, edgecolor="black")
axes[1].set_xlabel("pH")
axes[1].set_ylabel("Count")
axes[1].set_title("pH Distribution")

axes[2].scatter(sample_conditions["T"], sample_conditions["pH"], alpha=0.6)
axes[2].set_xlabel("Temperature (°C)")
axes[2].set_ylabel("pH")
axes[2].set_title("Experimental Design Space")
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

# SECTION 2: Reshape Data for Modeling

We reshape the data into matrices for efficient modeling.

In [None]:
# Extract unique samples and times
unique_samples = data[["SampleID", "T", "pH"]].drop_duplicates().reset_index(drop=True)
unique_times = np.sort(data["Time"].unique())

n_samples = len(unique_samples)
n_timepoints = len(unique_times)

print(f"Number of unique samples: {n_samples}")
print(f"Number of unique timepoints: {n_timepoints}")

# Map SampleID to index
sample_index = {sid: i for i, sid in enumerate(unique_samples["SampleID"])}
data["sample_idx"] = data["SampleID"].map(sample_index)

# Create observation matrix (samples x timepoints)
P_obs_matrix = np.full((n_samples, n_timepoints), np.nan)
for i, sid in enumerate(unique_samples["SampleID"]):
    subset = data[data["SampleID"] == sid].sort_values("Time")
    P_obs_matrix[i, :] = subset["P_obs"].values

print(f"\nObservation matrix shape: {P_obs_matrix.shape}")

# Normalise features for GP
T_mean = unique_samples["T"].mean()
T_std_val = unique_samples["T"].std()
pH_mean = unique_samples["pH"].mean()
pH_std_val = unique_samples["pH"].std()

T_normalized = (unique_samples["T"] - T_mean) / T_std_val
pH_normalized = (unique_samples["pH"] - pH_mean) / pH_std_val

X = np.vstack([T_normalized, pH_normalized]).T

print(f"Input matrix X shape: {X.shape}")
print(f"\nData prepared for modeling")

---

# SECTION 3: Baseline Model (Constant k)

## Model Specification

We start with a naive baseline that assumes k is constant across all conditions:

```
k_global ~ HalfNormal(sigma=0.1)
P(t) = S0 * (1 - exp(-k_global * t))
sigma_obs ~ HalfNormal(sigma=0.2)
P_obs ~ Normal(P(t), sigma_obs)
```

This model is provided for you.

In [None]:
S0 = 5.0

coords = {
    "samples": unique_samples["SampleID"].values,
    "time": unique_times,
}

with pm.Model(coords=coords) as model_baseline:
    time_values = pm.Data("time_values", unique_times, dims="time")
    
    k_global = pm.HalfNormal("k_global", sigma=0.1)
    sigma_obs = pm.HalfNormal("sigma_obs", sigma=0.2)
    
    P_pred = S0 * (1 - pm.math.exp(-k_global * time_values))
    
    P_obs = pm.Normal(
        "P_obs",
        mu=P_pred,
        sigma=sigma_obs,
        observed=P_obs_matrix,
        dims=("samples", "time")
    )

pm.model_to_graphviz(model_baseline)

In [None]:
with model_baseline:
    trace_baseline = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        random_seed=42,
        return_inferencedata=True,
        idata_kwargs={"log_likelihood": True}
    )

print("Baseline sampling complete")

---

# SECTION 4: Group Tasks

## TASK 1: Check Convergence for Baseline Model

**Student 1**: Assess whether MCMC sampling worked correctly.

In [None]:
# TODO: Print summary statistics using az.summary()
# Check: Is R-hat < 1.01? Is ESS > 400?

summary = az.summary(trace_baseline, var_names=["k_global", "sigma_obs"])
print(summary)

# TODO: Plot trace plots using az.plot_trace()
# Look for: "fuzzy caterpillar", stationarity, chain overlap

# TODO: Plot posterior distributions using arviz
# What is the posterior mean for k_global?

## TASK 2: Posterior Predictive Checks for Baseline

**Student 2**: Generate posterior predictive samples and assess model adequacy.

In [None]:
# TODO: Generate posterior predictive samples
# Hint: Use pm.sample_posterior_predictive(trace_baseline)

with model_baseline:
    ppc_baseline = ...

# Extract predictions
ppc_samples = ...
ppc_samples_flat = ppc_samples.reshape(-1, n_samples, n_timepoints) # this helps you to separate the dimensions later

# Plotting code provided (feel free to use arviz instead if you prefer)
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
sample_ids_plot = rng.choice(range(n_samples), size=6, replace=False)

for idx, sid in enumerate(sample_ids_plot):
    ax = axes[idx // 3, idx % 3]
    
    T_val = unique_samples.loc[sid, "T"]
    pH_val = unique_samples.loc[sid, "pH"]
    
    ppc_sample = ppc_samples_flat[:, sid, :]
    
    for pct in [5, 25, 50, 75, 95]:
        alpha = 0.3 if pct in [5, 95] else 0.5
        ax.plot(unique_times, np.percentile(ppc_sample, pct, axis=0), 
               alpha=alpha, color="blue")
    
    ax.scatter(unique_times, P_obs_matrix[sid, :], color="red", s=30, zorder=10, label="Observed")
    
    ax.set_xlabel("Time (min)")
    ax.set_ylabel("Product [P]")
    ax.set_title(f"Sample {sid}: T={T_val:.1f}°C, pH={pH_val:.2f}")
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle("Baseline Model: Posterior Predictive Checks", fontsize=14, y=1.00)
plt.tight_layout()
plt.show()

# TODO: Answer these questions:
# 1. Does the model capture all the variation in the data?
# 2. Are some samples consistently over/under-predicted?
# 3. What does this tell you about the constant-k assumption?

## TASK 3: Build the Hybrid GP Model

**Student 3**: Construct a Gaussian Process model for k(T, pH).

The data is prepared (X matrix with normalized T and pH). Build the GP model following this structure.

Question: Why do you need a latent GP here?

In [None]:
coords = {
    "samples": unique_samples["SampleID"].values,
    "time": unique_times,
    "features": ["T", "pH"],
}

with pm.Model(coords=coords) as model_gp:
    # Data containers
    time_values = pm.Data("time_values", unique_times, dims="time")
    normalised_T_pH = pm.Data("normalised_T_pH", X, dims=("samples", "features"))
    
    # TODO: Define GP hyperparameters
    # Hint: Check PyMC documentation for appropriate priors
    # Hint: You could for example use pm.HalfNormal for amplitude (eta)
    
    ls = ...
    eta = ...
    mean_log_k = pm.Normal("mean_log_k", mu=-1.5, sigma=1.0) # Here I provide you with an example
    
    # TODO: Create RBF covariance function (see later exercises to change this)
    # Hint: Don't forget to use input_dim=2 and ls=ls if you want to use individual length scales for T and pH
    
    cov_func = ...  # Your code here
    
    # TODO: Create Latent GP
    # Hint: Use pm.gp.Latent with mean_func and cov_func
    # Hint: Use pm.gp.mean.Constant(c=mean_log_k) for mean function
    
    gp = ...  # Your code here
    
    # TODO: GP prior over log(k) (I would recommend you use it this way to avoid non-positive k samples)
    # Hint: Use gp.prior("log_k",...)
    # While GP input has two dimensions (T, pH), output (k) is one-dimensional
    
    log_k = ...  # Your code here
    
    # TODO: Transform log(k) to k (ensures k > 0)
    # Hint: k = pm.Deterministic("k", ..., dims="samples")
    
    k = ...  # Your code here
    
    # TODO: Observation noise (sigma can be tuned)
    sigma_obs = pm.HalfNormal("sigma_obs", sigma=0.2)
    
    # TODO: Mechanistic model - compute predicted P from k
    # Hint: See yesterday's notebook and baseline model
    # Hint: Use [:, None] and [None, :] to broadcast multiplication (match the dimensions of samples and times etc.)
    # k has shape (n_samples,), time_values has shape (n_timepoints,)
    # Result should have shape (n_samples, n_timepoints)
    
    P_pred = pm.Deterministic(
        "P_pred",
        ...,  # Your code here: S0 * (1 - pm.math.exp(...))
        dims=("samples", "time")
    )
    
    # Likelihood
    P_obs = pm.Normal(
        "P_obs",
        mu=P_pred,
        sigma=sigma_obs,
        observed=P_obs_matrix,
        dims=("samples", "time")
    )

# Visualize the computation graph
pm.model_to_graphviz(model_gp)

In [None]:
# TODO: Sample from the hybrid model
# Hint: Increase target_accept for better sampling with GPs if necessary
# This will take longer than the baseline model

with model_gp:
    trace_gp = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        return_inferencedata=True,
        idata_kwargs={"log_likelihood": True} # in case you want to do model comparison later
    )

print("Hybrid model sampling complete")

## TASK 4: Check Convergence and Visualize k(T, pH)

**Student 4**: Check diagnostics for the hybrid model and visualize the learned surface.

In [None]:
# TODO: Print summary statistics for GP hyperparameters
# Check: R-hat < 1.01? ESS > 400? Any divergences?

summary = az.summary(trace_gp, var_names=["ls", "eta", "mean_log_k", "sigma_obs"])
print(summary)

divergences = trace_gp.sample_stats["diverging"].sum().item()
print(f"\nDivergences: {divergences}")

# TODO: Plot traces for GP hyperparameters

# TODO: What do the length scales tell you?
# How quickly does k vary with T vs pH?

In [None]:
# TODO: Visualize the learned k(T, pH) surface

# Extract posterior samples of k
k_posterior = trace_gp.posterior["k"].values
k_mean = k_posterior.mean(axis=(0, 1))
k_std = k_posterior.std(axis=(0, 1))

# Plotting code provided
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Posterior mean
scatter1 = axes[0].tricontourf(unique_samples["T"], unique_samples["pH"], k_mean, 
                                levels=20, cmap="viridis")
plt.colorbar(scatter1, ax=axes[0], label="k (1/min)")
axes[0].scatter(unique_samples["T"], unique_samples["pH"], c="red", s=20, alpha=0.3)
axes[0].set_title("Learned k(T, pH) - Posterior Mean")
axes[0].set_xlabel("Temperature (°C)")
axes[0].set_ylabel("pH")

# Uncertainty
scatter2 = axes[1].tricontourf(unique_samples["T"], unique_samples["pH"], k_std, 
                                levels=20, cmap="Reds")
plt.colorbar(scatter2, ax=axes[1], label="Std Dev")
axes[1].scatter(unique_samples["T"], unique_samples["pH"], c="blue", s=20, alpha=0.3)
axes[1].set_title("Posterior Uncertainty")
axes[1].set_xlabel("Temperature (°C)")
axes[1].set_ylabel("pH")

plt.tight_layout()
plt.show()

# TODO: Answer these questions:
# 1. Where does k appear highest?
# 2. Where is uncertainty highest?
# 3. How does uncertainty relate to data coverage?

## TASK 5: Compare Models

**Student 5**: Generate posterior predictive checks for the hybrid model and compare.

In [None]:
# TODO: Generate posterior predictive for hybrid model
# Use the same approach as Task 2

with model_gp:
    ppc_gp = ...

# Extract predictions
ppc_gp_samples = ...
ppc_gp_flat = ...

# Side-by-side comparison (plotting code provided)
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
sample_ids_plot = rng.choice(range(n_samples), size=3, replace=False)

for col, sid in enumerate(sample_ids_plot):
    T_val = unique_samples.loc[sid, "T"]
    pH_val = unique_samples.loc[sid, "pH"]
    
    # Baseline (top row)
    ax = axes[0, col]
    ppc_sample = ppc_samples_flat[:, sid, :]
    
    for pct in [5, 50, 95]:
        alpha = 0.3 if pct in [5, 95] else 0.8
        ax.plot(unique_times, np.percentile(ppc_sample, pct, axis=0), 
               alpha=alpha, color="blue")
    ax.scatter(unique_times, P_obs_matrix[sid, :], color="red", s=30, zorder=10)
    ax.set_title(f"Baseline: T={T_val:.1f}°C, pH={pH_val:.2f}")
    ax.set_ylabel("Product [P]")
    ax.grid(True, alpha=0.3)
    
    # Hybrid (bottom row)
    ax = axes[1, col]
    ppc_sample = ppc_gp_flat[:, sid, :]
    
    for pct in [5, 50, 95]:
        alpha = 0.3 if pct in [5, 95] else 0.8
        ax.plot(unique_times, np.percentile(ppc_sample, pct, axis=0), 
               alpha=alpha, color="green")
    ax.scatter(unique_times, P_obs_matrix[sid, :], color="red", s=30, zorder=10)
    ax.set_title(f"Hybrid: T={T_val:.1f}°C, pH={pH_val:.2f}")
    ax.set_xlabel("Time (min)")
    ax.set_ylabel("Product [P]")
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# TODO: Answer these questions:
# 1. Which model better captures individual curves?
# 2. How do the uncertainty bands differ?
# 3. Where does each model still struggle?
# 4. Can you think of other ways to compare the models than just visually?

## TASK 6: Experiment with Kernels and Length Scales

**Student 6**: Explore different kernel choices and discuss length scale interpretation.

### Part A: Try another kernel, e.g. Matern family (presented by Austin)

In [None]:
# TODO: Copy the GP model from Task 3 and replace ExpQuad with another
# Hint: pm.gp.cov.Matern52(input_dim=2, ls=ls) is the general syntax

# Build model, sample, and compare learned surfaces
# Do the surfaces look similar or different?

### Part B: Discussion Questions

**About Kernels:**

1. What is the difference between RBF and Matern-5/2 kernels?
2. Do the learned surfaces look similar or different? Why?
3. When might kernel choice matter more?

**About Length Scales:**

4. What do the length scales (ls_T and ls_pH) represent?
   - Small ls = ? 
   - Large ls = ?

5. Should length scales be the same for T and pH?
   - Why or why not?
   - What do your posterior values suggest?

6. What happens with very tight priors (e.g., ls ~ Gamma(0.1, 0.1))?

7. What happens with very diffuse priors (e.g., ls ~ Gamma(5, 5))?

In [None]:
# TODO (Optional): Experiment with different length scale priors
# Try: ls = pm.Gamma("ls", alpha=5, beta=5, dims="features")  # tight
# Try: ls = pm.Gamma("ls", alpha=1, beta=0.5, dims="features")  # diffuse

# How does this affect the learned surface?
# Which prior gives better fits?



# Hint: If you need to visuale different priors beforehand, you can use this code snippet:

# n_samples = 5000

# with pm.Model() as model:
#     gamma1 = pm.Gamma("gamma1", alpha=0.1, beta=0.1)
#     gamma2 = pm.Gamma("gamma2", alpha=5, beta=5)
    
#     prior_samples = pm.sample_prior_predictive(samples=n_samples)

# # Extract samples
# g1 = prior_samples.prior["gamma1"].values.flatten()
# g2 = prior_samples.prior["gamma2"].values.flatten()

# # Plot
# plt.figure()
# plt.hist(g1, bins=100, density=True, alpha=0.5, label="alpha=0.1, beta=0.1")
# plt.hist(g2, bins=100, density=True, alpha=0.5, label="alpha=5, beta=5")

# plt.legend()
# plt.xlabel("x")
# plt.ylabel("Density")
# plt.title("Different Gamma Priors")
# plt.show()

---

# SECTION 5: Bonus Questions

If you finish early, discuss these questions.

## Bonus 1: Optimal Conditions

According to your GP model, what are the optimal T and pH?

```python
idx_max = k_mean.argmax()
T_optimal = unique_samples.loc[idx_max, "T"]
pH_optimal = unique_samples.loc[idx_max, "pH"]
```

How confident is the model at this optimum?

## Bonus 2: Experimental Design

If you could collect 5 more samples, where would you place them?
- Strategy 1: High uncertainty regions
- Strategy 2: Around the optimum

Which strategy would you choose and why?

## Bonus 3: Model Assumptions

What assumptions does the hybrid model make?
1. First-order kinetics
2. k varies smoothly with T and pH
3. Independent observations
4. Constant observation noise

Which might be violated in real data?

## Bonus 4: Extensions

How would you extend this to 3+ factors (e.g., T, pH, substrate concentration)?

What about alternative mechanistic models (e.g., Michaelis-Menten) or fully mechanistic models?

---

# Summary

## Key Takeaways

1. **MCMC Diagnostics**: $\hat{R}, ESS, and traces tell you if sampling worked
2. **Posterior Predictive Checks**: Tell you if the model is adequate
3. **Good diagnostics ≠ good model**: Convergence doesn't guarantee correctness
4. **Hybrid Models**: Combine mechanistic knowledge with data-driven flexibility
5. **GPs**: Model unknown functions and provide uncertainty
6. **Kernels**: Encode smoothness assumptions
7. **Length scales**: Tell you how quickly functions vary