In [29]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup,
)

device = torch.device("cuda" if torch.cuda.is_available() else "mps")

# 1) Dataset (GLUE/SST-2)
raw_ds = load_dataset("glue", "sst2")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

def tokenize_fn(ex):
    return tokenizer(ex["sentence"], truncation=True)

# Keep only 'label'; drop 'sentence' and 'idx'
cols_to_remove = [c for c in raw_ds["train"].column_names if c not in ["label"]]
tokenized = raw_ds.map(tokenize_fn, batched=True, remove_columns=cols_to_remove)

collate = DataCollatorWithPadding(tokenizer=tokenizer)
train_loader = DataLoader(tokenized["train"], batch_size=128, shuffle=True, collate_fn=collate)
val_loader   = DataLoader(tokenized["validation"], batch_size=128, shuffle=False, collate_fn=collate)

# 3) Model + optim
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# (optional) scheduler for a few epochs
num_epochs = 3
num_steps = num_epochs * len(train_loader)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1 * num_steps), num_training_steps=num_steps)

# 4) Train loop
for epoch in range(1, num_epochs + 1):
    model.train()
    running = 0.0
    for batch in train_loader:
        # batch is a dict: input_ids, attention_mask, (optionally token_type_ids), labels
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()
        out = model(**batch)          # returns loss + logits when labels present
        loss = out.loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        running += loss.item()
    print(f"Epoch {epoch} | train loss: {running/len(train_loader):.4f}")

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


KeyboardInterrupt: 

In [24]:
save_dir = "./distilbert-sst2-finetuned-batch-128"

# Save model weights + config
model.save_pretrained(save_dir)

# Save tokenizer too (important for re-tokenization later)
tokenizer.save_pretrained(save_dir)

('./distilbert-sst2-finetuned/tokenizer_config.json',
 './distilbert-sst2-finetuned/special_tokens_map.json',
 './distilbert-sst2-finetuned/vocab.txt',
 './distilbert-sst2-finetuned/added_tokens.json',
 './distilbert-sst2-finetuned/tokenizer.json')

# Quantize Model & Save Model

In [25]:
from copy import deepcopy
from torchao.quantization import quantize_, Int8WeightOnlyConfig
torch.backends.quantized.engine = "qnnpack"

qmodel = deepcopy(fp32_model)        # don’t mutate your original
qmodel.eval()                        # eval mode is important
qmodel = qmodel.to("cpu")            # quant path is CPU

# Quantize linear weights (grouped int8). Use a group size that divides hidden dims (e.g., 64).
quantize_(qmodel, Int8WeightOnlyConfig(group_size=64))

In [26]:
import os, torch
qsave_dir = "./distilbert-sst2-finetuned-quantized"
os.makedirs(save_dir, exist_ok=True)

# 1) save state dict with torch.save (works with torchao tensor subclasses)
torch.save(qmodel.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))

# 2) save the HF config so from_pretrained can rebuild the architecture
qmodel.config.save_pretrained(qsave_dir)

# 3) tokenizer stays the same
tokenizer.save_pretrained(qsave_dir)

('./distilbert-sst2-finetuned-quantized/tokenizer_config.json',
 './distilbert-sst2-finetuned-quantized/special_tokens_map.json',
 './distilbert-sst2-finetuned-quantized/vocab.txt',
 './distilbert-sst2-finetuned-quantized/added_tokens.json',
 './distilbert-sst2-finetuned-quantized/tokenizer.json')

## Load and Eval Models

In [27]:
qmodel = AutoModelForSequenceClassification.from_pretrained(qsave_dir)  # builds the arch
state = torch.load(os.path.join(save_dir, "pytorch_model.bin"), map_location="cpu")
qmodel.load_state_dict(state, strict=True)
# may need to load tokenizer later for both if testing on new text

<All keys matched successfully>

In [28]:
fp32_model = AutoModelForSequenceClassification.from_pretrained("./distilbert-sst2-finetuned").to("cpu")
fp32_metrics = evaluate(fp32_model, val_loader, device="cpu")
qmetrics = evaluate(qmodel, val_loader, device="cpu")
print("FP32:", fp32_metrics)
print("Quantized:", qmetrics)

FP32: {'acc': 89.90825688073394, 'mean_ms_per_batch': 309.9399224322821, 'batches_measured': 7}
Quantized: {'acc': 90.59633027522936, 'mean_ms_per_batch': 330.66494028233654, 'batches_measured': 7}


# Model Evaluator Helper

In [5]:
import time
import torch

@torch.inference_mode()
def evaluate(model, dataloader, device="cpu", warmup=2, measure_batches=20):
    model.eval()
    correct = total = 0
    times = []

    # warmup passes (not timed)
    it = iter(dataloader)
    for _ in range(min(warmup, len(dataloader))):
        batch = next(it, None)
        if batch is None: break
        batch = {k: v.to(device) for k, v in batch.items()}
        _ = model(**batch).logits

    # timed passes
    counted = 0
    for batch in dataloader:
        if counted >= measure_batches: break
        labels = batch["labels"]
        batch = {k: v.to(device) for k, v in batch.items()}

        t0 = time.perf_counter()
        logits = model(**batch).logits
        dt = (time.perf_counter() - t0) * 1000.0  # ms
        times.append(dt)

        preds = logits.argmax(dim=-1).cpu()
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        counted += 1

    acc = 100.0 * correct / total if total else 0.0
    mean_ms = sum(times) / len(times) if times else float("nan")
    return {"acc": acc, "mean_ms_per_batch": mean_ms, "batches_measured": counted}