In [1]:
!pip install transformers accelerate  datasets



In [2]:
!pip install -U bitsandbytes



In [4]:
!pip install transformers accelerate bitsandbytes>0.37.0

In [7]:
import bitsandbytes as bnb

print(bnb.__version__)

0.44.1


In [10]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import time
import numpy as np
from tqdm import tqdm

# Define model and tokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load the FP32 model
model_fp32 = AutoModelForCausalLM.from_pretrained(model_name)
print(f"FP32 Model Memory Usage: {model_fp32.get_memory_footprint() / (1024 ** 3):.2f} GB")

# Measure latency for FP32 model
input_text = "The history of quantum mechanics begins with"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
start_time = time.time()
with torch.no_grad():
    model_fp32(input_ids)
latency_fp32 = time.time() - start_time
print(f"FP32 Model Latency: {latency_fp32:.2f} seconds")

# Configure NF4 quantization
nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
)

# Load the model with NF4 quantization
model_nf4 = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=nf4_config)
print(f"NF4 Model Memory Usage: {model_nf4.get_memory_footprint() / (1024 ** 3):.2f} GB")

# Measure latency for NF4 quantized model
start_time = time.time()
with torch.no_grad():
    model_nf4(input_ids)
latency_nf4 = time.time() - start_time
print(f"NF4 Model Latency: {latency_nf4:.2f} seconds")

# Load a subset of the Wikipedia dataset
dataset = load_dataset("wikipedia", "20220301.en", split="train[:3000]")

# Helper function to compute perplexity with tqdm
def calculate_perplexity(model, tokenizer, dataset):
    texts = [entry["text"] for entry in dataset]
    encodings = tokenizer("\n\n".join(texts), return_tensors="pt", truncation=True, max_length=512)
    stride = 512
    nlls = []
    total_steps = (encodings.input_ids.size(1) // stride) + 1
    for i in tqdm(range(0, encodings.input_ids.size(1), stride), desc="Calculating perplexity", total=total_steps):
        input_ids = encodings.input_ids[:, i:i + stride].to(model.device)
        target_ids = input_ids.clone()
        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            nll = outputs.loss.item() * input_ids.size(1)
            nlls.append(nll)
    perplexity = np.exp(np.sum(nlls) / len(encodings.input_ids[0]))
    return perplexity

# Compute perplexity for all models with progress tracking
print("Calculating perplexity...")
perplexity_fp32 = calculate_perplexity(model_fp32, tokenizer, dataset)
perplexity_nf4 = calculate_perplexity(model_nf4, tokenizer, dataset)

print(f"FP32 Model Perplexity: {perplexity_fp32:.2f}")
print(f"NF4 Model Perplexity: {perplexity_nf4:.2f}")



FP32 Model Memory Usage: 0.48 GB


`low_cpu_mem_usage` was None, now set to True since model is quantized.


FP32 Model Latency: 0.24 seconds
NF4 Model Memory Usage: 0.12 GB
NF4 Model Latency: 0.04 seconds


Downloading data:   0%|          | 0/41 [00:00<?, ?files/s]

Generating train split:   0%|          | 0/6458670 [00:00<?, ? examples/s]

Calculating perplexity...


Calculating perplexity:  50%|█████     | 1/2 [00:06<00:06,  6.19s/it]
Calculating perplexity:  50%|█████     | 1/2 [00:01<00:01,  1.50s/it]


FP32 Model Perplexity: 24.92
NF4 Model Perplexity: 26.09


In [11]:

# Load the model with FP4 quantization
linear_4bit_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="fp4",  # Change to 'fp4'
)



# Load the model with linear 4-bit quantization
model_linear_4bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=linear_4bit_config)
print(f"Linear 4-bit Model Memory Usage: {model_linear_4bit.get_memory_footprint() / (1024 ** 3):.2f} GB")

# Measure latency for linear 4-bit quantized model
start_time = time.time()
with torch.no_grad():
    model_linear_4bit(input_ids)
latency_linear_4bit = time.time() - start_time
print(f"Linear 4-bit Model Latency: {latency_linear_4bit:.2f} seconds")
perplexity_linear_4bit = calculate_perplexity(model_linear_4bit, tokenizer, dataset)
print(f"Linear 4-bit Model Perplexity: {perplexity_linear_4bit:.2f}")


`low_cpu_mem_usage` was None, now set to True since model is quantized.


Linear 4-bit Model Memory Usage: 0.12 GB
Linear 4-bit Model Latency: 0.04 seconds


Calculating perplexity:  50%|█████     | 1/2 [00:00<00:00,  3.96it/s]


Linear 4-bit Model Perplexity: 27.51
