# Quantize Model & Save Model

In [8]:
# ===== Torch-only, torchao-based quant baselines for DistilBERT (SST-2) =====
import time, numpy as np, torch
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# 1) pick your baseline:
#    "int8_dynamic" = W8A8 dynamic activations (recommended, robust)
#    "w4"           = INT4 weight-only (W4), activations stay float
MODE = "int8_dynamic"   # or "w4"

MODEL_DIR = "./distilbert-sst2-finetuned-128"
MAX_LEN   = 128
DEVICE    = "cpu"        # keep it torch-only; you can switch to "cuda" if available

tok = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR).eval().to(DEVICE)

# 2) quantize with torchao.quantization.quantize_
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig, Int4WeightOnlyConfig

if MODE == "int8_dynamic":
    # dynamic per-token activations (int8) + per-channel int8 weights on Linear layers
    quantize_(model, Int8DynamicActivationInt8WeightConfig())
elif MODE == "w4":
    # weight-only int4 per-group quant on Linear layers (activations stay float)
    # group_size=32 is a common default; adjust if you want
    quantize_(model, Int4WeightOnlyConfig(group_size=32))
else:
    raise ValueError("MODE must be 'int8_dynamic' or 'w4'")

# 3) helper
@torch.no_grad()
def predict_logits(m, texts, max_len=128):
    enc = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=max_len).to(DEVICE)
    return m(**enc).logits.detach().cpu().numpy()

def eval_sst2_acc(m, split="validation", bs=32):
    ds = load_dataset("glue", "sst2", split=split)
    labels = np.array(ds["label"], dtype=np.int64)
    preds = []
    for i in range(0, len(ds), bs):
        logits = predict_logits(m, ds["sentence"][i:i+bs], MAX_LEN)
        preds.append(np.argmax(logits, axis=-1))
    preds = np.concatenate(preds)
    return float((preds == labels[:len(preds)]).mean())

def bench_latency(m, text="this movie was great!", bs=32, runs=50, warmup=5):
    batch = [text] * bs
    _ = predict_logits(m, batch, MAX_LEN)  # warmup
    for _ in range(warmup):
        _ = predict_logits(m, batch, MAX_LEN)
    t0 = time.time()
    for _ in range(runs):
        _ = predict_logits(m, batch, MAX_LEN)
    return (time.time() - t0) / runs

# 4) (optional) compare against FP32 quickly
base = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR).eval().to(DEVICE)

acc_fp  = eval_sst2_acc(base)
acc_q   = eval_sst2_acc(model)
lat_fp  = bench_latency(base)
lat_q   = bench_latency(model)

print(f"\n=== MODE = {MODE} ===")
print(f"FP32 : acc={acc_fp:.4f},  avg_batch_latency={lat_fp*1000:.1f} ms")
print(f"Q({MODE}): acc={acc_q:.4f},  avg_batch_latency={lat_q*1000:.1f} ms")


=== MODE = int8_dynamic ===
FP32 : acc=0.9140,  avg_batch_latency=18.4 ms
Q(int8_dynamic): acc=0.9117,  avg_batch_latency=904.1 ms


In [12]:
import time
import torch

@torch.inference_mode()
def evaluate(model, dataloader, device="cpu", warmup=2, measure_batches=200):
    model.eval()
    correct = total = 0
    times = []

    # warmup passes (not timed)
    it = iter(dataloader)
    for _ in range(min(warmup, len(dataloader))):
        batch = next(it, None)
        if batch is None: break
        batch = {k: v.to(device) for k, v in batch.items()}
        _ = model(**batch).logits

    # timed passes
    counted = 0
    for batch in dataloader:
        if counted >= measure_batches: break
        labels = batch["labels"]
        batch = {k: v.to(device) for k, v in batch.items()}

        t0 = time.perf_counter()
        logits = model(**batch).logits
        dt = (time.perf_counter() - t0) * 1000.0  # ms
        times.append(dt)

        preds = logits.argmax(dim=-1).cpu()
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        counted += 1

    acc = 100.0 * correct / total if total else 0.0
    mean_ms = sum(times) / len(times) if times else float("nan")
    return {"acc": acc, "mean_ms_per_batch": mean_ms, "batches_measured": counted}