Skip to content

feat(hslm): OHEM — online hard example mining for fast convergence #318

@gHashTag

Description

@gHashTag

Task

Select only high-loss examples for backward pass. Don't waste compute on already-learned patterns.
Combine with curriculum learning for maximum sample efficiency.

Scientific Background

OHEM Original Paper (CVPR 2016, Shrivastava et al.)

  • 2-5% mAP improvement on PASCAL VOC with negligible overhead
  • Core: compute loss for all → sort → backward only top-k hardest
  • Effectiveness increases with dataset size — more examples to select from
  • For detection: auto-selects hard negatives, replaces manual mining

Curriculum Learning for LLM Pretraining (arxiv:2506.11300, June 2025)

  • 200+ models evaluated, up to 100B tokens, 3 strategies, 6 difficulty metrics
  • 18-45% fewer steps to reach baseline performance
  • Best difficulty signals: compression ratio, lexical diversity (MTLD), Flesch readability
  • As warmup strategy: sustained 3.5% improvement
  • Orthogonal to OHEM — can combine both!

Hard Example Mining Survey (PubMed 2025)

  • Token-level HEM for LMs: high-loss tokens = most informative gradients
  • Gradient-based selection complements loss-based: dual criterion more robust
  • Convergence speedup: 18-45% across domains
  • Clean Hard Examples vs Mislabeled Easy Examples: need filtering (Early Cutting technique)

Tiny Model Triage Effect (scaling laws research)

  • Small models (<10M) develop triage strategy: concentrate on easy, abandon hardest
  • Gini coefficient 0.26 at 22K params vs 0.09 at 4.7M
  • OHEM directly counters this: forces model to allocate capacity to hard examples
  • Especially impactful for tiny models like our 1.95M

Token-to-Parameter Ratio

  • Current: 100K steps × batch=128 = 12.8M tokens → ratio 6.5:1
  • Chinchilla optimal: 20:1 (39M tokens)
  • Tsinghua MiniCPM: 192:1 for tiny models
  • With OHEM: each token provides more gradient info → effective ratio higher

Implementation Plan

Phase 1: Token-level OHEM

// Forward pass: compute loss for ALL tokens in batch
const losses = computeTokenLosses(batch);  // [batch_size × seq_len]

// Sort and select top-k hardest tokens
const sorted_indices = argsort(losses, .descending);
const hard_mask = sorted_indices[0..top_k];  // top 50% initially

// Backward pass: only compute gradients for hard tokens
backwardWithMask(hard_mask);

Phase 2: Dynamic threshold annealing

Step 0-20K:   top_k = 100% (all tokens, standard training)
Step 20K-50K: top_k = 75% → 50% (progressive hardening)
Step 50K-100K: top_k = 50% → 25% (only hardest quarter)

Phase 3: Curriculum + OHEM combo

  1. Sort TinyStories by difficulty (compression ratio / Flesch score)
  2. First 20% of training: easy stories only (curriculum warmup)
  3. Then: all stories with OHEM selecting hard tokens within each batch
  4. Dual signal: global order (curriculum) + local selection (OHEM)

Phase 4: Gradient-based selection (advanced)

// Dual criterion: high loss AND high gradient norm
const importance = loss * gradient_norm;  // combined score
const hard_mask = topk(importance, k=batch_size/2);

Changes

  • src/hslm/trainer.zig: token-level loss computation + top-k selection
  • src/hslm/trainer.zig: dynamic threshold scheduler (linear anneal)
  • src/hslm/data.zig: difficulty scoring for curriculum ordering
  • Flag: --ohem-ratio=0.5 (fraction of tokens for backward)
  • Flag: --curriculum (enable difficulty-ordered training)

Expected

  • 20-30% PPL reduction (PPL 125 → 87-100) from OHEM alone
  • Additional 3-5% from curriculum learning warmup
  • Combined: PPL 125 → 80-95 with same 100K steps
  • With extended training (200K steps + OHEM): PPL target 57-65

Priority: HIGHEST — biggest single PPL improvement of all L8-L12 tasks

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    agent:spawnAuto-spawn agent container

    Projects

    Status
    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions