Online weight optimization via Thompson Sampling. Learns optimal configurations from outcome feedback — no grid search, no manual tuning. Converges in ~50 observations. +26% NDCG@5 over fixed-weight baselines in controlled experiments.
pip install fonz-dial
from thompson_bandits import ThompsonBandit, InMemoryStore
store = InMemoryStore(arm_ids=["relevance_heavy", "balanced", "recency_heavy"])
bandit = ThompsonBandit(store)
# Run the loop: select → observe → update
for query in queries:
arm = bandit.select()
reward = run_query(query, strategy=arm)
bandit.update(arm, reward=reward)
print(bandit.get_summary())After 50 iterations:
BanditSummary(
best_arm='relevance_heavy',
total_pulls=50,
arms=[
ArmSummary(arm_id='balanced', mean=0.5765, pulls=11),
ArmSummary(arm_id='recency_heavy', mean=0.4210, pulls=8),
ArmSummary(arm_id='relevance_heavy', mean=0.8903, pulls=31),
]
)
The bandit explores all three options early, then converges — 31 of 50 pulls on the winner, without you telling it which arm is best.
vs. grid search / random search — Those require running every combination upfront. Dial learns online, one observation at a time. No batch experiments needed.
vs. manual tuning — Manual weights are a guess that stays frozen. Dial adapts when the best option shifts — user behavior drifts, data distributions change, what worked in January fails in March.
vs. contextual bandits (LinUCB, neural) — Those need feature engineering and thousands of observations. Dial works with 50 observations and zero features. Start with Dial; graduate to contextual bandits when you have the data to justify them.
vs. Bayesian optimization (Optuna, Ax) — Those optimize over continuous parameter spaces. Dial optimizes over discrete options (strategies, presets, model choices). Different problem shape.
- Retrieval weight tuning — learn the optimal blend of relevance, recency, and importance for RAG systems
- Model routing — discover which LLM performs best for different query types
- Prompt selection — A/B test prompt variants with automatic convergence
- Feature flag rollout — promote variants based on measured outcomes
- Any multi-option decision where you can observe a reward signal
- Beta posteriors — each arm maintains a
Beta(alpha, beta)distribution updated with observed rewards - Discounted Thompson Sampling — optional decay factor for non-stationary environments where the best arm shifts over time
- Cost-aware rewards — built-in
cost_aware_reward()scales outcomes by resource efficiency - Pluggable storage —
InMemoryStorefor testing,SQLiteStorefor persistence, or implement theArmStoreprotocol for anything else - Zero SQLite dependency in core — bandit logic talks only to the
ArmStoreprotocol - Type-safe — full annotations,
runtime_checkableProtocol
from thompson_bandits import InMemoryStore
store = InMemoryStore(arm_ids=["a", "b", "c"], prior_alpha=1.0, prior_beta=1.0)from thompson_bandits import SQLiteStore
# From a file path (store owns the connection)
store = SQLiteStore.from_path("bandits.db", arm_ids=["a", "b", "c"])
# From an existing connection (you own the connection)
import sqlite3
conn = sqlite3.connect("bandits.db")
store = SQLiteStore(conn, arm_ids=["a", "b", "c"])Implement the ArmStore protocol — any class with the right methods works, no inheritance required:
from thompson_bandits import ArmStore, ArmStats
class RedisStore:
def get_stats(self, arm_id: str) -> ArmStats | None: ...
def update_stats(self, arm_id: str, alpha_delta: float, beta_delta: float, reward: float) -> None: ...
def get_all_arms(self) -> list[ArmStats]: ...
def decay(self, arm_id: str, factor: float) -> None: ...When the best option changes over time, enable discounting:
from thompson_bandits import ThompsonBandit, InMemoryStore, BanditConfig
config = BanditConfig(discount=0.95) # decay factor in (0, 1)
bandit = ThompsonBandit(store, config=config)Before each update, existing evidence is decayed by the discount factor. Recent observations carry more weight than old ones.
When options have different costs (tokens, latency, dollars), scale rewards accordingly:
from thompson_bandits import cost_aware_reward
raw_reward = 0.9
token_cost = 1500
baseline_cost = 1000
adjusted = cost_aware_reward(raw_reward, cost=token_cost, baseline_cost=baseline_cost)
bandit.update(arm, reward=adjusted)summary = bandit.get_summary()
print(summary.best_arm) # 'relevance_heavy'
print(summary.total_pulls) # 50
for arm in summary.arms:
print(f"{arm.arm_id}: mean={arm.mean:.3f}, pulls={arm.pulls}")
# balanced: mean=0.577, pulls=11
# recency_heavy: mean=0.421, pulls=8
# relevance_heavy: mean=0.890, pulls=31When you have prior knowledge (from a previous experiment, a related task, or domain expertise), encode it as informative priors instead of starting from uniform:
from thompson_bandits import ThompsonBandit, InMemoryStore, BanditConfig
# Previous experiment found relevance_heavy won ~63% of pulls.
# Encode that as Beta(6.3, 3.7) instead of the default Beta(1, 1).
config = BanditConfig(prior_alpha=1.0, prior_beta=1.0)
store = InMemoryStore(arm_ids=["relevance_heavy", "balanced", "recency_heavy"])
# Override priors for the arm with known history
arm = store.get_stats("relevance_heavy")
arm.alpha = 6.3
arm.beta = 3.7
bandit = ThompsonBandit(store, config=config)The bandit starts biased toward the prior winner but remains open to switching if the data disagrees. With shrinkage (e.g., scaling the prior by 0.15), the prior influence fades within ~20 observations.
ThompsonBandit(store: ArmStore, config: BanditConfig | None = None, rng: numpy.random.Generator | None = None)The main entry point. Wraps any ArmStore backend with Thompson Sampling logic.
select() -> str— Draws a sample from each arm's Beta posterior and returns the arm_id with the highest sample. RaisesValueErrorif the store is empty.update(arm_id: str, reward: float) -> None— Updates the posterior for the given arm. Reward must be in[0, 1]. If discounting is enabled, existing evidence is decayed before the new observation is applied.get_summary() -> BanditSummary— Returns a serialization-friendly snapshot of all arms.get_arm(arm_id: str) -> ArmStats | None— Returns stats for one arm, orNone.get_arms() -> list[ArmStats]— Returns stats for all arms.
BanditConfig(discount: float | None = None, prior_alpha: float = 1.0, prior_beta: float = 1.0)discount— Decay factor in(0, 1)for non-stationary environments.Nonedisables discounting.prior_alpha,prior_beta— Initial Beta parameters for new arms. Defaults give a uniformBeta(1, 1)prior.
ArmStats(arm_id: str, alpha: float = 1.0, beta: float = 1.0, pulls: int = 0, total_reward: float = 0.0)State of one arm. Also exposes two computed properties:
mean -> float— Expected value:alpha / (alpha + beta).variance -> float— Variance of the Beta distribution.
BanditSummary(arms: list[ArmSummary], best_arm: str | None, total_pulls: int)
ArmSummary(arm_id: str, alpha: float, beta: float, mean: float, pulls: int)Returned by get_summary(). best_arm is the arm_id with the highest mean, or None if no arms exist.
InMemoryStore(arm_ids: list[str] | None = None, prior_alpha: float = 1.0, prior_beta: float = 1.0)Dict-backed, ephemeral. Good for testing and short-lived processes.
SQLiteStore(conn: sqlite3.Connection, arm_ids: list[str] | None = None, prior_alpha: float = 1.0, prior_beta: float = 1.0, table_name: str = "bandit_arms")
SQLiteStore.from_path(db_path: str | Path, ...) -> SQLiteStore # creates connection with WAL modePersistent storage. from_path owns the connection; the constructor form lets you pass an existing connection. Creates the table if it doesn't exist. Safe to re-initialize — existing arms are preserved (INSERT OR IGNORE).
class ArmStore(Protocol):
def get_stats(self, arm_id: str) -> ArmStats | None: ...
def update_stats(self, arm_id: str, alpha_delta: float, beta_delta: float, reward: float) -> None: ...
def get_all_arms(self) -> list[ArmStats]: ...
def decay(self, arm_id: str, factor: float) -> None: ...runtime_checkable. Implement these four methods on any class to create a custom backend — no inheritance required.
cost_aware_reward(raw_reward: float, cost: float, baseline_cost: float = 1.0) -> floatScales a reward by cost efficiency: raw_reward * (baseline_cost / cost), clamped to [0, 1]. Cheaper-than-baseline episodes get a boost; more expensive ones get penalized. Zero cost passes through unchanged.
Dial extracts the Thompson Sampling engine from a research experiment on gradient-free retrieval weight learning. The experiment ran 1,200 episodes across 4 conditions on a $50/month API budget.
Citation (BibTeX)
@article{dirocco2026gradient,
title = {Gradient-Free Retrieval Weight Learning via Thompson Sampling
with LLM Self-Assessment},
author = {DiRocco, Alfonso},
year = {2026},
url = {https://github.com/kusp-dev/retrieval-weight-experiment},
note = {1,200 episodes, 4 conditions, +26\% NDCG@5 over fixed baselines}
}git clone https://github.com/fonz-ai/dial.git
cd dial
uv sync --extra dev
uv run pytest tests/ -v
uv run ruff check src/ tests/MIT