In [3]:
import torch
import torch.quantization
from torch.quantization import get_default_qat_qconfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import logging

In [4]:
logging.basicConfig(level=logging.INFO)


In [5]:
MODEL_NAME = "facebook/opt-125m"
DATASET_NAME = "medmcqa"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


In [6]:
model.train()
model.qconfig = get_default_qat_qconfig('fbgemm')  # Use fbgemm for x86 CPUs or 'qnnpack' for mobile
model = torch.quantization.prepare_qat(model, inplace=False)



In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Load dataset
dataset = load_dataset(DATASET_NAME)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Generating train split: 100%|██████████| 182822/182822 [00:00<00:00, 243672.66 examples/s]
Generating test split: 100%|██████████| 6150/6150 [00:00<00:00, 392187.70 examples/s]
Generating validation split: 100%|██████████| 4183/4183 [00:00<00:00, 267834.60 examples/s]


In [10]:
def train_qat(model, tokenizer, dataset, num_epochs=1, max_steps=100):
    model.train()
    for epoch in range(num_epochs):
        logging.info(f"Epoch {epoch + 1}/{num_epochs}")
        for i, batch in enumerate(dataset["train"]):
            if i >= max_steps:
                break

            question = batch["question"]
            options = [batch["opa"], batch["opb"], batch["opc"], batch["opd"]]
            answer_idx = batch["cop"]

            prompt = f"Question: {question}\nOptions:\nA. {options[0]}\nB. {options[1]}\nC. {options[2]}\nD. {options[3]}\nAnswer: {['A','B','C','D'][answer_idx]}"
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
            
            model.zero_grad()
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
            loss.backward()
            optimizer.step()

            if i % 10 == 0:
                logging.info(f"Step {i}, Loss: {loss.item():.4f}")

    # Convert to quantized model after training
    model = torch.quantization.convert(model, inplace=False)
    return model

# Evaluate function
def evaluate_qat_model(model, tokenizer, dataset, limit=200):
    model.eval()
    predictions = []
    references = []

    with torch.no_grad():
        for i, batch in enumerate(dataset["validation"]):
            if i >= limit:
                break

            question = batch["question"]
            options = [batch["opa"], batch["opb"], batch["opc"], batch["opd"]]
            answer_idx = batch["cop"]

            prompt = f"Question: {question}\nOptions:\nA. {options[0]}\nB. {options[1]}\nC. {options[2]}\nD. {options[3]}\nAnswer:"
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
            
            outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10)
            pred_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            answer_letter = pred_text.strip().upper().split("Answer:")[-1].strip()[0]

            label_map = {"A": 0, "B": 1, "C": 2, "D": 3}
            if answer_letter in label_map:
                predictions.append(label_map[answer_letter])
            else:
                predictions.append(None)

            references.append(answer_idx)

    # Filter out None predictions
    filtered = [(p, r) for p, r in zip(predictions, references) if p is not None]
    predictions, references = zip(*filtered) if filtered else ([], [])

    accuracy = accuracy_score(references, predictions) if predictions else 0
    return accuracy

In [11]:
quantized_model = train_qat(model, tokenizer, dataset, num_epochs=1, max_steps=100)

INFO:root:Epoch 1/1
INFO:root:Step 0, Loss: 12.0137
INFO:root:Step 10, Loss: 5.8261
INFO:root:Step 20, Loss: 5.0928
INFO:root:Step 30, Loss: 5.2674
INFO:root:Step 40, Loss: 4.0199
INFO:root:Step 50, Loss: 3.6182
INFO:root:Step 60, Loss: 2.2397
INFO:root:Step 70, Loss: 2.5552
INFO:root:Step 80, Loss: 2.9768
INFO:root:Step 90, Loss: 3.0076


AssertionError: Embedding quantization is only supported with float_qparams_weight_only_qconfig.

In [None]:
accuracy = evaluate_qat_model(quantized_model, tokenizer, dataset, limit=200)
print(f"Accuracy on MedMCQA after QAT: {accuracy:.2%}")

In [None]:

original_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)

def evaluate_model(model, tokenizer, dataset, limit=200):
    """Generic evaluation function that works for both original and quantized models"""
    model.eval()
    predictions = []
    references = []

    with torch.no_grad():
        for i, batch in enumerate(dataset["validation"]):
            if i >= limit:
                break

            question = batch["question"]
            options = [batch["opa"], batch["opb"], batch["opc"], batch["opd"]]
            answer_idx = batch["cop"]

            prompt = f"Question: {question}\nOptions:\nA. {options[0]}\nB. {options[1]}\nC. {options[2]}\nD. {options[3]}\nAnswer:"
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device)
            
            outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10)
            pred_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            answer_letter = pred_text.strip().upper().split("Answer:")[-1].strip()[0]

            label_map = {"A": 0, "B": 1, "C": 2, "D": 3}
            if answer_letter in label_map:
                predictions.append(label_map[answer_letter])
            else:
                predictions.append(None)

            references.append(answer_idx)

    # Filter out None predictions
    filtered = [(p, r) for p, r in zip(predictions, references) if p is not None]
    predictions, references = zip(*filtered) if filtered else ([], [])

    accuracy = accuracy_score(references, predictions) if predictions else 0
    return accuracy

# Evaluate original model BEFORE QAT
original_accuracy = evaluate_model(original_model, tokenizer, dataset, limit=200)
print(f"Accuracy on MedMCQA BEFORE QAT: {original_accuracy:.2%}")


# Evaluate after QAT
quantized_accuracy = evaluate_model(quantized_model, tokenizer, dataset, limit=200)
print(f"Accuracy on MedMCQA AFTER QAT: {quantized_accuracy:.2%}")

# Print comparison
print("\nComparison:")
print(f"Original model accuracy: {original_accuracy:.2%}")
print(f"Quantized model accuracy: {quantized_accuracy:.2%}")
print(f"Accuracy difference: {(quantized_accuracy - original_accuracy):.2f} percentage points")