Skip to content

feat(hslm): knowledge distillation from FP32 teacher to ternary student #322

@gHashTag

Description

@gHashTag

Task

Train full-precision teacher (same architecture), then distill to ternary student.
Soft probability distributions provide richer supervision than hard labels.

Scientific Background

Distillation Effectiveness (NeurIPS 2021)

  • Students fail to exactly match teacher predictions — optimization landscape constraint
  • But soft targets still provide 2-5% improvement over hard-label training
  • Mechanism: soft distributions encode inter-class relationships
  • Temperature scaling τ=4-8 optimal for ternary students

Llama 3.2 Distillation (Meta, 2024)

  • 1B and 3B distilled from 8B and 70B teachers
  • Logit-level distillation during pretraining
  • Loss ratio: 0.5 cross-entropy + 0.5 distillation (optimal blend)
  • Improvements especially on reasoning and multi-step tasks

Quantization-Aware Distillation (QAD)

  • Teacher supervises while student experiences fake quantization
  • Recovers 1-3% accuracy lost to aggressive quantization
  • For PPL=125 → estimated 3-5 PPL point improvement

BitNet b1.58 Note

  • Authors did NOT extensively explore distillation
  • Direct training competitive for large models (3B+)
  • But for tiny models (1.95M) — distillation likely MORE impactful

Implementation

// Step 1: Train FP32 teacher (same arch, ~20K extra steps)
// Step 2: Distill with dual loss

const L_total = alpha * L_ce(student_logits, labels) +
                (1 - alpha) * L_kd(
                    softmax(student_logits / tau),
                    softmax(teacher_logits / tau)
                ) * tau * tau;

// alpha = 0.5, tau = 5.0 (temperature)

Changes

  • src/hslm/trainer.zig: teacher model loading + dual loss computation
  • src/hslm/distill.zig: KD loss function with temperature scaling
  • Two-phase training: tri train --phase=teachertri train --phase=distill
  • Teacher checkpoint: data/checkpoints/teacher_fp32.bin

Expected

  • Teacher training: +20K steps overhead
  • 3-5 PPL improvement (PPL 125 → 120-122, or combined with OHEM: 80 → 76)
  • Most value when combined with other techniques (multiplicative)
  • Low risk: if distillation doesn't help, just use direct training

Priority: MEDIUM — modest standalone gain, but good compound effect

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    agent:spawnAuto-spawn agent container

    Projects

    Status
    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions