In [7]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import evaluate
import numpy as np
import os
from torch.nn import CrossEntropyLoss
from evaluate import logging

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = "google/gemma-2b"  

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

In [None]:

# model_8bit = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit = True)
# model_4bit = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit = True)

# Load the WikiText-2 dataset
# dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

# # Tokenize the dataset
# def tokenize_function(examples):
#     return tokenizer(examples["text"], return_tensors="pt", truncation=True, padding="max_length", max_length=512)

# tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

In [4]:
print(model.get_memory_footprint())
print(model_8bit.get_memory_footprint())
print(model_4bit.get_memory_footprint())

10024698880
3030545408
2039641088


In [None]:
def compute_perplexity(model, model_name):
    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Load the WikiText-2 dataset
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

    # Preprocess the dataset
    def tokenize_function(examples):
        return tokenizer(examples["text"], truncation=True, padding='max_length', max_length=512)

    tokenized_dataset = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])
    tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

    # Initialize DataLoader
    dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=8)

    model.eval()
    total_loss = 0
    total_count = 0

    with torch.no_grad():
        for batch in dataloader:
            inputs = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            outputs = model(inputs, attention_mask=attention_mask, labels=inputs)
            total_loss += outputs.loss.item() * inputs.size(0)
            total_count += inputs.size(0)

    return torch.exp(total_loss / total_count)

# Usage
model = model
model_name = model_name
print(f"Perplexity: {compute_perplexity(model, model_name)}")

In [None]:
def compute(predictions, model, tokenizer, batch_size: int = 8, add_start_token: bool = True, device=None, max_length=None):

    if device is not None:
        assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
        if device == "gpu":
            device = "cuda"
    else:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    model = model.to(device)

    tokenizer = tokenizer

    # if batch_size > 1 (which generally leads to padding being required), and
    # if there is not an already assigned pad_token, assign an existing
    # special token to also be the padding token
    if tokenizer.pad_token is None and batch_size > 1:
        existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
        # check that the model already has at least one special token defined
        assert (
            len(existing_special_tokens) > 0
        ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
        # assign one of the special tokens to also be the pad token
        tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

    if add_start_token and max_length:
        # leave room for <BOS> token to be added:
        assert (
            tokenizer.bos_token is not None
        ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
        max_tokenized_len = max_length - 1
    else:
        max_tokenized_len = max_length

    encodings = tokenizer(
        predictions,
        add_special_tokens=False,
        padding=True,
        truncation=True if max_tokenized_len else False,
        max_length=max_tokenized_len,
        return_tensors="pt",
        return_attention_mask=True,
    ).to(device)

    encoded_texts = encodings["input_ids"]
    attn_masks = encodings["attention_mask"]

    # check that each input is long enough:
    if add_start_token:
        assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
    else:
        assert torch.all(
            torch.ge(attn_masks.sum(1), 2)
        ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."

    ppls = []
    loss_fct = CrossEntropyLoss(reduction="none")

    for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
        end_index = min(start_index + batch_size, len(encoded_texts))
        encoded_batch = encoded_texts[start_index:end_index]
        attn_mask = attn_masks[start_index:end_index]

        if add_start_token:
            bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
            encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
            attn_mask = torch.cat(
                [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
            )

        labels = encoded_batch

        with torch.no_grad():
            out_logits = model(encoded_batch, attention_mask=attn_mask).logits

        shift_logits = out_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_attention_mask_batch = attn_mask[..., 1:].contiguous()

        perplexity_batch = torch.exp(
            (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
            / shift_attention_mask_batch.sum(1)
        )

        ppls += perplexity_batch.tolist()

    return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}

In [8]:
perplexity = evaluate.load("perplexity", module_type="metric")
input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"] #[:10] # doctest: +SKIP
input_texts = [s for s in input_texts if s!='']
results = perplexity.compute(model_id='gpt2', batch_size=4, predictions=input_texts)
print(list(results.keys()))

TypeError: EvaluationModule.compute() missing 1 required positional argument: 'self'

: 

In [8]:
print(round(results["mean_perplexity"], 4))
print(results["perplexities"])

827.0198
[2904.349609375, 10.540328025817871, 11.525307655334473, 1188.0777587890625, 20.60600471496582]


In [12]:
import math
from torch.utils.data import DataLoader

def collate_fn(batch):
    input_ids = torch.stack([item["input_ids"].squeeze() for item in batch]).to(device)
    attention_mask = torch.stack([item["attention_mask"].squeeze() for item in batch]).to(device)
    return {"input_ids": input_ids, "attention_mask": attention_mask}

def evaluate_perplexity(model, tokenized_dataset, batch_size=8):
    model.eval()
    dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, collate_fn=collate_fn)
    total_loss = 0
    total_tokens = 0

    with torch.no_grad():
        for batch in dataloader:
            outputs = model(**batch, labels=batch["input_ids"])
            total_loss += outputs.loss.item() * batch["input_ids"].size(1)
            total_tokens += batch["input_ids"].size(1)

    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    return perplexity

# Evaluate perplexity for full precision model
perplexity_full_precision = evaluate_perplexity(model, tokenized_dataset)
print(f"Full Precision Model Perplexity: {perplexity_full_precision}")

# Evaluate perplexity for 8-bit model
perplexity_8bit = evaluate_perplexity(model_8bit, tokenized_dataset)
print(f"8-bit Model Perplexity: {perplexity_8bit}")

# Evaluate perplexity for 4-bit model if available
perplexity_4bit = evaluate_perplexity(model_4bit, tokenized_dataset)
print(f"4-bit Model Perplexity: {perplexity_4bit}")


AttributeError: 'list' object has no attribute 'squeeze'