Skip to content

v0.4.0

Choose a tag to compare

@jaepil jaepil released this 24 Feb 13:30
· 40 commits to main since this release
5091ed1

What's Changed

  • LearnableLogOddsWeights for per-signal reliability learning (Remark 5.3.2)

    • Learns weights from Naive Bayes uniform initialization (w_i = 1/n) to per-signal reliability weights via softmax parameterization over unconstrained logits
    • Completes the correspondence to a fully parameterized single-layer network in log-odds space: logit -> weighted sum -> sigmoid
    • Hebbian gradient: dL/dz_j = n^alpha * (p - y) * w_j * (x_j - x_bar_w) (pre-synaptic activity x post-synaptic error, backprop-free)
    • Batch fit() via gradient descent on BCE loss
    • Online update() via SGD with EMA-smoothed gradients, bias correction, L2 gradient clipping, learning rate decay, and Polyak averaging of weights in the simplex
    • Alpha (confidence scaling) is fixed, only weights are learned; the two are orthogonal (Paper 2, Section 4.2)
  • Theorem verification tests for Remark 5.3.2

    • Naive Bayes initialization: uniform 1/n weights match unweighted conjunction
    • Hebbian gradient structure: zero gradient when signals identical, correct direction for overestimating signals
    • Theorem 5.3.1: equal-quality signals maintain approximately uniform weights
  • Learnable weights benchmark (benchmarks/learnable_weights.py)

    • Weight recovery accuracy across 2--5 signals with varying noise
    • Fusion quality comparison: uniform vs oracle vs learned weights (BCE, MSE, Spearman)
    • Online convergence tracking: update() vs fit() target
    • Timing measurements for fit() and update() at various scales
  • Learnable fusion example (examples/learnable_fusion.py)

    • Batch fit, online update, Polyak-averaged inference, and alpha confidence scaling for a 3-signal hybrid search system

Quick Start

from bayesian_bm25 import LearnableLogOddsWeights

learner = LearnableLogOddsWeights(n_signals=3, alpha=0.0)
learner.fit(training_probs, training_labels, learning_rate=0.1)

for probs, label in feedback_stream:
    learner.update(probs, label, learning_rate=0.05, momentum=0.9)

fused = learner(test_probs, use_averaged=True)

Full Changelog: v0.3.2...v0.4.0