# DistilBERT Quantization — Consolidated Notebook

This single notebook covers both flows from `nlp_quant_test.ipynb` and `weight_activation_test.ipynb`, using shared utilities in `quant_utils.py`.
- Choose quantization mode: dynamic INT8 activations + weights (`int8_dynamic`) or weight-only (`w4`, `w8`).
- Evaluate via tokenized pipeline (accuracy + latency) or dataloader (accuracy + per-batch latency).
- Optionally fine-tune baseline before quantization.
- Optionally save the quantized model for reuse.


In [None]:
import os
from copy import deepcopy
from pathlib import Path
import torch

from quant_utils import (
    load_sst2_dataloaders,
    train_sst2_baseline,
    quantize_model,
    save_quantized_model,
    load_base_model,
    evaluate_dataloader,
    eval_sst2_acc,
    bench_latency,
)

from quant_eval import EvalConfig, evaluate_pair

def pick_device(name: str) -> str:
    if name == 'cuda' and torch.cuda.is_available():
        return 'cuda'
    if name == 'mps' and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return 'mps'
    return 'cpu'


## Config

In [None]:
# Model / data
MODEL_NAME = 'distilbert-sst2-finetuned-128-edited'

def _resolve_model_dir() -> str:
    candidates = []
    cwd = Path.cwd()
    candidates.append((cwd / 'models' / MODEL_NAME).resolve())
    candidates.append((cwd / 'QAT_pipeline' / 'models' / MODEL_NAME).resolve())
    try:
        nb_dir = Path(__file__).resolve().parent
    except NameError:
        nb_dir = Path.cwd()
    candidates.append((nb_dir / '..' / 'models' / MODEL_NAME).resolve())
    searched = ', '.join(str(p) for p in candidates)
    for path in candidates:
        if path.exists():
            return str(path)
    raise FileNotFoundError(
        f"Could not locate '{MODEL_NAME}' in any expected location: {searched}"
    )

MODEL_DIR = _resolve_model_dir()
DEVICE    = 'cuda'                                # 'cpu' | 'cuda' | 'mps'
MAX_LEN   = 128

# Quantization
MODE       = 'int8_dynamic'  # 'int8_dynamic' | 'w4' | 'w8'
GROUP_SIZE = 32              # used for weight-only modes

# Evaluation
EVAL_MODE        = 'tokenized'  # 'tokenized' | 'dataloader'
BATCH_SIZE       = 128           # for dataloader eval
MEASURE_BATCHES  = 200           # for dataloader eval
WARMUP           = 2             # for dataloader eval
BENCH_BATCH      = 32            # for tokenized latency bench
BENCH_RUNS       = 50            # for tokenized latency bench
BENCH_WARMUP     = 5             # for tokenized latency bench
COMPARE_FP32     = True

# Optional quick fine-tune (set >0 to enable)
TRAIN_EPOCHS = 0
LR           = 5e-5

# Saving
SAVE     = True
SAVE_DIR = ''  # default: f"{MODEL_DIR}-quantized-{MODE}"



## Load (or train) baseline

In [None]:
device = pick_device(DEVICE)
print(f'Using device: {device}')

if TRAIN_EPOCHS and TRAIN_EPOCHS > 0:
    model, tok = train_sst2_baseline(
        model_name_or_dir=MODEL_DIR,
        epochs=TRAIN_EPOCHS,
        lr=LR,
        batch_size=BATCH_SIZE,
        max_len=MAX_LEN,
        device=device,
    )
else:
    model, tok = load_base_model(MODEL_DIR, device=device)

# Keep original FP32 model intact for comparison
model = model.eval()
print('Baseline loaded.')


## Quantize (torchao)

In [None]:
# Quantize a deep copy to avoid mutating FP32 baseline
qmodel = quantize_model(deepcopy(model), mode=MODE, group_size=GROUP_SIZE)
qmodel = qmodel.to(device).eval()
print(f'Quantized with MODE={MODE}, group_size={GROUP_SIZE}')


## Evaluate

In [None]:
eval_cfg = EvalConfig(
    mode=EVAL_MODE,
    batch_size=BATCH_SIZE,
    measure_batches=MEASURE_BATCHES,
    warmup=WARMUP,
    bench_batch=BENCH_BATCH,
    bench_runs=BENCH_RUNS,
    bench_warmup=BENCH_WARMUP,
    compare_fp32=COMPARE_FP32,
    max_len=MAX_LEN,
)

results = evaluate_pair(
    baseline=model,
    quantized=qmodel,
    tokenizer=tok,
    model_dir=MODEL_DIR,
    device=device,
    config=eval_cfg,
)

if eval_cfg.mode == 'dataloader':
    print('Quantized:', results['quantized'])
    if COMPARE_FP32 and 'fp32' in results:
        print('FP32:', results['fp32'])
else:
    q = results['quantized']
    print(f"Q({MODE}): acc={q['acc']:.4f},  avg_batch_latency={q['avg_batch_latency_ms']:.1f} ms")
    if COMPARE_FP32 and 'fp32' in results:
        fp = results['fp32']
        print(f"FP32 : acc={fp['acc']:.4f},  avg_batch_latency={fp['avg_batch_latency_ms']:.1f} ms")


## Save quantized model (optional)

In [None]:
if SAVE:
    out_dir = SAVE_DIR or (MODEL_DIR.rstrip('/') + f'-quantized-{MODE}')
    save_quantized_model(qmodel, tok, out_dir)
    print('Saved quantized model to:', out_dir)
else:
    print('Skipping save (set SAVE=True to enable).')
