popriskmin gives you PRM, a small modification of AdamW that trains on population risk instead of raw empirical risk. It keeps AdamW's usual one forward pass and one backward pass per step, then preconditions each parameter update with the population-risk mask from Litman & Guo (2026), A Theory of Generalization in Deep Learning.
PRM tracks one extra tensor per parameter tensor: an exponential moving average of centered minibatch gradient variance. For each scalar parameter
Parameters that pass get the Adam update. Parameters that fail get shrunk or zeroed, depending on the mask. On the fresh-batch boundary,
uv pip install -e .
uv sync --extra testPRM requires torch>=2.0.
from popriskmin import PRM
optimizer = PRM(
model.parameters(),
lr=3e-4,
weight_decay=0.01,
softness=1.0,
batch_size=32,
)
for batch in loader:
optimizer.zero_grad()
loss = loss_fn(model(batch.x), batch.y)
loss.backward()
optimizer.step()Useful options:
PRM(model.parameters(), mask="snr") # default, smooth SNR mask
PRM(model.parameters(), mask="soft") # strict Algorithm 1 cutoff
PRM(model.parameters(), mask="hard") # 0/1 theorem mask
PRM(model.parameters(), reduction="per_tensor")Use reduction="per_tensor" when each scalar parameter is noisy but the whole parameter tensor has a clear signal. This is often the more useful setting for large generative models, diffusion, and CFM-style training.
mask |
Formula | Behavior |
|---|---|---|
snr |
Default. Smooth, never fully shuts off. With softness=1, it gives |
|
soft |
Strict Algorithm 1 mask. Zero below the boundary. | |
hard |
Binary indicator from Theorem 6.5. Mostly useful for ablations. |
softness is 0.3, 1, 3, 10.
| Argument | Default | Notes |
|---|---|---|
lr |
1e-3 |
Tune as you would for AdamW. |
betas |
(0.9, 0.999) |
Adam moment decay rates. |
rho |
0.99 |
Decay for the centered gradient variance. Usually shorter than beta2. |
eps |
1e-8 |
Stabilizer for Adam and the mask denominator. |
weight_decay |
0.01 |
Decoupled AdamW-style weight decay by default. |
softness |
1.0 |
Population-risk mask regularizer. |
batch_size |
None |
Optional for boundary="batch". Required for boundary="empirical". |
boundary |
"batch" |
Use "batch" for online or fresh-batch training. Use "empirical" for finite-dataset leave-one-out. |
n_dataset |
None |
Required when boundary="empirical". |
mask |
"snr" |
One of "snr", "soft", or "hard". |
reduction |
"per_param" |
Use one mask per scalar parameter. Set "per_tensor" to pool each parameter tensor first. |
warmup_steps |
0 |
Force the mask to 1 for the first N optimizer steps. |
amsgrad |
False |
Use AMSGrad for the Adam denominator. |
bias_correction |
True |
Apply Adam-style bias correction to m, v, and s. |
decoupled_weight_decay |
True |
Set to False for coupled L2. |
stats = optimizer.get_mask_stats()
print(stats)Example output:
{
"mean_q": 0.62,
"active_fraction": 0.71,
"min_q": 0.00,
"max_q": 0.99,
"parameter_count": 1_245_184,
"noise_scale": 12.4,
"signal_sq": 18.7,
"snr": 1.51,
}mean_q tells you how open the mask is on average. If active_fraction falls
near zero, almost every parameter is below the leave-one-out boundary. Try a lower softness, a higher learning rate, or reduction="per_tensor" before assuming the method is broken.
PRM is meant for settings where plain empirical-risk training fits structured noise or memorizes before it generalizes. The paper reports improvements in:
| Setting | AdamW | PRM |
|---|---|---|
| Modular division, 25% train fraction | Groks at step 29,450 | Groks at step 5,950 |
| PINN with noisy initial condition, |
Best LR-tuned run: 3,300 iterations | 1,400 iterations to relative |
| Qwen2.5-0.5B-Instruct with 30% noisy DPO | Reward accuracy 0.566, drift 0.41 | Reward accuracy 0.641, drift 0.14 |
With more experiments in the appendix. However, PRM is not magic. If AdamW already reaches the solution without overfitting, PRM may not help.
Run the smoke test:
uv run python examples/synthetic_regression.pyThe script trains AdamW and several PRM variants on a noisy regression problem. You should see the optimizer learning the underlying signal instead of the corrupted training labels.
popriskmin/
|-- popriskmin/
| |-- __init__.py
| |-- mask.py
| `-- optimizer.py
|-- examples/
| `-- synthetic_regression.py
|-- tests/
| `-- test_optimizer.py
`-- pyproject.toml
@misc{litman2026theory,
title = {A Theory of Generalization in Deep Learning},
author = {Litman, Elon and Guo, Gabe},
year = {2026},
eprint = {2605.01172},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
doi = {10.48550/arXiv.2605.01172}
}