In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForCausalLM
from transformers import DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, GPT2Config, GPT2ForQuestionAnswering
from transformers import TrainingArguments, Trainer
from torch.utils.data import Dataset
import copy
import pandas as pd
import re
import transformers
import torch

transformers.logging.set_verbosity_error()

# Medium has 24 layers/GPT2Blocks
med_tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")

# Large has 36 layers/GPT2Blocks
large_tokenizer = AutoTokenizer.from_pretrained("gpt2-large")


2024-02-20 01:27:44.515105: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-20 01:27:44.546448: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-20 01:27:44.546484: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-20 01:27:44.547492: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-20 01:27:44.552963: I tensorflow/core/platform/cpu_feature_guar

In [2]:
# Load the SQuAD dataset, roughly 9:1:1, train:valid:test
train_squad = load_dataset("squad")["train"].train_test_split(test_size=0.12)
train_dataset = train_squad["train"]
valid_dataset = load_dataset("squad")["validation"]
test_dataset = train_squad["test"]

In [3]:
# sanity checking
len(train_dataset), len(valid_dataset), len(test_dataset)
print(train_dataset[0])

{'id': '570d61a5b3d812140066d7a7', 'title': 'Valencia', 'context': 'In March 2012, the newspaper El Mundo published a story according to which FGV had instructed employees who were to testify at the crash commission investigation, providing a set of possible questions and guidelines to prepare the answers. In April 2013, the television program Salvados questioned the official version of the incident as there were indications that the Valencian Government had tried to downplay the accident, which coincided with the visit of the pope to Valencia, or even to hide evidence, as the book of train breakdowns was never found. The day after the broadcast of this report, which received extensive media coverage, several voices called for the reopening of the investigation. The investigation was effectively reopened and the accident is currently under re-examination.', 'question': 'What evidence related to the crash remains missing?', 'answers': {'text': ['book of train breakdowns'], 'answer_start

In [4]:
def encode(examples, tokenizer):
    contexes = examples["context"]
    questions = examples["question"]
    answers = examples["answers"]
    samples = [f"{context}\n{question}\n{answer['text'][0]}" for context, question, answer in zip(contexes, questions, answers)]
    return tokenizer(samples, truncation=True, padding="max_length")

In [7]:
med_tokenizer.padding_side = "left"
if med_tokenizer.pad_token is None:
    med_tokenizer.pad_token = med_tokenizer.eos_token

train_dataset_med = train_dataset.map(lambda x: encode(x, med_tokenizer), batched=True)
valid_dataset_med = valid_dataset.map(lambda x: encode(x, med_tokenizer), batched=True)

Map:   0%|          | 0/77087 [00:00<?, ? examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [8]:
def print_decodes(decodes):
    for i, d in enumerate(decodes):
        print(f"{i}: {d}\n")

def get_question(sample):
    return f'{sample["context"]}\n{sample["question"]}'

def get_prediction(prompt, model, tokenizer, max_tokens=50):
    model.eval()
    input_text = [prompt]
    prompts = [torch.tensor(tokenizer.encode(s)).unsqueeze(0) for s in input_text]
    out0 = [tokenizer.decode(
        model.generate(p, 
                            max_length=p.shape[-1]+max_tokens)[0,:]) for p in prompts]
    print_decodes(out0)
    return out0

def get_model_answer(index, dataset, model, tokenizer, max_tokens=50):
    prompt = get_question(dataset[index])
    prediction = get_prediction(prompt, model, tokenizer, max_tokens)
    print("\nAnswer key: ", dataset[index]["answers"]["text"][0])
    return prediction

In [10]:
# download pruned model
gpt2_med_s23 = AutoModelForCausalLM.from_pretrained("han2lin/gpt2_med_s23")

model.safetensors:   0%|          | 0.00/1.37G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/144 [00:00<?, ?B/s]

In [12]:
# sanity check
print(f"num layers={gpt2_med_s23.config.n_layer}")

num layers=23


In [13]:
# Test base model prediction
INDEX = 1
answer = get_model_answer(INDEX, train_dataset_med, gpt2_med_s23, med_tokenizer)

0: Hunting and gathering was humanity's first and most successful adaptation, occupying at least 90 percent of human history. Following the invention of agriculture, hunter-gatherers have been displaced or conquered by farming or pastoralist groups in most parts of the world.
What are the basic types of agricultural groups?
The most common agricultural groups are hunter-gatherers, who live in groups of several hundred individuals, and pastoralists, who live in groups of several hundred individuals. The most common types of agricultural groups are:
1. Group A:


Answer key:  farming or pastoralist groups


In [None]:
# for setting up wandb:
# import wandb
# wandb.login()

# wandb.init(
#     project="gpt2-pruning",
#     config={
#         # "batch_size": BATCH_SIZE,
#         # "learning_rate": LEARNING_RATE,
#         "dataset": "SQuAD",
#     },
# )

In [None]:
BATCH_SIZE = 64
EPOCHS = 1
LEARNING_RATE = 1e-5
LOGGING_STEPS = 1000
SAVE_STEPS = 1000

In [None]:
def fine_tune_gpt2(model, 
                   tokenizer, 
                   train_dataset, 
                   valid_dataset, 
                   train_output_dir,
                   save_model_dir):
    # Create data collator for language modeling
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False)

    # Set training arguments
    training_args = TrainingArguments(
        output_dir = train_output_dir, 
        evaluation_strategy = "steps", 
        disable_tqdm = False,
        logging_steps = LOGGING_STEPS,
        logging_strategy = "steps",
        save_steps = SAVE_STEPS,
        num_train_epochs = EPOCHS,
        per_device_train_batch_size = BATCH_SIZE,
        per_device_eval_batch_size = BATCH_SIZE,
        learning_rate = LEARNING_RATE,
        # optim="paged_adamw_32bit",
        report_to = "wandb",
    )

    # Train the model
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
    )

    trainer.train()

    log_history = pd.DataFrame(trainer.state.log_history)
    print(log_history)

    # Save the fine-tuned model
    model.save_pretrained(save_model_dir)
    tokenizer.save_pretrained(save_model_dir)

In [None]:
fine_tune_gpt2(base_model, 
               base_tokenizer, 
               small_train_dataset['input_ids'], 
               small_valid_dataset['input_ids'],
               "train_log",
               "trained_model")