Skip to content

elonlit/PopRiskMinimization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Population Risk Minimization for Neural Networks

popriskmin gives you PRM, a small modification of AdamW that trains on population risk instead of raw empirical risk. It keeps AdamW's usual one forward pass and one backward pass per step, then preconditions each parameter update with the population-risk mask from Litman & Guo (2026), A Theory of Generalization in Deep Learning.

PRM tracks one extra tensor per parameter tensor: an exponential moving average of centered minibatch gradient variance. For each scalar parameter $k$, it asks whether the batch-mean gradient is larger than the leave-one-out noise estimate:

$$ \mu_k^2 > \alpha \sigma_k^2 $$

Parameters that pass get the Adam update. Parameters that fail get shrunk or zeroed, depending on the mask. On the fresh-batch boundary, $\alpha = 1$ and the streaming variance estimates the leave-one-out penalty $\Sigma_B / (b - 1)$.

Install

uv pip install -e .
uv sync --extra test

PRM requires torch>=2.0.

Quick start

from popriskmin import PRM

optimizer = PRM(
    model.parameters(),
    lr=3e-4,
    weight_decay=0.01,
    softness=1.0,
    batch_size=32,
)

for batch in loader:
    optimizer.zero_grad()
    loss = loss_fn(model(batch.x), batch.y)
    loss.backward()
    optimizer.step()

Useful options:

PRM(model.parameters(), mask="snr")         # default, smooth SNR mask
PRM(model.parameters(), mask="soft")        # strict Algorithm 1 cutoff
PRM(model.parameters(), mask="hard")        # 0/1 theorem mask
PRM(model.parameters(), reduction="per_tensor")

Use reduction="per_tensor" when each scalar parameter is noisy but the whole parameter tensor has a clear signal. This is often the more useful setting for large generative models, diffusion, and CFM-style training.

Masks

mask Formula Behavior
snr $\hat{m}^2 / (\hat{m}^2 + \lambda_p \alpha \hat{s} + \varepsilon)$ Default. Smooth, never fully shuts off. With softness=1, it gives $q = 1/2$ on the boundary.
soft $\max(\hat{m}^2 - \alpha \hat{s}, 0) / (\max(\hat{m}^2 - \alpha \hat{s}, 0) + \lambda_p \hat{s} + \varepsilon)$ Strict Algorithm 1 mask. Zero below the boundary.
hard $\mathbf{1}[\hat{m}^2 > \alpha \hat{s}]$ Binary indicator from Theorem 6.5. Mostly useful for ablations.

softness is $\lambda_p$. Larger values make the mask more conservative. A reasonable first sweep is 0.3, 1, 3, 10.

Arguments

Argument Default Notes
lr 1e-3 Tune as you would for AdamW.
betas (0.9, 0.999) Adam moment decay rates.
rho 0.99 Decay for the centered gradient variance. Usually shorter than beta2.
eps 1e-8 Stabilizer for Adam and the mask denominator.
weight_decay 0.01 Decoupled AdamW-style weight decay by default.
softness 1.0 Population-risk mask regularizer.
batch_size None Optional for boundary="batch". Required for boundary="empirical".
boundary "batch" Use "batch" for online or fresh-batch training. Use "empirical" for finite-dataset leave-one-out.
n_dataset None Required when boundary="empirical".
mask "snr" One of "snr", "soft", or "hard".
reduction "per_param" Use one mask per scalar parameter. Set "per_tensor" to pool each parameter tensor first.
warmup_steps 0 Force the mask to 1 for the first N optimizer steps.
amsgrad False Use AMSGrad for the Adam denominator.
bias_correction True Apply Adam-style bias correction to m, v, and s.
decoupled_weight_decay True Set to False for coupled L2.

Diagnostics

stats = optimizer.get_mask_stats()
print(stats)

Example output:

{
    "mean_q": 0.62,
    "active_fraction": 0.71,
    "min_q": 0.00,
    "max_q": 0.99,
    "parameter_count": 1_245_184,
    "noise_scale": 12.4,
    "signal_sq": 18.7,
    "snr": 1.51,
}

mean_q tells you how open the mask is on average. If active_fraction falls near zero, almost every parameter is below the leave-one-out boundary. Try a lower softness, a higher learning rate, or reduction="per_tensor" before assuming the method is broken.

When it helps

PRM is meant for settings where plain empirical-risk training fits structured noise or memorizes before it generalizes. The paper reports improvements in:

Setting AdamW PRM
Modular division, 25% train fraction Groks at step 29,450 Groks at step 5,950
PINN with noisy initial condition, $\beta = 5$ Best LR-tuned run: 3,300 iterations 1,400 iterations to relative $\ell_2 \le 0.40$
Qwen2.5-0.5B-Instruct with 30% noisy DPO Reward accuracy 0.566, drift 0.41 Reward accuracy 0.641, drift 0.14

With more experiments in the appendix. However, PRM is not magic. If AdamW already reaches the solution without overfitting, PRM may not help.

Example

Run the smoke test:

uv run python examples/synthetic_regression.py

The script trains AdamW and several PRM variants on a noisy regression problem. You should see the optimizer learning the underlying signal instead of the corrupted training labels.

Layout

popriskmin/
|-- popriskmin/
|   |-- __init__.py
|   |-- mask.py
|   `-- optimizer.py
|-- examples/
|   `-- synthetic_regression.py
|-- tests/
|   `-- test_optimizer.py
`-- pyproject.toml

Citation

@misc{litman2026theory,
  title         = {A Theory of Generalization in Deep Learning},
  author        = {Litman, Elon and Guo, Gabe},
  year          = {2026},
  eprint        = {2605.01172},
  archivePrefix = {arXiv},
  primaryClass  = {cs.LG},
  doi           = {10.48550/arXiv.2605.01172}
}

About

Operationalization of Population Risk Minimization algorithm from "A Theory of Generalization in Deep Learning."

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages