Skip to content

gpartin/ladam

Repository files navigation

LAdam

Laplacian Adam — spatially-aware adaptive optimizer for PyTorch

PyPI License Python 3.8+

LAdam is a drop-in Adam replacement that applies discrete Laplacian regularization to Adam's second-moment estimate (v_t). This couples neighboring weight learning rates, producing spatially-smoothed adaptive optimization.

Why LAdam?

Adam computes independent per-parameter learning rates. But adjacent weights in trained networks are often functionally correlated — the per-parameter variance estimates should reflect this structure.

LAdam adds one operation to Adam: a Laplacian diffusion step on v_t, controlled by a single scalar c2. The Laplacian allows each weight's learning rate to be informed by its neighbors, smoothing the optimization landscape.

Results

Task Architecture Metric Adam LAdam Improvement Seeds
Wave Equation PINN 5×128 MLP L2 Error 0.0067 ± 0.0015 0.0066 ± 0.0010 +0.8% 3
Regression MLP MSE 0.213 0.184 -13.5% 3
FashionMNIST Transformer Accuracy 89.46 ± 0.10% 89.66 ± 0.06% +0.20% (p=0.0005) 5
CIFAR-10 ResNet + Chi-Anneal Accuracy 67.96% 73.39% +5.43% 3
FashionMNIST MLP Accuracy 89.10% 89.12% +0.02% (n.s.) 1
FashionMNIST CNN Accuracy 91.15% 91.14% -0.01% (tie) 1

LAdam works best on structured regression, physics-informed networks, and classification with chi-annealing. It has lower variance across seeds on PINNs. It does NOT help on LLMs — tested on GPT-2 fine-tuning where it significantly hurts performance.

Installation

pip install ladam

Optimizers

LAdam ships three Laplacian-enhanced optimizers:

Optimizer Base Laplacian target Best for
LAdam Adam Second moment v_t PINNs, transformers, CNNs
LAdaGrad AdaGrad Cumulative sum G_t Sparse features, NLP
LRMSProp RMSProp Running average v_t RNNs, non-stationary losses

All three share the same Laplacian kernel infrastructure and c2 parameter.

Usage

Basic — Drop-in Adam replacement

from ladam import LAdam

optimizer = LAdam(model.parameters(), lr=1e-3, c2=1e-4)

# Training loop is identical to Adam
for batch in dataloader:
    loss = criterion(model(batch))
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

LAdaGrad and LRMSProp

from ladam import LAdaGrad, LRMSProp

# AdaGrad with Laplacian smoothing on cumulative squared gradients
optimizer = LAdaGrad(model.parameters(), lr=1e-2, c2=1e-4)

# RMSProp with Laplacian smoothing on running variance
optimizer = LRMSProp(model.parameters(), lr=1e-2, alpha=0.99, c2=1e-4)

Per-layer c2 with parameter groups

optimizer = LAdam([
    {'params': model.attention.parameters(), 'c2': 1e-4},   # Transformer attention
    {'params': model.ffn.parameters(), 'c2': 1e-5},         # Feed-forward
    {'params': model.norm.parameters(), 'c2': 0.0},         # Skip for norms
], lr=3e-4)

Architecture-aware defaults

from ladam import LAdam, suggest_c2

c2 = suggest_c2('pinn')         # Returns 1e-5
c2 = suggest_c2('transformer')  # Returns 1e-4

optimizer = LAdam(model.parameters(), lr=1e-3, c2=c2)

Parameters

Parameter Default Description
lr 1e-3 Learning rate
betas (0.9, 0.999) EMA coefficients (same as Adam)
eps 1e-8 Numerical stability (same as Adam)
weight_decay 0 L2 regularization (same as AdamW behavior)
c2 1e-4 Laplacian coupling strength. Controls how much neighboring variance estimates influence each other.
mode 'variance_lap' Which quantity to smooth. 'variance_lap' is best.
stencil '9point' Discrete Laplacian stencil. '9point' (isotropic, 0.46% anisotropy) or '5point' (legacy, 12.3% anisotropy).
min_spatial_size 16 Skip Laplacian for params with fewer elements (biases, LayerNorm).

Stencil Selection

The stencil parameter controls the discrete Laplacian kernel used for spatial coupling:

  • '9point' (default): Isotropic stencil with face + edge neighbors. Treats diagonal neighbors with 1/6 weight vs 4/6 for face neighbors.
  • '5point': Standard cross-pattern stencil (faces only). Slightly faster but 25× more anisotropic.

At typical c2 values (1e-5 to 1e-3), the effective learning rate difference between stencils is <0.3%. The 9-point default is recommended for correctness.

Choosing c2

c2 is the only new hyperparameter. It's robust across 3 orders of magnitude:

c2 Best For Notes
1e-5 PINNs, scientific ML Gentle coupling, biggest error reduction
1e-4 Transformers, general Safe default
1e-3 Aggressive smoothing Works but slightly less stable
0 Disable Reduces to standard Adam

All 7 values tested in [1e-6, 1e-3] outperformed Adam on transformers (B12 sweep).

How It Works

Standard Adam computes per-parameter adaptive learning rates from the second moment:

v_t = β₂·v_{t-1} + (1-β₂)·g_t²     # Variance estimate
lr_effective = lr / (√v_t + ε)        # Per-parameter learning rate

LAdam adds a Laplacian coupling step:

v_smooth = v_t + c2 · ∇²v_t           # Spatial smoothing
lr_effective = lr / (√v_smooth + ε)    # Coupled learning rate

Where \nabla^2 is the discrete Laplacian computed via a single F.conv2d kernel (9-point isotropic by default) -- efficient and GPU-friendly. The Laplacian treats weight matrices as 2D fields, coupling each weight's learning rate with its spatial neighbors.

Overhead: ~2-5% wall-clock time increase per step. The Laplacian is a single fused convolution kernel, not point-wise iteration.

Benchmarks

PINN: Wave Equation (u_tt = c^2 u_xx)

5-layer, 128-unit tanh MLP trained for 5000 steps on the 1D wave equation. 3 seeds, best L2 per seed.

Optimizer Mean L2 Error Std vs Adam
Adam (lr=1e-3) 0.0067 ± 0.0015
LAdam c²=1e-5 0.0066 ± 0.0010 +0.8%, lower variance

LAdam converges to similar L2 but with 34% lower variance across seeds (0.0010 vs 0.0015), indicating more stable optimization.

Note: An earlier single-seed benchmark with gradient clipping showed -44.6%. Multi-seed testing without gradient clipping shows the advantage is primarily in convergence stability, not final error magnitude.

Transformer: FashionMNIST Classification

4-head, 128-dim, 2-layer transformer, 30 epochs, 5 independent seeds.

Optimizer Accuracy (mean ± std) p-value (vs Adam)
Adam 89.46 ± 0.10%
LAdam c²=1e-4 89.66 ± 0.06% 0.0005

c² Robustness Sweep

7 c² values on the same transformer task. All 7 beat Adam:

Accuracy Δ vs Adam
1e-6 89.62% +0.16%
5e-6 89.73% +0.27%
1e-5 89.79% +0.33%
5e-5 89.75% +0.29%
1e-4 89.67% +0.21%
5e-4 89.64% +0.18%
1e-3 89.66% +0.20%

FAQ

Q: Does this work for LLMs / GPT-scale models? A: No. LAdam hurts LLM training (tested on GPT-2/WikiText-2). Attention weight matrices encode semantic structure, not spatial structure — the Laplacian destroys per-feature specialization. Use standard Adam/AdamW for LLMs.

Q: Why not smooth the gradient instead of the variance? A: Osher et al. (2018) explored Laplacian smoothing of gradients. We found that smoothing the variance estimate is more effective because it smooths the learning rate landscape rather than the descent direction. These are mathematically distinct: ∇²(EMA(g²)) ≠ (∇²g)².

Q: Why does this help PINNs so much? A: PDE-based loss landscapes have inherent spatial structure from the differential operators in the loss function. The Laplacian on v_t aligns the optimizer's internal representation with this structure.

Q: Can I use this with learning rate schedulers? A: Yes. LAdam is fully compatible with any torch.optim.lr_scheduler.

Citation

If you use LAdam in your research, please cite:

@software{partin2026ladam,
  author = {Partin, Greg},
  title = {LAdam: Spatially-Aware Adaptive Optimization via Laplacian-Regularized Variance Estimates},
  year = {2026},
  url = {https://github.com/gpartin/ladam}
}

License

MIT. See LICENSE for details.

About

laplacian + adam

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors