<a href="https://colab.research.google.com/github/bnbryan/hpml-project/blob/master/prune_on_gpt2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load the model

install library

In [None]:
!pip install transformers
!pip install datasets

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

Load pre-trained gpt2 model from library

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

tokenizer.pad_token = tokenizer.eos_token

## Data pre-processing

In [None]:
from datasets import load_dataset

# 加载 WikiText-2 数据集
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

# 分词
def tokenize_text(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=512
    )

tokenized_dataset = dataset.map(tokenize_text, batched=True)
tokenized_dataset.set_format("torch", columns=[input_ids])

## Finetune the pre-trained model on WikiText-2

## Get the baseline accuracy

In [None]:
torch.cuda.empty_cache()

In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.eval()

total_loss = 0
total_tokens = 0

# 创建 DataLoader
dataloader = DataLoader(tokenized_dataset, batch_size=8, shuffle=False)

# 遍历数据集
with tqdm(dataloader, desc="Evaluating", unit="batch") as pbar:
    for batch in pbar:
        # 确保 input_ids 是张量
        torch.cuda.empty_cache()
        tokens = torch.stack(batch["input_ids"]).to(device)

        # 禁用梯度计算
        with torch.no_grad():
            outputs = model(input_ids=tokens, labels=tokens)
            loss = outputs.loss
            total_loss += loss.item() * tokens.size(1)
            total_tokens += tokens.size(1)

In [None]:
# 计算准确率
import math
print(f"Total Loss: {total_loss}")
print(f"Total Tokens: {total_tokens}")
print(f"Average Loss: {total_loss / total_tokens}")
perplexity = math.exp(total_loss / total_tokens)
print(f"Perplexity: {perplexity:.4f}")

# Pruning