# Notebook to replicate perplexity results

In this notebook we will show the code and all the explanations and considerations needed to replicate the results obtained for perplexity with LLaMA 2 7b and Wikitext2 dataset in a text generation task.

Firstly, we need to set up a python environment. For this purpose, the usual command can be used. For example:

python3 -m venv .env

After we have the environment created, we need to install all the requirements. For this, just execute the following command with the environment activated:

pip install -r requirements.txt

After we have all set up, we just need to import everything and compute perplexity for each quantization method.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GPTQConfig, HqqConfig
from optimum.quanto import QuantizedModelForCausalLM, qint4, qint8
from datasets import load_dataset
from tqdm import tqdm
from codecarbon import OfflineEmissionsTracker
import math
import os
import wandb

Now let's first calculate perplexity for the model without quantization. We define the model loading function for BitsAndBytes but that allows loading without quantization.

Clearly, the lines related to the use of CodeCarbon and Wandb can be changed or even commented out. They can be customized as needed.

In [None]:
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'  # Change depending on GPU used

tracker1 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_not_quant_perplexity.csv", gpu_ids=[1])
wandb.init(project="Perplexity", name="not_quant")

def load_llama2_model(model_id, quantization=None):
    if quantization == "4bit":
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
    elif quantization == "8bit":
        quant_config = BitsAndBytesConfig(load_in_8bit=True)
    else:
        quant_config = None

    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        quantization_config=quant_config,
        torch_dtype=torch.float16 if quantization else torch.float32
    )
    return tokenizer, model

def calculate_perplexity_precise(model, tokenizer, texts, max_length=1024, stride=512, device='cuda'):
    model.eval()

    nll_sum = 0.0
    n_tokens = 0

    for text in tqdm(texts, desc="Calculating perplexity"):
        encodings = tokenizer(text, return_tensors="pt", truncation=False)
        input_ids = encodings.input_ids.to(device)
        seq_len = input_ids.size(1)
        prev_end_loc = 0

        for begin_loc in range(0, seq_len, stride):
            end_loc = min(begin_loc + max_length, seq_len)
            trg_len = end_loc - prev_end_loc

            input_ids_chunk = input_ids[:, begin_loc:end_loc]
            target_ids = input_ids_chunk.clone()
            target_ids[:, :-trg_len] = -100  

            with torch.no_grad():
                outputs = model(input_ids_chunk, labels=target_ids)
                neg_log_likelihood = outputs.loss

            valid_tokens = (target_ids != -100).sum().item()
            effective_tokens = valid_tokens - target_ids.size(0)  
            nll_sum += neg_log_likelihood.item() * effective_tokens
            n_tokens += effective_tokens

            prev_end_loc = end_loc
            if end_loc == seq_len:
                break

    avg_nll = nll_sum / n_tokens
    ppl = math.exp(avg_nll)
    wandb.log({"Perplexity": ppl})
    return ppl


model_id = "meta-llama/Llama-2-7b-hf" 
quantization = None  # Change to "8bit" or "4bit" if needed

print(f"Loading model {model_id} with quantization: {quantization or 'not quantized'}")
tracker1.start()
tokenizer, model = load_llama2_model(model_id, quantization)
tracker1.stop()

print("Loading WikiText-2...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:100%]")
texts = [sample["text"] for sample in dataset if sample["text"].strip()]

print("Calculating perplexity...")
tracker2.start()
ppl = calculate_perplexity_precise(model, tokenizer, texts)
tracker2.stop()
print(f"\n✅ Perplexity ({quantization or 'not quantized'}): {ppl:.2f}")


Now we proceed with BitsAndBytes (8 and 4 bits) results

In [None]:
tracker1 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_8bit_ByB_quant_perplexity.csv", gpu_ids=[1])
tracker2 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_8bit_ByB_quant_eval_perplexity.csv", gpu_ids=[1])
wandb.init(project="Perplexity", name="8bit_ByB")

model_id = "meta-llama/Llama-2-7b-hf" 
quantization = "8bit" 

print(f"Loading model {model_id} with quantization: {quantization or 'not quantized'}")
tracker1.start()
tokenizer, model = load_llama2_model(model_id, quantization)
tracker1.stop()

print("Loading WikiText-2...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:100%]")
texts = [sample["text"] for sample in dataset if sample["text"].strip()]

print("Calculating perplexity...")
tracker2.start()
ppl = calculate_perplexity_precise(model, tokenizer, texts)
tracker2.stop()
print(f"\n✅ Perplexity ({quantization or 'not quantized'}): {ppl:.2f}")

In [None]:
tracker1 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_4bit_ByB_quant_perplexity.csv", gpu_ids=[1])
tracker2 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_4bit_ByB_quant_eval_perplexity.csv", gpu_ids=[1])
wandb.init(project="Perplexity", name="4bit_ByB")

model_id = "meta-llama/Llama-2-7b-hf" 
quantization = "4bit" 

print(f"Loading model {model_id} with quantization: {quantization or 'not quantized'}")
tracker1.start()
tokenizer, model = load_llama2_model(model_id, quantization)
tracker1.stop()

print("Loading WikiText-2...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:100%]")
texts = [sample["text"] for sample in dataset if sample["text"].strip()]

print("Calculating perplexity...")
tracker2.start()
ppl = calculate_perplexity_precise(model, tokenizer, texts)
tracker2.stop()
print(f"\n✅ Perplexity ({quantization or 'not quantized'}): {ppl:.2f}")

Now we proceed with GPTQ (8,4 and 3 bits):

In [None]:
def load_llama2_model(model_id, quantization=None):
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
    if quantization == "4bit":
        quant_config = GPTQConfig(bits=4, dataset="c4", tokenizer=tokenizer)
    elif quantization == "8bit":
        quant_config = GPTQConfig(bits=8, dataset="c4", tokenizer=tokenizer)
    elif quantization == "3bit":
        quant_config = GPTQConfig(bits=3, dataset="c4", tokenizer=tokenizer)
    else:
        quant_config = None

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        quantization_config=quant_config,
        torch_dtype=torch.float16 if quantization else torch.float32
    )
    return tokenizer, model

In [None]:
tracker1 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_8bit_GPTQ_quant_perplexity.csv", gpu_ids=[1])
tracker2 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_8bit_GPTQ_quant_eval_perplexity.csv", gpu_ids=[1])
wandb.init(project="Perplexity", name="8bit_GPTQ")

model_id = "meta-llama/Llama-2-7b-hf" 
quantization = "8bit" 

print(f"Loading model {model_id} with quantization: {quantization or 'not quantized'}")
tracker1.start()
tokenizer, model = load_llama2_model(model_id, quantization)
tracker1.stop()

print("Loading WikiText-2...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:100%]")
texts = [sample["text"] for sample in dataset if sample["text"].strip()]

print("Calculating perplexity...")
tracker2.start()
ppl = calculate_perplexity_precise(model, tokenizer, texts)
tracker2.stop()
print(f"\n✅ Perplexity ({quantization or 'not quantized'}): {ppl:.2f}")

In [None]:
tracker1 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_4bit_GPTQ_quant_perplexity.csv", gpu_ids=[1])
tracker2 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_4bit_GPTQ_quant_eval_perplexity.csv", gpu_ids=[1])
wandb.init(project="Perplexity", name="4bit_GPTQ")

model_id = "meta-llama/Llama-2-7b-hf" 
quantization = "4bit" 

print(f"Loading model {model_id} with quantization: {quantization or 'not quantized'}")
tracker1.start()
tokenizer, model = load_llama2_model(model_id, quantization)
tracker1.stop()

print("Loading WikiText-2...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:100%]")
texts = [sample["text"] for sample in dataset if sample["text"].strip()]

print("Calculating perplexity...")
tracker2.start()
ppl = calculate_perplexity_precise(model, tokenizer, texts)
tracker2.stop()
print(f"\n✅ Perplexity ({quantization or 'not quantized'}): {ppl:.2f}")

In [None]:
tracker1 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_3bit_GPTQ_quant_perplexity.csv", gpu_ids=[1])
tracker2 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_3bit_GPTQ_quant_eval_perplexity.csv", gpu_ids=[1])
wandb.init(project="Perplexity", name="3bit_GPTQ")

model_id = "meta-llama/Llama-2-7b-hf" 
quantization = "3bit" 

print(f"Loading model {model_id} with quantization: {quantization or 'not quantized'}")
tracker1.start()
tokenizer, model = load_llama2_model(model_id, quantization)
tracker1.stop()

print("Loading WikiText-2...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:100%]")
texts = [sample["text"] for sample in dataset if sample["text"].strip()]

print("Calculating perplexity...")
tracker2.start()
ppl = calculate_perplexity_precise(model, tokenizer, texts)
tracker2.stop()
print(f"\n✅ Perplexity ({quantization or 'not quantized'}): {ppl:.2f}")

Now we proceed with HQQ (8,4 and 3 bits):

In [None]:

def load_llama2_model(model_id, quantization=None):
    if quantization == "4bit":
        quant_config = HqqConfig(nbits=4, group_size=64)
    elif quantization == "8bit":
        quant_config = HqqConfig(nbits=8, group_size=64)
    elif quantization == "3bit":
        quant_config = HqqConfig(nbits=3, group_size=64)
    else:
        quant_config = None

    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        quantization_config=quant_config,
        torch_dtype=torch.float16 if quantization else torch.float32
    )
    return tokenizer, model

In [None]:
tracker1 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_8bit_HQQ_quant_perplexity.csv", gpu_ids=[1])
tracker2 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_8bit_HQQ_quant_eval_perplexity.csv", gpu_ids=[1])
wandb.init(project="Perplexity", name="8bit_HQQ")

model_id = "meta-llama/Llama-2-7b-hf" 
quantization = "8bit" 

print(f"Loading model {model_id} with quantization: {quantization or 'not quantized'}")
tracker1.start()
tokenizer, model = load_llama2_model(model_id, quantization)
tracker1.stop()

print("Loading WikiText-2...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:100%]")
texts = [sample["text"] for sample in dataset if sample["text"].strip()]

print("Calculating perplexity...")
tracker2.start()
ppl = calculate_perplexity_precise(model, tokenizer, texts)
tracker2.stop()
print(f"\n✅ Perplexity ({quantization or 'not quantized'}): {ppl:.2f}")

In [None]:
tracker1 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_4bit_HQQ_quant_perplexity.csv", gpu_ids=[1])
tracker2 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_4bit_HQQ_quant_eval_perplexity.csv", gpu_ids=[1])
wandb.init(project="Perplexity", name="4bit_HQQ")

model_id = "meta-llama/Llama-2-7b-hf" 
quantization = "8bit" 

print(f"Loading model {model_id} with quantization: {quantization or 'not quantized'}")
tracker1.start()
tokenizer, model = load_llama2_model(model_id, quantization)
tracker1.stop()

print("Loading WikiText-2...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:100%]")
texts = [sample["text"] for sample in dataset if sample["text"].strip()]

print("Calculating perplexity...")
tracker2.start()
ppl = calculate_perplexity_precise(model, tokenizer, texts)
tracker2.stop()
print(f"\n✅ Perplexity ({quantization or 'not quantized'}): {ppl:.2f}")

In [None]:
tracker1 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_3bit_HQQ_quant_perplexity.csv", gpu_ids=[1])
tracker2 = OfflineEmissionsTracker(country_iso_code="ESP", allow_multiple_runs = True, output_file= "./emissions_3bit_HQQ_quant_eval_perplexity.csv", gpu_ids=[1])
wandb.init(project="Perplexity", name="3bit_HQQ")

model_id = "meta-llama/Llama-2-7b-hf" 
quantization = "8bit" 

print(f"Loading model {model_id} with quantization: {quantization or 'not quantized'}")
tracker1.start()
tokenizer, model = load_llama2_model(model_id, quantization)
tracker1.stop()

print("Loading WikiText-2...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:100%]")
texts = [sample["text"] for sample in dataset if sample["text"].strip()]

print("Calculating perplexity...")
tracker2.start()
ppl = calculate_perplexity_precise(model, tokenizer, texts)
tracker2.stop()
print(f"\n✅ Perplexity ({quantization or 'not quantized'}): {ppl:.2f}")