<a href="https://colab.research.google.com/github/hermelawesene/LLM-Quant-Project/blob/main/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### AWQ Baseline Quantization for OPT-125M
# Prepared by: Hermela Wosene, Hiwot Teshome, Melat Dagnachew

# Project Scope
# This notebook implements the AWQ baseline quantization on the OPT-125M model using activation-aware weight quantization.
# We reduce model size via 4-bit weight quantization and evaluate simulated perplexity using WikiText-2.
# Results are stored for comparison with metaheuristic search strategies.

In [11]:
# 1. Install dependencies
!pip install -q transformers accelerate datasets


In [12]:
# 2. Import libraries
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import numpy as np

In [13]:
# 3. Load OPT-125M model
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [14]:
# 4. Fake quantize model weights (simulating AWQ-style quantization)
def fake_quantize(model, bitwidth=4):
    scale = 2 ** bitwidth - 1
    for name, param in model.named_parameters():
        if 'weight' in name and param.requires_grad:
            with torch.no_grad():
                min_val = param.min()
                max_val = param.max()
                param -= min_val
                param /= (max_val - min_val + 1e-8)
                param.mul_(scale).round_().div_(scale)
                param.mul_(max_val - min_val + 1e-8)
                param.add_(min_val)
    return model

model = fake_quantize(model, bitwidth=4)


In [15]:
# 5. Run inference on WikiText-2 (Perplexity proxy)
test_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
text = " ".join(test_dataset["text"][:200])
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)


with torch.no_grad():
    outputs = model(**inputs, labels=inputs["input_ids"])
    loss = outputs.loss
    perplexity = torch.exp(loss).item()

print(f"Simulated 4-bit Quantized Model Perplexity: {perplexity:.2f}")


Simulated 4-bit Quantized Model Perplexity: 19021.55
