A lightweight, pip-installable PyTorch library for robust deep learning training under noise.
HASA tracks each training sample's loss trajectory over a sliding window and uses loss variance as a noise indicator. Clean samples stabilise quickly (low variance); noisy/mislabelled samples oscillate (high variance). After a warm-up phase the algorithm masks out high-variance samples so that gradients are computed only from likely-clean data.
pip install hasaOr install directly from GitHub:
pip install git+https://github.com/msc35/hasa-py.gitFor development:
git clone https://github.com/msc35/hasa-py.git
cd hasa-py
pip install -e ".[dev]"HASA needs to map each loss value back to a specific dataset sample.
The simplest approach is an IndexedDataset wrapper:
from torch.utils.data import Dataset
class IndexedDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
x, y = self.dataset[idx]
return idx, x, yimport torch
import torch.nn as nn
from torch.utils.data import DataLoader
from hasa import HASA
dataset = IndexedDataset(my_dataset)
loader = DataLoader(dataset, batch_size=128, shuffle=True)
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(reduction='none') # MUST be unreduced
selector = HASA(num_samples=len(dataset), window_size=15, select_ratio=0.8)
for epoch in range(150):
for indices, x, y in loader:
logits = model(x)
losses = criterion(logits, y)
mask = selector.step(indices, losses.detach())
# Divide by mask.sum() (not batch_size) to preserve gradient magnitude
loss = (losses * mask).sum() / mask.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
selector.end_epoch()Motivated by interpreting SGD as approximate Bayesian inference (Mandt, Hoffman & Blei, 2018):
selector = HASA(
num_samples=len(dataset),
window_size=15,
select_ratio=0.8,
langevin_noise=1e-4,
)
# Inside the training loop, after optimizer.step():
selector.inject_langevin_noise(model)- Warm-up phase (first
window_sizeepochs): all samples are used — the per-sample loss history buffer fills up. - Selection phase (after warm-up): for each batch, compute per-sample loss variance from the history buffer. Keep only the
select_ratiofraction with the lowest variance. High-variance samples (likely mislabelled) are masked out.
Warm-up (epochs 0..T-1) Selection (epochs T+)
┌─────────────────────┐ ┌──────────────────────────────┐
per-sample │ Record losses into │ │ Var(loss history) per sample│
losses ───>│ ring buffer, train │────>│ Keep lowest-variance k% │
│ on ALL samples │ │ Mask out the rest │
└─────────────────────┘ └──────────────────────────────┘
| Parameter | Meaning | Tested Range | Default |
|---|---|---|---|
num_samples |
Total dataset size (for buffer allocation) | — | required |
window_size |
Loss values stored per sample (T) | 5, 10, 15 | 15 |
select_ratio |
Fraction of batch to keep (k) | 0.5 – 0.9 | 0.8 |
langevin_noise |
Scale of injected Gaussian noise | 0 or small | 0.0 |
HASA supports full state serialisation for resuming training:
state = selector.state_dict()
torch.save(state, "hasa_checkpoint.pt")
# Restore later
selector.load_state_dict(torch.load("hasa_checkpoint.pt"))For a simpler API that handles the full training loop:
from hasa.callbacks import HASATrainer
trainer = HASATrainer(model, optimizer, criterion, selector, device="cuda")
for epoch in range(num_epochs):
metrics = trainer.train_epoch(dataloader)
print(f"Epoch {epoch}: loss={metrics['loss']:.4f}, selected={metrics['selected_frac']:.2%}")HASA is model-agnostic — it only looks at per-sample loss values. It works with any architecture (CNNs, Transformers, MLPs, etc.) and any loss function that supports reduction='none':
nn.CrossEntropyLoss(reduction='none')— classificationnn.MSELoss(reduction='none')— regressionnn.BCEWithLogitsLoss(reduction='none')— binary classification- Any custom loss returning per-sample values
step(sample_indices, losses) -> BoolTensor— update history, return selection mask.end_epoch()— advance the epoch counter. Must be called once per epoch.inject_langevin_noise(model)— add Gaussian noise to parameters.state_dict() / load_state_dict(d)— checkpoint support.epoch— current epoch (read-only property).in_warmup— True during the warm-up phase.
update(indices, losses)— write losses into the ring buffer.variance(indices) -> Tensor— compute per-sample loss variance.is_ready(epoch) -> bool— True after warm-up.
Returns a mask keeping the lowest select_ratio fraction of samples by variance.
pip install -e ".[dev]"
pytest --cov=hasa