<a href="https://colab.research.google.com/github/bnbryan/hpml-project/blob/master/prune_on_gpt2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load the model

install library

In [None]:
!pip install transformers
!pip install datasets

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

Load pre-trained gpt2 model from library

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

tokenizer.pad_token = tokenizer.eos_token

## Data pre-processing

In [None]:
from datasets import load_dataset

# 加载 WikiText-103 数据集
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
# 分词
def tokenize_text(examples):
  return tokenizer(
        examples["text"],
        return_special_tokens_mask=True,
        truncation=True,
        max_length=512,
        padding='max_length'
    )

tokenized_dataset = dataset.map(tokenize_text, batched=True, remove_columns=['text'])
tokenized_dataset.set_format(type='torch', columns = ['input_ids', 'attention_mask'])

## Get the baseline accuracy

In [None]:
torch.cuda.empty_cache()

In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.eval()

total_loss = 0
total_tokens = 0

# 创建 DataLoader
dataloader = DataLoader(tokenized_dataset, batch_size=8)

# 遍历数据集
with tqdm(dataloader, desc="Evaluating", unit="batch") as pbar:
    for batch in pbar:
        torch.cuda.empty_cache()
        tokens = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        labels = tokens.clone()
        labels[attention_mask == 0] = -100

        with torch.no_grad():
            outputs = model(input_ids=tokens, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            num_active_tokens = attention_mask.sum().item()
            total_loss += loss.item() * num_active_tokens
            total_tokens += num_active_tokens

In [None]:
# 计算准确率
import math
print(f"Total Loss: {total_loss}")
print(f"Total Tokens: {total_tokens}")
print(f"Average Loss: {total_loss / total_tokens}")
perplexity = torch.exp(torch.tensor(total_loss / total_tokens))
print(f"Perplexity: {perplexity:.4f}")

# Pruning