# Active Learning

Active Learning & Acquisition Functions Example for JAXSR.

Demonstrates how to use acquisition functions to intelligently select
the next experiments to run, balancing exploration and exploitation.

This example covers five scenarios:

1. Pure exploration: reduce uncertainty everywhere
2. Bayesian optimisation: find the minimum of an unknown function
3. Model discrimination: resolve which model structure is correct
4. Batch selection strategies: greedy vs penalized vs kriging believer
5. Full active learning loop: iteratively improve a model

In [1]:
import jax.numpy as jnp
import numpy as np
from jaxsr import BasisLibrary, SymbolicRegressor
from jaxsr.acquisition import (
    LCB,
    ActiveLearner,
    AOptimal,
    BMAUncertainty,
    ConfidenceBandWidth,
    DOptimal,
    EnsembleDisagreement,
    ExpectedImprovement,
    ModelDiscrimination,
    ModelMin,
    PredictionVariance,
    ProbabilityOfImprovement,
    ThompsonSampling,
    suggest_points,
)

## Setup: Create a fitted model

We'll use this model across all examples below.

In [2]:
np.random.seed(42)
X = np.random.uniform(0, 5, (40, 1))
y = X[:, 0] ** 2 - 3.0 * X[:, 0] + 2.0 + np.random.randn(40) * 0.3

library = (
    BasisLibrary(n_features=1, feature_names=["x"])
    .add_constant()
    .add_linear()
    .add_polynomials(max_degree=3)
)
model = SymbolicRegressor(basis_library=library, max_terms=4, strategy="greedy_forward")
model.fit(jnp.array(X), jnp.array(y))

SymbolicRegressor(max_terms=4, strategy='greedy_forward', fitted)

## Goal: improve model accuracy uniformly by sampling where prediction uncertainty is highest

    Available acquisition functions for this goal:

    - PredictionVariance: the default.  Uses the OLS posterior to compute
      sigma^2(x).  Fast, exact for linear-in-parameter models.

    - ConfidenceBandWidth: similar, but reports the actual width of the
      confidence band at a specified significance level.

    - EnsembleDisagreement: uses the Pareto front of models with different
      complexities.  Good when you're unsure whether a simpler or more
      complex model is appropriate.

    - BMAUncertainty: Bayesian Model Averaging.  The most comprehensive
      measure -- combines noise uncertainty and model-selection uncertainty.

    - AOptimal: targets reduction in *parameter* uncertainty (trace of
      covariance matrix).  Use when you care about accurate coefficients.

    - DOptimal: maximises information gain (det of information matrix).
      Use when you want maximal information per experiment.

In [3]:
bounds = [(0.0, 5.0)]

print(f"Current model: {model.expression_}")
print(f"Current R^2:   {model.score(model._X_train, model._y_train):.4f}")
print(f"Training size: {len(model._y_train)}")
print()

# --- Try different exploration strategies ---
strategies = [
    ("PredictionVariance", PredictionVariance()),
    ("ConfidenceBandWidth(95%)", ConfidenceBandWidth(alpha=0.05)),
    ("EnsembleDisagreement", EnsembleDisagreement()),
    ("BMAUncertainty", BMAUncertainty(criterion="bic")),
    ("AOptimal", AOptimal()),
    ("DOptimal", DOptimal()),
]

for name, acq in strategies:
    result = suggest_points(model, bounds, acq, n_points=3, random_state=42)
    pts = np.array(result.points).ravel()
    print(f"  {name:30s} -> x = [{', '.join(f'{p:.2f}' for p in pts)}]")

print()

Current model: y = 0.0132*x^3 + 0.8891*x^2 + 1.827 - 2.736*x
Current R^2:   0.9941
Training size: 40



  PredictionVariance             -> x = [5.00, 0.01, 0.00]
  ConfidenceBandWidth(95%)       -> x = [5.00, 0.01, 0.00]
  EnsembleDisagreement           -> x = [0.01, 0.01, 0.00]


  BMAUncertainty                 -> x = [4.95, 1.69, 1.59]
  AOptimal                       -> x = [0.01, 0.01, 0.00]
  DOptimal                       -> x = [5.00, 0.01, 0.00]



## Goal: find x that minimises y, using the fitted model as a surrogate.

Available acquisition functions for this goal:

    - ModelMin / ModelMax: pure exploitation.  No exploration at all --
      just returns the predicted optimum.  Use only when you fully trust
      the model.

    - LCB (Lower Confidence Bound): y_hat - kappa*sigma.  The kappa
      parameter controls exploration vs exploitation:
        kappa=0  -> pure exploitation (ModelMin)
        kappa~2  -> balanced
        kappa>3  -> heavy exploration

    - UCB (Upper Confidence Bound): the mirror image for maximisation.

    - ExpectedImprovement (EI): the Bayesian optimisation gold standard.
      Naturally balances exploration and exploitation without a tuning
      parameter (just xi, which is usually small).  Recommended as the
      default for optimisation.

    - ProbabilityOfImprovement (PI): similar to EI but only cares about
      the *probability* of beating the current best, not the magnitude
      of improvement.  More exploitative than EI for the same xi.

    - ThompsonSampling: draws a random model from the posterior and
      optimises that.  Produces diverse batches naturally.

In [4]:
bounds = [(0.0, 5.0)]

print(f"Model: {model.expression_}")
print("True minimum at x=1.5 (y = 2.25 - 4.5 + 2 = -0.25)")
print()

strategies = [
    ("ModelMin (exploit only)", ModelMin()),
    ("LCB kappa=0.5 (exploitative)", LCB(kappa=0.5)),
    ("LCB kappa=2 (balanced)", LCB(kappa=2.0)),
    ("LCB kappa=5 (exploratory)", LCB(kappa=5.0)),
    ("Expected Improvement", ExpectedImprovement(minimize=True)),
    ("Prob. of Improvement", ProbabilityOfImprovement(minimize=True)),
    ("Thompson Sampling", ThompsonSampling(minimize=True, seed=42)),
]

for name, acq in strategies:
    result = suggest_points(model, bounds, acq, n_points=3, random_state=42)
    pts = np.array(result.points).ravel()
    print(f"  {name:35s} -> x = [{', '.join(f'{p:.2f}' for p in pts)}]")

print()

Model: y = 0.0132*x^3 + 0.8891*x^2 + 1.827 - 2.736*x
True minimum at x=1.5 (y = 2.25 - 4.5 + 2 = -0.25)

  ModelMin (exploit only)             -> x = [1.40, 1.57, 1.41]
  LCB kappa=0.5 (exploitative)        -> x = [1.40, 1.57, 1.41]
  LCB kappa=2 (balanced)              -> x = [1.40, 1.41, 1.57]
  LCB kappa=5 (exploratory)           -> x = [1.58, 1.41, 1.57]


  Expected Improvement                -> x = [1.40, 1.41, 1.57]
  Prob. of Improvement                -> x = [1.40, 1.41, 1.57]
  Thompson Sampling                   -> x = [1.58, 1.58, 1.57]



## Goal: figure out which model form is correct.

When the Pareto front contains models of different complexities that
    all fit the data similarly, you need data points that *discriminate*
    between them.

    - ModelDiscrimination: scores candidates by the maximum disagreement
      among Pareto-front models.

    - EnsembleDisagreement: standard deviation across Pareto models.
      Similar idea but uses std instead of max-min range.

In [5]:
bounds = [(0.0, 5.0)]

print(f"Best model: {model.expression_}")
print(f"Pareto front has {len(model.pareto_front_)} models:")
for r in model.pareto_front_:
    print(f"  complexity={r.complexity}, BIC={r.bic:.1f}: {r.expression()}")
print()

acqs = [
    ("ModelDiscrimination", ModelDiscrimination()),
    ("EnsembleDisagreement", EnsembleDisagreement()),
]

for name, acq in acqs:
    result = suggest_points(model, bounds, acq, n_points=5, random_state=42)
    pts = np.array(result.points).ravel()
    print(f"  {name:25s} -> x = [{', '.join(f'{p:.2f}' for p in pts)}]")

print()

Best model: y = 0.0132*x^3 + 0.8891*x^2 + 1.827 - 2.736*x
Pareto front has 3 models:
  complexity=3, BIC=81.8: y = 0.09309*x^3
  complexity=5, BIC=63.0: y = 0.1661*x^3 - 0.3433*x^2 + 0.4584
  complexity=6, BIC=20.9: y = 0.0132*x^3 + 0.8891*x^2 + 1.827 - 2.736*x

  ModelDiscrimination       -> x = [0.02, 0.02, 0.01, 0.01, 0.00]
  EnsembleDisagreement      -> x = [0.02, 0.02, 0.01, 0.01, 0.00]



## Goal: select a *batch* of points that are collectively informative.

When you select the top-k by acquisition score (greedy), the points
    can cluster in one region.  Batch strategies address this:

    - greedy: top-k by raw score.  Fast but may cluster.

    - penalized: after selecting the best candidate, nearby candidates
      are penalised before selecting the next.  Simple diversity.

    - kriging_believer: after selecting each point, the model is
      temporarily updated with a "fantasy" observation (y_hat) and
      re-scored.  More sophisticated -- later selections account for
      information gained by earlier ones.

    - d_optimal: selects the batch that maximises det(Phi^T Phi),
      ignoring the acquisition function entirely.  Best for pure
      space-filling / information maximisation.

In [6]:
bounds = [(0.0, 5.0)]

learner = ActiveLearner(model, bounds, PredictionVariance(), random_state=42)

for strategy in ["greedy", "penalized", "kriging_believer", "d_optimal"]:
    result = learner.suggest(n_points=5, batch_strategy=strategy)
    pts = sorted(np.array(result.points).ravel())
    spread = pts[-1] - pts[0]
    print(
        f"  {strategy:20s} -> x = [{', '.join(f'{p:.2f}' for p in pts)}]"
        f"  (spread={spread:.2f})"
    )

print()

  greedy               -> x = [0.00, 0.01, 0.01, 4.99, 5.00]  (spread=4.99)


  penalized            -> x = [0.00, 1.40, 3.15, 4.25, 5.00]  (spread=5.00)


  kriging_believer     -> x = [0.00, 0.01, 0.01, 4.99, 5.00]  (spread=5.00)


  d_optimal            -> x = [0.00, 1.36, 3.49, 3.73, 5.00]  (spread=5.00)



## Goal: iteratively improve a model by running experiments.

The workflow is:
    1. Fit an initial model on a small dataset.
    2. Use an acquisition function to suggest new points.
    3. "Run the experiment" (here: evaluate the true function + noise).
    4. Update the model with the new data.
    5. Repeat until converged or budget exhausted.

In [7]:
# True function (unknown to the model)
def oracle(X):
    X = np.array(X)
    return X[:, 0] ** 2 - 3.0 * X[:, 0] + 2.0 + np.random.randn(len(X)) * 0.2

# Start with very few points
np.random.seed(0)
X_init = np.random.uniform(0, 5, (15, 1))
y_init = oracle(X_init)

library = (
    BasisLibrary(n_features=1, feature_names=["x"])
    .add_constant()
    .add_linear()
    .add_polynomials(max_degree=3)
)
model = SymbolicRegressor(basis_library=library, max_terms=4, strategy="greedy_forward")
model.fit(jnp.array(X_init), jnp.array(y_init))

print(f"Initial model ({len(y_init)} points): {model.expression_}")
print(f"  R^2 = {model.score(model._X_train, model._y_train):.4f}")
print(f"  MSE = {model.metrics_['mse']:.4f}")

# Active learning loop
learner = ActiveLearner(
    model,
    bounds=[(0.0, 5.0)],
    acquisition=ExpectedImprovement(minimize=True),
    random_state=42,
)

n_iterations = 5
points_per_iteration = 5

for i in range(n_iterations):
    result = learner.suggest(
        n_points=points_per_iteration,
        batch_strategy="penalized",
    )

    y_new = oracle(np.array(result.points))
    learner.update(result.points, jnp.array(y_new))

    print(
        f"  Iteration {i + 1}: "
        f"n={learner.n_observations}, "
        f"R^2={model.score(model._X_train, model._y_train):.4f}, "
        f"MSE={model.metrics_['mse']:.4f}, "
        f"model={model.expression_}"
    )

print(f"\nFinal model ({learner.n_observations} points): {model.expression_}")

Initial model (15 points): y = 0.01273*x^3 - 2.784*x + 1.862 + 0.9046*x^2
  R^2 = 0.9981
  MSE = 0.0236


  Iteration 1: n=20, R^2=0.9908, MSE=0.1040, model=y = 0.1458*x^3 - 0.2237*x^2


  Iteration 2: n=25, R^2=0.9902, MSE=0.1002, model=y = 0.1483*x^3 - 0.2344*x^2


  Iteration 3: n=30, R^2=0.9969, MSE=0.0280, model=y = 0.02827*x^3 + 0.777*x^2 + 1.729 - 2.496*x


  Iteration 4: n=35, R^2=0.9962, MSE=0.0300, model=y = 0.0279*x^3 + 0.7865*x^2 + 1.756 - 2.538*x


  Iteration 5: n=40, R^2=0.9957, MSE=0.0292, model=y = 0.01841*x^3 + 0.8644*x^2 + 1.875 - 2.722*x

Final model (40 points): y = 0.01841*x^3 + 0.8644*x^2 + 1.875 - 2.722*x


## Goal: combine multiple objectives using weighted acquisition.

You can weight and add acquisition functions together to balance
    different goals simultaneously.  Each component is min-max normalised
    before weighting so the weights are meaningful.

    Common recipes:
    - Balanced optimisation:  0.7 * EI + 0.3 * PredictionVariance
    - Exploration with model improvement:  0.5 * PredictionVariance + 0.5 * AOptimal
    - Multi-objective:  0.4 * ModelMin + 0.3 * PredictionVariance + 0.3 * DOptimal

In [8]:
bounds = [(0.0, 5.0)]

composites = [
    (
        "0.7*EI + 0.3*Variance",
        0.7 * ExpectedImprovement(minimize=True) + 0.3 * PredictionVariance(),
    ),
    (
        "0.5*LCB + 0.5*DOptimal",
        0.5 * LCB(kappa=2) + 0.5 * DOptimal(),
    ),
    (
        "Equal: EI + Var + AOptimal",
        ExpectedImprovement(minimize=True) + PredictionVariance() + AOptimal(),
    ),
]

for name, acq in composites:
    result = suggest_points(model, bounds, acq, n_points=3, random_state=42)
    pts = np.array(result.points).ravel()
    print(f"  {name:30s} -> x = [{', '.join(f'{p:.2f}' for p in pts)}]")

print()

  0.7*EI + 0.3*Variance          -> x = [4.99, 4.99, 5.00]
  0.5*LCB + 0.5*DOptimal         -> x = [1.23, 1.24, 1.24]
  Equal: EI + Var + AOptimal     -> x = [4.99, 4.99, 5.00]



## Decision Guide: Choosing an Acquisition Function

**WHAT IS YOUR GOAL?**

1. **IMPROVE MODEL ACCURACY** (explore everywhere)
   - Simple & fast? → `PredictionVariance`
   - Need coverage guarantee? → `ConfidenceBandWidth(alpha=0.05)`
   - Unsure about model form? → `EnsembleDisagreement` or `BMAUncertainty`
   - Tighten coefficient CIs? → `AOptimal`
   - Max info per experiment? → `DOptimal`

2. **FIND THE OPTIMUM** (minimise or maximise y)
   - Trust the model fully? → `ModelMin` / `ModelMax`
   - Want balanced exploration? → `ExpectedImprovement` (recommended)
   - Need probability of beating a threshold? → `ProbabilityOfImprovement`
   - Want explicit exploration knob? → `LCB(kappa)` / `UCB(kappa)`
   - Want randomised exploration? → `ThompsonSampling`

3. **DECIDE WHICH MODEL IS CORRECT**
   - Pareto models disagree? → `ModelDiscrimination`
   - Quantify structural uncertainty? → `EnsembleDisagreement`

4. **MULTIPLE OBJECTIVES**
   - Combine with weights: `0.7 * EI + 0.3 * PredictionVariance`

**BATCH STRATEGY SELECTION:**
- Fast, don't care about diversity? → `greedy`
- Want spatial diversity? → `penalized`
- Want information-aware batches? → `kriging_believer`
- Want maximum design efficiency? → `d_optimal`