In [1]:
import torch
from torch.quantization import quantize_static, get_default_qconfig, prepare, convert
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

In [4]:
# Load the pre-trained DistilBERT model
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model.eval()

Downloading (…)okenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [ ]:
# Set the quantization configuration for x86 CPUs
model.qconfig = get_default_qconfig("fbgemm")

# Prepare the model for static quantization
prepare(model, inplace=True)

In [5]:
# Calibrate the model with positive and negative sentiment news headlines (the target application domain of the model)


In [ ]:
# Convert the model to a quantized version
quantized_model = convert(model, inplace=True)

In [6]:
# Test the original and quantized models
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits
    quant_logits = quantized_model(**inputs).logits

predicted_class_id = logits.argmax().item()
quant_predicted_class_id = quant_logits.argmax().item()

In [7]:
# Display results
print("Original model prediction:", model.config.id2label[predicted_class_id])
print("Quantized model prediction:", model.config.id2label[quant_predicted_class_id])

Original model prediction: POSITIVE
Quantized model prediction: POSITIVE


In [8]:
# Save the quantized model state_dict
torch.save(quantized_model.state_dict(), "../goodnewsonly/resources/static_quantized_model_state_dict.pth")