## Fairy Tale Generator with GPT-2

In [None]:
# import libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torchtext
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader, Dataset

import re, random, math, time
from tqdm.notebook import tqdm

# ignore warnings
import warnings
warnings.filterwarnings('ignore')

# define the computing device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# set random seed
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

### 1. Loading data

In [None]:
df = pd.read_csv('grimms_fairytales.csv')

df.head()

In [None]:
df.info()

In [None]:
# drop the Unnamed: 0 column
df = df.drop('Unnamed: 0', axis=1)

df.head()

In [None]:
df['Text'][0]

In [None]:
# replace new line characters with space
df['Text'] = df['Text'].replace('\n', ' ', regex=True)

In [None]:
df['Text'][0]

In [None]:
# split the data into train, val and test
train_df, val_df = train_test_split(df, test_size=0.2, random_state=SEED)

train_df.shape, val_df.shape

In [None]:
# create raw datasets
from datasets import Dataset, DatasetDict

ds_train = Dataset.from_pandas(train_df)
ds_valid = Dataset.from_pandas(val_df)

raw_datasets = DatasetDict(
    {
        "train": ds_train,
        "valid": ds_valid
    }
)

raw_datasets

In [None]:
# remove '__index_level_0__' column
raw_datasets["train"] = raw_datasets["train"].remove_columns("__index_level_0__")
raw_datasets["valid"] = raw_datasets["valid"].remove_columns("__index_level_0__")

raw_datasets

In [None]:
for key in raw_datasets["train"][0]:
    print(f"{key}: {raw_datasets['train'][0][key][:200]}")

### 2. Preprocessing

In [None]:
# create the tokenizer
from transformers import AutoTokenizer

context_length = 256
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)

outputs = tokenizer(
    raw_datasets["train"][:2]["Text"],
    truncation=True,
    max_length=context_length,
    return_overflowing_tokens=True,
    return_length=True,
)

print(f"Input IDs length: {len(outputs['input_ids'])}")
print(f"Input chunk lengths: {(outputs['length'])}")
print(f"Chunk mapping: {outputs['overflow_to_sample_mapping']}")

In [None]:
# tokenize the data
def tokenize(element):
    outputs = tokenizer(
        element["Text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}

tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)

tokenized_datasets

### 3. Preparing data loaders

In [None]:
from torch.utils.data.dataloader import DataLoader

batch_size = 8

tokenized_datasets.set_format("torch")

train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=batch_size, shuffle=True)
eval_dataloader  = DataLoader(tokenized_datasets["valid"], batch_size=batch_size)

len(train_dataloader), len(eval_dataloader)

In [None]:
for i in train_dataloader:
    print(i['input_ids'].shape)
    break

### 4. Modeling

In [None]:
# define the model
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig

config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

model = GPT2LMHeadModel(config)
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 size: {model_size/1000**2:.1f}M parameters")

In [None]:
# prepare key token ids
keytoken_ids = []
for keyword in [
    "Once upon a time",
    "Long long ago",
    "In a faraway land"
]:
    ids = tokenizer([keyword]).input_ids[0]
    if len(ids) == 1:
        keytoken_ids.append(ids[0])
    else:
        print(f"Keyword has not single token: {keyword}")

In [None]:
# define the loss function and the optimizer
from torch.nn import CrossEntropyLoss

def keytoken_weighted_loss(inputs, logits, keytoken_ids, alpha=1.0):
    # Shift so that tokens < n predict n
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()
    # Calculate per-token loss
    loss_fct = CrossEntropyLoss(reduce=False) #change to reduction=None
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    # Resize and average loss per sample
    loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)
    # Calculate and scale weighting
    weights = torch.stack([(inputs == kt).float() for kt in keytoken_ids]).sum(
        axis=[0, 2]
    )
    weights = alpha * (1.0 + weights)
    # Calculate weighted average
    weighted_loss = (loss_per_sample * weights).mean()
    return weighted_loss

weight_decay = 0.1

def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

model = GPT2LMHeadModel(config).to(device)

from torch.optim import AdamW

optimizer = AdamW(get_grouped_params(model), lr=5e-4)

In [None]:
# use Accelerator to speed up training
from accelerate import Accelerator

accelerator = Accelerator(mixed_precision='fp16')

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [None]:
# use get_scheduler library to schedule the learning rate
from transformers import get_scheduler

num_train_epochs = 1
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [None]:
# log in to huggingface hub
from huggingface_hub import notebook_login

notebook_login()

In [None]:
# create a repository to save the model
from huggingface_hub import Repository, get_full_repo_name

model_name = "fairy-tale-generator-accelerate"
repo_name = get_full_repo_name(model_name)
repo_name

In [None]:
# clone the repository in a local folder
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

output_dir = "fairy-tale-generator-accelerate"
repo = Repository(output_dir, clone_from=repo_name)

### 5. Training

In [None]:
# function to evaluate the model during training
def evaluate():
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])
            outputs.loss = outputs.loss.reshape(1)
        losses.append(accelerator.gather(outputs.loss))        
    loss = torch.mean(torch.cat(losses))
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [None]:
# test the evaluate function
evaluate()

In [None]:
num_training_steps

In [None]:
# train the model

gradient_accumulation_steps = 10
eval_steps = 2

model.train()
completed_steps = 0
for epoch in range(num_train_epochs):
    for step, batch in tqdm(
        enumerate(train_dataloader, start=1), total=num_training_steps
    ):
        logits = model(batch["input_ids"]).logits
        loss = keytoken_weighted_loss(batch["input_ids"], logits, keytoken_ids)
        if step % 100 == 0:
            accelerator.print(
                {
                    "steps": completed_steps,
                    "loss/train": loss.item() * gradient_accumulation_steps,
                }
            )
        loss = loss / gradient_accumulation_steps
        accelerator.backward(loss)
        if step % gradient_accumulation_steps == 0:
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            completed_steps += 1
        if (step % (eval_steps * gradient_accumulation_steps)) == 0:
            eval_loss, perplexity = evaluate()
            accelerator.print({"loss/eval": eval_loss, "perplexity": perplexity})
            model.train()
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
            if accelerator.is_main_process:
                tokenizer.save_pretrained(output_dir)
                repo.push_to_hub(
                    commit_message=f"Training in progress step {step}", blocking=False
                )

### 6. Inference

#### Greedy search

#### Beam search