# Chemical Kinetics

Chemical Kinetics Example for JAXSR.

Demonstrates discovering rate laws from kinetic data, including:
- Langmuir-Hinshelwood kinetics
- Power law kinetics
- Arrhenius temperature dependence

In [1]:
import jax.numpy as jnp
import numpy as np
from jaxsr import BasisLibrary, Constraints, SymbolicRegressor

## Discover Langmuir-Hinshelwood rate law.

True model:

$$r = \frac{k\,C_A\,C_B}{1 + K\,C_A}, \quad k=2.5,\; K=1.2$$

In [2]:
# Generate synthetic kinetic datanp.random.seed(42)n_samples = 100# Concentration ranges typical for catalytic reactionsC_A = np.random.uniform(0.1, 2.0, n_samples)C_B = np.random.uniform(0.1, 2.0, n_samples)# True kinetic parametersk = 2.5  # Rate constantK = 1.2  # Adsorption equilibrium constant# True rate lawr_true = k * C_A * C_B / (1 + K * C_A)r = r_true + np.random.randn(n_samples) * 0.05X = jnp.column_stack([C_A, C_B])y = jnp.array(r)print("\nTrue model: r = 2.5*C_A*C_B / (1 + 1.2*C_A)")print(f"Data: {n_samples} samples")# Build basis library with appropriate functions for kineticslibrary = (    BasisLibrary(n_features=2, feature_names=["C_A", "C_B"])    .add_constant()    .add_linear()    .add_polynomials(max_degree=2)    .add_interactions(max_order=2)    .add_ratios()    .add_transcendental(["inv"]))print(f"Basis library: {len(library)} candidate functions")# Add constraint: reaction rate must be non-negativeconstraints = Constraints().add_bounds("y", lower=0)
# ⚠️ Constraint Enforcement Note: This uses soft enforcement (hard=False by default).
# Predictions may still go slightly negative at boundary regions.
# For strict non-negativity everywhere, use:
#   constraints = Constraints().add_bounds("y", lower=0, hard=True)
#   model = SymbolicRegressor(..., constraint_enforcement="exact")# Fit modelmodel = SymbolicRegressor(    basis_library=library,    max_terms=6,    strategy="greedy_forward",    information_criterion="bic",    constraints=constraints,)model.fit(X, y)print("\nDiscovered expression:")print(f"  {model.expression_}")print("\n--- Mechanistic vs Empirical Models ---")print("This polynomial-rational expression provides good empirical fit (R² shown above),")print("but does NOT give mechanistic parameters k and K from the L-H rate law.")print()print("To recover exact L-H form:  r = k*C_A*C_B/(1+K*C_A)")print("  1. Use add_parametric() to define L-H basis function")print("  2. Fit with profile_params=['K'] to optimize adsorption constant")print("  3. Extract k from coefficient, K from optimized parameter")print("  4. See model_comparison_isotherms.ipynb for complete parametric example")print()print("When to use parametric vs exploratory:")print("  • Parametric (mechanistic): Known functional form, interpretable parameters")print("  • Exploratory (polynomial): Unknown form, empirical approximation")print("\nMetrics:")print(f"  R² = {model.metrics_['r2']:.4f}")print(f"  MSE = {model.metrics_['mse']:.6f}")

### Note on soft bounds constraints

`add_bounds("y", lower=0)` uses **soft enforcement by default** (`hard=False`), which adds a penalty term during coefficient refit but does not guarantee non-negativity everywhere. The model's negative intercept (`-0.10`) means predictions could be negative for very small concentrations. For strict non-negativity, use `hard=True` or pass `constraint_enforcement="exact"` to the `SymbolicRegressor` constructor.

In [3]:
# Parameter significance, diagnostics, and ANOVA
import matplotlib.pyplot as plt
from scipy import stats as sp_stats
from jaxsr import anova
from jaxsr.plotting import plot_parity

intervals = model.coefficient_intervals(alpha=0.05)
n_obs, k_terms = len(np.asarray(y)), len(model.selected_features_)
df_resid = n_obs - k_terms

print("Parameter Significance (95% CI):")
print(f"  {'Term':>15s} {'Estimate':>10s} {'Std Err':>9s} {'t':>8s} {'p-value':>10s} 95% CI")
print("  " + "-" * 75)
for name, (est, lo, hi, se) in intervals.items():
    t_val = est / se if abs(se) > 1e-15 else float("inf")
    p_val = float(2 * (1 - sp_stats.t.cdf(abs(t_val), df_resid))) if df_resid > 0 else 0.0
    sig = "***" if p_val < 0.001 else ("**" if p_val < 0.01 else ("*" if p_val < 0.05 else ""))
    print(f"  {name:>15s} {est:10.4f} {se:9.4f} {t_val:8.2f} {p_val:10.2e} [{lo:.4f}, {hi:.4f}] {sig}")
print("  --- *** p<0.001, ** p<0.01, * p<0.05")

y_pred = model.predict(X)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
plot_parity(y, y_pred, ax=axes[0], title="Langmuir-Hinshelwood: Parity")
residuals = np.array(y - y_pred)
axes[1].scatter(np.array(y_pred), residuals, alpha=0.6, c="steelblue", edgecolors="white", linewidth=0.5)
axes[1].axhline(y=0, color="r", linestyle="--")
axes[1].set_xlabel("Predicted")
axes[1].set_ylabel("Residuals")
axes[1].set_title("Langmuir-Hinshelwood: Residuals")
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

anova_result = anova(model)
summary_sources = {"Model", "Residual", "Total"}
print("\nANOVA Table (Langmuir-Hinshelwood)")
print("=" * 80)
print(f"  {'Source':25s}  {'DF':>4}  {'Sum Sq':>12}  {'Mean Sq':>12}  {'F':>10}  {'p-value':>10}")
print("-" * 80)
for row in anova_result.rows:
    f_str = f"{row.f_value:10.2f}" if row.f_value is not None else "          "
    p_str = f"{row.p_value:10.4f}" if row.p_value is not None else "          "
    print(f"  {row.source:25s}  {row.df:4d}  {row.sum_sq:12.4f}  {row.mean_sq:12.4f}  {f_str}  {p_str}")
print("-" * 80)
term_rows = [r for r in anova_result.rows if r.source not in summary_sources]
if term_rows:
    model_ss = sum(r.sum_sq for r in term_rows)
    print("\nVariance Contributions:")
    print("(Percentages relative to Model SS, not Total SS — shows relative importance within fitted model)")
    for row in term_rows:
        pct = 100 * row.sum_sq / model_ss if model_ss > 0 else 0
        sig = (
            "***" if row.p_value is not None and row.p_value < 0.001 else (
            "**" if row.p_value is not None and row.p_value < 0.01 else (
            "*" if row.p_value is not None and row.p_value < 0.05 else ""))
        )
        print(f"  {row.source:25s}  {pct:6.1f}%  {sig}")

### Note on the discovered expression

The discovered 6-term polynomial-rational expression provides a good empirical fit
(R^2 = 0.993) but is **not** the true Langmuir-Hinshelwood form
`r = k*C_A*C_B / (1 + K*C_A)`. This is expected: the basis library does not include
the exact L-H functional form as a single basis function, so JAXSR approximates the
underlying relationship using the available building blocks (polynomials, ratios, etc.).

Because the discovered expression is a polynomial approximation rather than the
mechanistic form, we cannot extract the physically meaningful rate constant *k* or
adsorption equilibrium constant *K* from the fitted coefficients.

**Soft bounds constraint.** The `add_bounds("y", lower=0)` constraint used above is
**soft** by default (`hard=False`), meaning it penalizes but does not strictly prevent
negative predictions. Since the model includes a negative intercept (-0.1014),
predictions could go negative for small concentrations near the boundary of the
training domain. Pass `hard=True` to enforce strict non-negativity.

**Recovering the exact L-H form.** If the functional form is known (or hypothesized),
use `add_parametric()` to encode it directly as a basis function. This enables JAXSR
to fit the exact L-H expression with identifiable physical parameters, as demonstrated
in the `model_comparison_isotherms` and `langmuir_doe_active_learning` notebooks.

## Discover power law kinetics.

True model:

$$r = 1.5\,C_A^{1.0}\,C_B^{0.5}$$

In [4]:
np.random.seed(42)
n_samples = 100

C_A = np.random.uniform(0.5, 3.0, n_samples)
C_B = np.random.uniform(0.5, 3.0, n_samples)

# True parameters
k = 1.5
a = 1.0  # First order in A
b = 0.5  # Half order in B

r_true = k * C_A**a * C_B**b
r = r_true + np.random.randn(n_samples) * 0.02

X = jnp.column_stack([C_A, C_B])
y = jnp.array(r)

print("\nTrue model: r = 1.5 * C_A^1.0 * C_B^0.5")

# For power law, include sqrt for half-order
library = (
    BasisLibrary(n_features=2, feature_names=["C_A", "C_B"])
    .add_constant()
    .add_linear()
    .add_polynomials(max_degree=2)
    .add_interactions(max_order=2)
    .add_transcendental(["sqrt"])
)

# Add custom basis function for C_A * sqrt(C_B)
library.add_custom(
    name="C_A*sqrt(C_B)",
    func=lambda X: X[:, 0] * jnp.sqrt(X[:, 1]),
    complexity=2,
)

model = SymbolicRegressor(
    basis_library=library,
    max_terms=5,
    strategy="greedy_forward",
)
model.fit(X, y)

print("\nDiscovered expression:")
print(f"  {model.expression_}")

# Check for spurious/negligible terms
print("\nTerm significance:")
max_coef = max(abs(c) for c in model.coefficients_)
for name, coef in zip(model.selected_features_, model.coefficients_, strict=False):
    rel_magnitude = abs(coef) / max_coef
    if rel_magnitude < 0.01:
        flag = "(negligible)"
    elif rel_magnitude > 0.5:
        flag = "(dominant)"
    else:
        flag = ""
    print(f"  {name:15s}: {float(coef):10.4f}  {flag}")
print(f"  R² = {model.metrics_['r2']:.4f}")

In [5]:
# Parameter significance, diagnostics, and ANOVA
intervals = model.coefficient_intervals(alpha=0.05)
n_obs, k_terms = len(np.asarray(y)), len(model.selected_features_)
df_resid = n_obs - k_terms

print("Parameter Significance (95% CI):")
print(f"  {'Term':>15s} {'Estimate':>10s} {'Std Err':>9s} {'t':>8s} {'p-value':>10s} 95% CI")
print("  " + "-" * 75)
for name, (est, lo, hi, se) in intervals.items():
    t_val = est / se if abs(se) > 1e-15 else float("inf")
    p_val = float(2 * (1 - sp_stats.t.cdf(abs(t_val), df_resid))) if df_resid > 0 else 0.0
    sig = "***" if p_val < 0.001 else ("**" if p_val < 0.01 else ("*" if p_val < 0.05 else ""))
    print(f"  {name:>15s} {est:10.4f} {se:9.4f} {t_val:8.2f} {p_val:10.2e} [{lo:.4f}, {hi:.4f}] {sig}")
print("  --- *** p<0.001, ** p<0.01, * p<0.05")

y_pred = model.predict(X)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
plot_parity(y, y_pred, ax=axes[0], title="Power Law: Parity")
residuals = np.array(y - y_pred)
axes[1].scatter(np.array(y_pred), residuals, alpha=0.6, c="steelblue", edgecolors="white", linewidth=0.5)
axes[1].axhline(y=0, color="r", linestyle="--")
axes[1].set_xlabel("Predicted")
axes[1].set_ylabel("Residuals")
axes[1].set_title("Power Law: Residuals")
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

anova_result = anova(model)
summary_sources = {"Model", "Residual", "Total"}
print("\nANOVA Table (Power Law)")
print("=" * 80)
print(f"  {'Source':25s}  {'DF':>4}  {'Sum Sq':>12}  {'Mean Sq':>12}  {'F':>10}  {'p-value':>10}")
print("-" * 80)
for row in anova_result.rows:
    f_str = f"{row.f_value:10.2f}" if row.f_value is not None else "          "
    p_str = f"{row.p_value:10.4f}" if row.p_value is not None else "          "
    print(f"  {row.source:25s}  {row.df:4d}  {row.sum_sq:12.4f}  {row.mean_sq:12.4f}  {f_str}  {p_str}")
print("-" * 80)
term_rows = [r for r in anova_result.rows if r.source not in summary_sources]
if term_rows:
    model_ss = sum(r.sum_sq for r in term_rows)
    print("\nVariance Contributions:")
    print("(Percentages relative to Model SS, not Total SS — shows relative importance within fitted model)")
    for row in term_rows:
        pct = 100 * row.sum_sq / model_ss if model_ss > 0 else 0
        sig = (
            "***" if row.p_value is not None and row.p_value < 0.001 else (
            "**" if row.p_value is not None and row.p_value < 0.01 else (
            "*" if row.p_value is not None and row.p_value < 0.05 else ""))
        )
        print(f"  {row.source:25s}  {pct:6.1f}%  {sig}")

### Note on the spurious C_B^2 term

The discovered expression `y = 1.498*C_A*sqrt(C_B) + 0.002122*C_B^2` includes a
spurious `C_B^2` term with a negligibly small coefficient (0.002). The dominant term
`1.498*C_A*sqrt(C_B)` closely matches the true model `1.5*C_A*sqrt(C_B)`.

The BIC penalty was not large enough to prefer the simpler 1-term model over the
2-term model, because the additional term marginally reduces the residual sum of
squares. In practice, the `C_B^2` contribution is negligible across the data range
and can be safely dropped when interpreting the result.

## Discover Arrhenius temperature dependence.

True model:

$$\ln k = \ln A - \frac{E_a}{RT}$$

In [6]:
np.random.seed(42)
n_samples = 50

# Temperature range (K)
T = np.random.uniform(300, 500, n_samples)

# Arrhenius parameters
A = 1e6  # Pre-exponential factor
Ea = 50000  # Activation energy (J/mol)
R = 8.314  # Gas constant (J/mol/K)

# True rate constant
k_true = A * np.exp(-Ea / (R * T))
# Work in log space for better fitting
log_k = np.log(k_true) + np.random.randn(n_samples) * 0.05

# Use 1/T as the feature (linearized Arrhenius)
X = jnp.array(1000 / T).reshape(-1, 1)  # 1000/T in 1/K
y = jnp.array(log_k)

print("\nTrue model: ln(k) = ln(A) - Ea/(R*T)")
print(f"Or: ln(k) = {np.log(A):.2f} - {Ea/R/1000:.2f} * (1000/T)")

# Simple linear library for linearized Arrhenius
library = BasisLibrary(n_features=1, feature_names=["1000/T"]).add_constant().add_linear()

model = SymbolicRegressor(
    basis_library=library,
    max_terms=2,
    strategy="exhaustive",
)
model.fit(X, y)

print("\nDiscovered expression:")
print(f"  {model.expression_}")
print(f"  R² = {model.metrics_['r2']:.4f}")

# Extract parameters
if "1" in model.selected_features_:
    idx_const = model.selected_features_.index("1")
    ln_A = float(model.coefficients_[idx_const])
    print("\nExtracted parameters:")
    print(f"  ln(A) = {ln_A:.2f} (true: {np.log(A):.2f})")

if "1000/T" in model.selected_features_:
    idx_T = model.selected_features_.index("1000/T")
    slope = float(model.coefficients_[idx_T])
    Ea_fit = -slope * R * 1000
    print(f"  Ea = {Ea_fit:.0f} J/mol (true: {Ea} J/mol)")


True model: ln(k) = ln(A) - Ea/(R*T)
Or: ln(k) = 13.82 - 6.01 * (1000/T)



Discovered expression:
  y = 13.79 - 6.006*1000/T
  R² = 0.9996

Extracted parameters:
  ln(A) = 13.79 (true: 13.82)
  Ea = 49932 J/mol (true: 50000 J/mol)


In [7]:
# Parameter significance, diagnostics, and ANOVA
intervals = model.coefficient_intervals(alpha=0.05)
n_obs, k_terms = len(np.asarray(y)), len(model.selected_features_)
df_resid = n_obs - k_terms

print("Parameter Significance (95% CI):")
print(f"  {'Term':>15s} {'Estimate':>10s} {'Std Err':>9s} {'t':>8s} {'p-value':>10s} 95% CI")
print("  " + "-" * 75)
for name, (est, lo, hi, se) in intervals.items():
    t_val = est / se if abs(se) > 1e-15 else float("inf")
    p_val = float(2 * (1 - sp_stats.t.cdf(abs(t_val), df_resid))) if df_resid > 0 else 0.0
    sig = "***" if p_val < 0.001 else ("**" if p_val < 0.01 else ("*" if p_val < 0.05 else ""))
    print(f"  {name:>15s} {est:10.4f} {se:9.4f} {t_val:8.2f} {p_val:10.2e} [{lo:.4f}, {hi:.4f}] {sig}")
print("  --- *** p<0.001, ** p<0.01, * p<0.05")

y_pred = model.predict(X)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
plot_parity(y, y_pred, ax=axes[0], title="Arrhenius: Parity")
residuals = np.array(y - y_pred)
axes[1].scatter(np.array(y_pred), residuals, alpha=0.6, c="steelblue", edgecolors="white", linewidth=0.5)
axes[1].axhline(y=0, color="r", linestyle="--")
axes[1].set_xlabel("Predicted")
axes[1].set_ylabel("Residuals")
axes[1].set_title("Arrhenius: Residuals")
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

anova_result = anova(model)
summary_sources = {"Model", "Residual", "Total"}
print("\nANOVA Table (Arrhenius)")
print("=" * 80)
print(f"  {'Source':25s}  {'DF':>4}  {'Sum Sq':>12}  {'Mean Sq':>12}  {'F':>10}  {'p-value':>10}")
print("-" * 80)
for row in anova_result.rows:
    f_str = f"{row.f_value:10.2f}" if row.f_value is not None else "          "
    p_str = f"{row.p_value:10.4f}" if row.p_value is not None else "          "
    print(f"  {row.source:25s}  {row.df:4d}  {row.sum_sq:12.4f}  {row.mean_sq:12.4f}  {f_str}  {p_str}")
print("-" * 80)
term_rows = [r for r in anova_result.rows if r.source not in summary_sources]
if term_rows:
    model_ss = sum(r.sum_sq for r in term_rows)
    print("\nVariance Contributions:")
    print("(Percentages relative to Model SS, not Total SS — shows relative importance within fitted model)")
    for row in term_rows:
        pct = 100 * row.sum_sq / model_ss if model_ss > 0 else 0
        sig = (
            "***" if row.p_value is not None and row.p_value < 0.001 else (
            "**" if row.p_value is not None and row.p_value < 0.01 else (
            "*" if row.p_value is not None and row.p_value < 0.05 else ""))
        )
        print(f"  {row.source:25s}  {pct:6.1f}%  {sig}")

## Discover competitive adsorption kinetics.

True model:

$$r = \frac{3.0\,C_A\,C_B}{1 + 0.8\,C_A + 1.5\,C_B}$$

In [8]:
np.random.seed(42)
n_samples = 150

C_A = np.random.uniform(0.1, 2.0, n_samples)
C_B = np.random.uniform(0.1, 2.0, n_samples)

# Kinetic parameters
k = 3.0
K_A = 0.8
K_B = 1.5

r_true = k * C_A * C_B / (1 + K_A * C_A + K_B * C_B)
r = r_true + np.random.randn(n_samples) * 0.03

X = jnp.column_stack([C_A, C_B])
y = jnp.array(r)

print("\nTrue model: r = 3.0*C_A*C_B / (1 + 0.8*C_A + 1.5*C_B)")

# Build comprehensive library
library = (
    BasisLibrary(n_features=2, feature_names=["C_A", "C_B"])
    .add_constant()
    .add_linear()
    .add_polynomials(max_degree=2)
    .add_interactions(max_order=2)
    .add_ratios()
)

# Add custom rational functions
library.add_custom(
    name="C_A*C_B/(1+C_A)",
    func=lambda X: X[:, 0] * X[:, 1] / (1 + X[:, 0]),
    complexity=3,
)
library.add_custom(
    name="C_A*C_B/(1+C_B)",
    func=lambda X: X[:, 0] * X[:, 1] / (1 + X[:, 1]),
    complexity=3,
)
library.add_custom(
    name="C_A*C_B/(1+C_A+C_B)",
    func=lambda X: X[:, 0] * X[:, 1] / (1 + X[:, 0] + X[:, 1]),
    complexity=4,
)

model = SymbolicRegressor(
    basis_library=library,
    max_terms=5,
    strategy="greedy_forward",
)
model.fit(X, y)

print("\nDiscovered expression:")
print(f"  {model.expression_}")
print(f"  R² = {model.metrics_['r2']:.4f}")


True model: r = 3.0*C_A*C_B / (1 + 0.8*C_A + 1.5*C_B)



Discovered expression:
  y = 3.393*C_A*C_B/(1+C_A+C_B) - 0.4771*C_A*C_B/(1+C_A) + 0.02056*C_A^2
  R² = 0.9972


In [9]:
# Parameter significance, diagnostics, and ANOVA
intervals = model.coefficient_intervals(alpha=0.05)
n_obs, k_terms = len(np.asarray(y)), len(model.selected_features_)
df_resid = n_obs - k_terms

print("Parameter Significance (95% CI):")
print(f"  {'Term':>15s} {'Estimate':>10s} {'Std Err':>9s} {'t':>8s} {'p-value':>10s} 95% CI")
print("  " + "-" * 75)
for name, (est, lo, hi, se) in intervals.items():
    t_val = est / se if abs(se) > 1e-15 else float("inf")
    p_val = float(2 * (1 - sp_stats.t.cdf(abs(t_val), df_resid))) if df_resid > 0 else 0.0
    sig = "***" if p_val < 0.001 else ("**" if p_val < 0.01 else ("*" if p_val < 0.05 else ""))
    print(f"  {name:>15s} {est:10.4f} {se:9.4f} {t_val:8.2f} {p_val:10.2e} [{lo:.4f}, {hi:.4f}] {sig}")
print("  --- *** p<0.001, ** p<0.01, * p<0.05")

y_pred = model.predict(X)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
plot_parity(y, y_pred, ax=axes[0], title="Competitive Adsorption: Parity")
residuals = np.array(y - y_pred)
axes[1].scatter(np.array(y_pred), residuals, alpha=0.6, c="steelblue", edgecolors="white", linewidth=0.5)
axes[1].axhline(y=0, color="r", linestyle="--")
axes[1].set_xlabel("Predicted")
axes[1].set_ylabel("Residuals")
axes[1].set_title("Competitive Adsorption: Residuals")
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

anova_result = anova(model)
summary_sources = {"Model", "Residual", "Total"}
print("\nANOVA Table (Competitive Adsorption)")
print("=" * 80)
print(f"  {'Source':25s}  {'DF':>4}  {'Sum Sq':>12}  {'Mean Sq':>12}  {'F':>10}  {'p-value':>10}")
print("-" * 80)
for row in anova_result.rows:
    f_str = f"{row.f_value:10.2f}" if row.f_value is not None else "          "
    p_str = f"{row.p_value:10.4f}" if row.p_value is not None else "          "
    print(f"  {row.source:25s}  {row.df:4d}  {row.sum_sq:12.4f}  {row.mean_sq:12.4f}  {f_str}  {p_str}")
print("-" * 80)
term_rows = [r for r in anova_result.rows if r.source not in summary_sources]
if term_rows:
    model_ss = sum(r.sum_sq for r in term_rows)
    print("\nVariance Contributions:")
    print("(Percentages relative to Model SS, not Total SS — shows relative importance within fitted model)")
    for row in term_rows:
        pct = 100 * row.sum_sq / model_ss if model_ss > 0 else 0
        sig = (
            "***" if row.p_value is not None and row.p_value < 0.001 else (
            "**" if row.p_value is not None and row.p_value < 0.01 else (
            "*" if row.p_value is not None and row.p_value < 0.05 else ""))
        )
        print(f"  {row.source:25s}  {pct:6.1f}%  {sig}")