In [1]:
import numpy as np
import random
import torch
import transformers
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, pipeline
import gc
import os 
from torch.ao.quantization import (
    QuantStub, 
    DeQuantStub, 
    prepare_qat, 
    convert, 
    FakeQuantize,
    MinMaxObserver,
    float_qparams_weight_only_qconfig,
    QConfig,
    get_default_qat_qconfig, 
    propagate_qconfig_,
)
# from torch.ao.quantization.observer import 


random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
from tqdm import tqdm 

#GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
torch.__version__
device="cpu"

# Empty VRAM cache
gc.collect()
torch.cuda.empty_cache()

model_name = "models/llama3-8b/"

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [2]:
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(model_name,
                                            # use_cache = True,
                                            torch_dtype=torch.float32
)
tokenizer = AutoTokenizer.from_pretrained(model_name)


Loading model...


Loading checkpoint shards: 100%|██████████| 4/4 [01:20<00:00, 20.23s/it]


In [3]:
# Step 2: Define Custom QConfig for 4-bit MinMax Quantization
def create_4bit_qconfig_with_embedding():
    """
    Creates a custom QConfig for 4-bit MinMax quantization with special handling for Embedding layers.
    """
    default_qconfig = QConfig(
        activation=FakeQuantize.with_args(
            observer=MinMaxObserver,
            quant_min=0,
            quant_max=15,
            dtype=torch.quint8,
            qscheme=torch.per_tensor_affine,
        ),
        weight=FakeQuantize.with_args(
            observer=MinMaxObserver,
            quant_min=-8,
            quant_max=7,
            dtype=torch.qint8,
            qscheme=torch.per_tensor_symmetric,
        ),
    )

    qconfig_dict = {
        # Default QConfig for all layers
        "": default_qconfig,
        # Special QConfig for Embedding layers
        torch.nn.Embedding: float_qparams_weight_only_qconfig,
    }
    return qconfig_dict

print("Assigning quantization configuration...")
model.qconfig = create_4bit_qconfig_with_embedding()

Assigning quantization configuration...


In [None]:
# Step 3: Prepare the Model for Quantization-Aware Training (QAT)
print("Preparing model for QAT...")
model.train()
model.qconfig = get_default_qat_qconfig("fbgemm")
try:
    propagate_qconfig_(model)
except AttributeError as e:
    print("Error during qconfig propagation. Double-check the qconfig assignment.", str(e))
    raise
# Prepare the model for QAT
prepared_model = prepare_qat(model, inplace=False)
print("Model prepared for QAT.")

Preparing model for QAT...




Model prepared for QAT.


In [None]:
# Step 4: Calibrate the Model
def calibrate_model(model, tokenizer):
    """
    Feeds representative data into the model to calibrate for quantization.
    """
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # Reuse EOS token for padding
        
    model.eval()
    num_samples = 1
    calibration_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "Transformers are revolutionizing natural language processing."
    ] * (num_samples // 2)

    print("Calibrating model...")
    with torch.no_grad():
        for text in tqdm(calibration_texts, desc="Calibrating"):
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            model(**inputs)

calibrate_model(prepared_model, tokenizer)

Calibrating model...


In [7]:
# Step 5: Convert the Model to a Quantized Version
print("Converting model to a quantized version...")
quantized_model = convert(prepared_model, inplace=False)

# Step 6: Save the Quantized Model
quantized_model_path = os.path.join(model_name, "models/llama_quantized_minmax_4bit.pt")
torch.save(quantized_model.state_dict(), quantized_model_path)
print(f"Quantized model saved at: {quantized_model_path}")
print(f"Quantized model size: {os.path.getsize(quantized_model_path) / 1e9:.2f} KB")

Converting model to a quantized version...


: 

In [None]:
# Step 4: Test the quantized model
def test_quantized_model(model, tokenizer):
    model.eval()
    test_text = "What is the capital of France?"
    inputs = tokenizer(test_text, return_tensors="pt", truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=10)
    print("Quantized model output:", tokenizer.decode(output[0], skip_special_tokens=True))


In [None]:
print("Testing quantized model...")
test_quantized_model(model, tokenizer)
