### Note about the dataset
You should start by running the data preprocessing code in the github repo (`data/preprocessing/get_data.ipynb`) or just clone the repo to get a copy of `limericks.json`, which is then used to finetune the GPT-2 model.

In [None]:
# Start by installing required libraries (mainly Transformers)
# !pip install transformers==4.17.0
# !pip install scikit-learn

In [None]:
# Only needed when running in colab
# from google.colab import drive
# drive.mount("/content/drive/", force_remount=True)

In [None]:
import glob
import json
import math
import numpy as np
import os
import random
import shutil
import string
import torch
import torch.optim as optim
import tqdm.notebook as tqdm

from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from transformers import DataCollatorForLanguageModeling
from transformers import GPT2Model
from transformers import GPT2LMHeadModel
from transformers import GPT2Tokenizer
from transformers import AdamW, get_scheduler

In [None]:
# Change them if needed
data_dir = "/content/drive/MyDrive/11-785-final/data/"
ckpt_dir = "/content/drive/MyDrive/11-785-final/ckpt/"

os.makedirs(ckpt_dir, exist_ok=True)

In [None]:
data = json.load(open(f"{data_dir}/limericks.json"))
limericks = []

for _, limerick in data['limericks'].items():
    lines = limerick['lines']
    flag = True

    # Remove the final punctuation of each line
    # (we'll use a special separator instead)
    for idx, line in enumerate(lines):
        if len(line) == 0:
            flag = False
            break
        if line[-1] in string.punctuation:
            lines[idx] = line[:-1]
    
    if flag:
        limericks.append(lines)

In [None]:
print(f"# of limericks before clean-up: {len(data['limericks'])}")
print(f"# of limericks after clean-up: {len(limericks)}")

In [None]:
# We'll use a new special token <LINE> as the separator between lines
# Also notice that we add the pad_token for padding purpose, but it should be
# masked out (i.e. ineffective) by using attention_mask throughout the training
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({
    "sep_token": "<LINE>",
    "pad_token": "<PAD>"})
print(f"New sep_token: {tokenizer.sep_token} ({tokenizer.sep_token_id})")
print(f"New pad_token: {tokenizer.pad_token} ({tokenizer.pad_token_id})")

In [None]:
# We can construct a training sample of limericks by merging the lines
# with the separator attached at the end of each line
def merge_lines(lines):
    string = ' <LINE> '.join(lines) + ' <LINE>'
    return string

In [None]:
sample = random.sample(limericks, 1)[0]
string = merge_lines(sample)
print(f"Lines with separator: {string}")
input_ids = tokenizer(string)['input_ids']
print(f"Tokens: {input_ids}")
decoded_string = tokenizer.decode(input_ids)
print(f"Decoding result: {decoded_string}")

In [None]:
train_data, val_data = train_test_split(limericks, train_size=0.9)
print(f"# of training samples: {len(train_data)}")
print(f"# of validation samples: {len(val_data)}")

In [None]:
class LimerickDataset(Dataset):
    def __init__(self, data):
        self.data = [merge_lines(limerick) for limerick in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
def gen_collate_fn(tokenizer):
    def collate_fn(batch):
        batch = tokenizer(batch, padding="longest", return_tensors="pt")
        batch['labels'] = torch.clone(batch['input_ids']).detach()
        for key, value in batch.items():
            batch[key] = value.cuda()
        return batch

    return collate_fn

In [None]:
# optimizer
learning_rate = 5e-5
weight_decay = 0.0
# scheduler
scheduler_type = "linear"
num_warmup_steps = 0
# training loop
epochs = 20
batch_size = 32
gradient_accumulation_steps = 1
# ckpt
exp_name = "standard-gpt2"
debug = False

In [None]:
exp_dir = f"{ckpt_dir}/{exp_name}"
os.makedirs(exp_dir, exist_ok=True)
log_file = f"{exp_dir}/log.txt"

In [None]:
if not debug:
    train_dataset = LimerickDataset(train_data)
    val_dataset = LimerickDataset(val_data)
else:
    train_dataset = LimerickDataset(train_data[:batch_size * 8])
    val_dataset = LimerickDataset(val_data[:batch_size * 2])

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    drop_last=True,
    shuffle=True,
    collate_fn=gen_collate_fn(tokenizer))
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    drop_last=False,
    shuffle=False,
    collate_fn=gen_collate_fn(tokenizer))

In [None]:
# initialize the model, also resize the embeddings for new tokens
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer))
model = model.cuda()

In [None]:
# Reference: https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
    },
    {
        "params": [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)

T_epoch = np.ceil(len(train_loader) // gradient_accumulation_steps)
scheduler = get_scheduler(
    name=scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=epochs * T_epoch)
scaler = torch.cuda.amp.GradScaler()

In [None]:
files = glob.glob(f"{exp_dir}/epoch-*.ckpt")
if len(files) != 0:
    files = sorted(files, key=lambda x: int(os.path.basename(x)[6:-5]))
    states = torch.load(files[-1])
    
    model.load_state_dict(states['model_state_dict'])
    optimizer.load_state_dict(states['optimizer_state_dict'])
    scheduler.load_state_dict(states['scheduler_state_dict'])
    scaler.load_state_dict(states['scaler_state_dict'])
    start_epoch = states['epoch'] + 1
    best_perplexity = states['perplexity']
else:
    start_epoch = 0
    best_perplexity = 1e30

if start_epoch == 0:
    print("Start training from scratch")
else:
    print(f"Resume training from epoch {start_epoch + 1}")

In [None]:
# Reference: https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
def train_epoch(model, train_loader, optimizer, scheduler, scaler):
    model.train()
    optimizer.zero_grad()

    bar = tqdm.tqdm(train_loader, leave=False)
    loss_total = 0.

    for step, batch in enumerate(bar):
        outputs = model(**batch)
        loss = outputs.loss
        loss_total += loss.item()
        loss = loss / gradient_accumulation_steps
        scaler.scale(loss).backward()
  
        if (
                step % gradient_accumulation_steps == 0 or
                step == len(train_loader) - 1
        ):
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()

        bar.set_postfix({"Loss": f"{loss_total / (step + 1):.4f}"})

    return loss_total / len(train_loader)

In [None]:
# Reference: https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
def validation(model, val_loader):
    model.eval()

    bar = tqdm.tqdm(val_loader, leave=False)
    losses = []

    for step, batch in enumerate(bar):
        with torch.no_grad():
            outputs = model(**batch)

        batch_size = batch['input_ids'].shape[0]
        loss = outputs.loss.item()
        losses.extend([loss for _ in range(batch_size)])

        try:
            perplexity = math.exp(np.mean(losses))
        except OverflowError:
            perplexity = float('inf')

    return perplexity

In [None]:
# Reference: https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py
epoch_bar = tqdm.trange(start_epoch, epochs, leave=False)

for epoch in epoch_bar:
    loss = train_epoch(model, train_loader, optimizer, scheduler, scaler)
    perplexity = validation(model, val_loader)

    log = f"Epoch {epoch+1} Loss: {loss:.4f} Perplexity {perplexity:.4f}"
    epoch_bar.write(log)
    with open(log_file, 'a') as file:
        file.write(f"{log}\n")

    flag = False
    if perplexity < best_perplexity:
        best_perplexity = perplexity
        flag = True

    epoch_bar.write(f"Save model at epoch {epoch+1}")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': 
            scheduler.state_dict()
            if scheduler is not None else None,
        'scaler_state_dict': scaler.state_dict(),
        'epoch': epoch,
        'perplexity': perplexity,
        'best_perplexity': best_perplexity
    }, f"{exp_dir}/epoch-{epoch+1}.ckpt")

    if flag:
        print(f"Save best model at epoch {epoch+1}")
        best_perplexity = perplexity
        shutil.copyfile(
            f"{exp_dir}/epoch-{epoch+1}.ckpt",
            f"{exp_dir}/best-model.ckpt")

In [None]:
tmp_dir = "/content/test"

states = torch.load(f"{exp_dir}/best-model.ckpt")
model.load_state_dict(states['model_state_dict'])

model.save_pretrained(tmp_dir)
new_model = AutoModelForCausalLM.from_pretrained(tmp_dir)

In [None]:
prompt = "if you're using a subsurface map <LINE>"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
outputs = new_model.generate(input_ids, max_length=100, do_sample=True)
tokenizer.batch_decode(outputs, skip_special_tokens=False)