Fine Tune GPT-2 model for Creative Story Generation using Writing Prompt Dataset.  

In [None]:
import numpy as np
import pandas as pd
import torch
import logging
from tqdm import tqdm
import math
import argparse
import os

In [None]:
# I use the dataset of writing prompts and stories from https://github.com/pytorch/fairseq/tree/master/examples/stories to fine-tune GPT-2, then use the fine-tuned model to generate stories.

In [None]:
!git clone https://github.com/huggingface/transformers
!pip install transformers/
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

In [None]:
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=88888)
parser.add_argument("--model_name", default="gpt2", type=str)
parser.add_argument("--max_seq_length", default=512, type=int)
parser.add_argument("--train_batch_size", default=4, type=int)
parser.add_argument("--valid_batch_size", default=4, type=int)
parser.add_argument("--num_train_epochs", default=1, type=int)
parser.add_argument("--warmup", default=0.1, type=float)
parser.add_argument("--learning_rate", default=5e-5, type=float)
parser.add_argument("--input_text_path", default='../input/story-text', type=str)
args, _ = parser.parse_known_args()

In [None]:
# Download Writing Prompt Dataset from https://www.kaggle.com/datasets/ratthachat/writing-prompts
# It is already divided into train, validation and test sets. But the prompts and the stories are in the seperate files.

In [None]:
# The train dataset is very large.
# For the training purpose, valid dataset is taken as train dataset, and the test dataset as valid dataset.

# Every line in the combined file includes the prompt and it's corresponding story concatenated together as: 'prompt + <sep> + story' for the input to the GPT-2 model.

In [None]:
# Function to combine prompts and stories
def combinetext(prompt, story):
    prompts = open(prompt, 'r', encoding='utf8').readlines()
    stories = open(story, 'r', encoding='utf8').readlines()
    assert len(prompts) == len(stories)
    combine = []
    for i in range(len(prompts)):
        combine.append(prompts[i].rstrip() + ' <sep> ' + " ".join(stories[i].split()[:300]))
    return combine

# Prprocessing the data (punctuations, etc)
def cleanpunctuation(s):
    for p in '!,.:;?':
        s = s.replace(' ' + p, p)
    s = s.replace(' ' + 'n\'t', 'n\'t')
    s = s.replace(' ' + '\'s', '\'s')
    s = s.replace(' ' + '\'re', '\'re')
    s = s.replace(' ' + '\'ve', '\'ve')
    s = s.replace(' ' + '\'ll', '\'ll')
    s = s.replace(' ' + '\'am', '\'am')
    s = s.replace(' ' + '\'m', '\'m')
    s = s.replace(' ' + '\' m', '\'m')
    s = s.replace(' ' + '\'m', '\'m')
    s = s.replace(' ' + '\' ve', '\'ve')
    s = s.replace(' ' + '\' s', '\'s')
    s = s.replace('<newline>', '\n')
    return s

# Combine and clean text for train and valid datasets
train_text = combinetext('valid.wp_source', 'valid.wp_target')
train_text = list(map(cleanpunctuation, train_text))

valid_text = combinetext('test.wp_source', 'test.wp_target')
valid_text = list(map(cleanpunctuation, valid_text))


In [None]:
#tokenize an dload the dataloader
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')  # USes BPE to tokenize text sequence by merging fequently co-occured byte pair in greedy method
tokenizer.pad_token=tokenizer.eos_token

inputs_train = tokenizer(train_text, padding=True,truncation=True,max_length=args.max_seq_length) # truncate the longer sequence and pad the shorter ones
inputs_valid=tokenizer(valid_text, padding=True,truncation=True,max_length=args.max_seq_length)

In [None]:
# creating labels sequence for every input_ids sequence
def create_labels(inputs):
    labels=[]
    for ids,attention_mask in zip(inputs['input_ids'],inputs['attention_mask']):
        label=ids.copy()
        real_len=sum(attention_mask)
        padding_len=len(attention_mask)-sum(attention_mask)
        label[:]=label[:real_len]+[-100]*padding_len  # rule out padding tokens by setting it to -100 (to avoid compute loss)
        labels.append(label)    # automatically shifts the labels to the right to match the inputs_ids
    inputs['labels']=labels

create_labels(inputs_train)
create_labels(inputs_valid)


In [None]:
class StoryDataset:
    def __init__(self, inputs):
        self.ids = inputs['input_ids']
        self.attention_mask = inputs['attention_mask']
        self.labels=inputs['labels']

    def __len__(self):
        return len(self.ids)  # total number of samples in the dataset

    def __getitem__(self, item):

        return [torch.tensor(self.ids[item], dtype=torch.long),
                torch.tensor(self.attention_mask[item], dtype=torch.long),
                torch.tensor(self.labels[item], dtype=torch.long)]


In [None]:
# train dataset
train_batch_size=args.train_batch_size
valid_batch_size=args.valid_batch_size
traindata=StoryDataset(inputs_train)
train_dataloader = torch.utils.data.DataLoader(
    traindata,
    shuffle=False,
    batch_size=train_batch_size)

# vaidation dataset
validdata=StoryDataset(inputs_valid)
valid_dataloader = torch.utils.data.DataLoader(
    validdata,
    shuffle=False,
    batch_size=valid_batch_size)

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')

Evaluate Model in Zero-Shot Setting (W/o Fine-tuning) and calculate Perplexity

In [None]:
model.to('cuda')   # Use GPU
model.eval()
eval_loss=[]
for inputs in tqdm(valid_dataloader, desc="eval"):
    d1,d2,d3=inputs
    d1=d1.to('cuda')
    d2=d2.to('cuda')
    d3=d3.to('cuda')

    with torch.no_grad():
        output = model(input_ids=d1, attention_mask=d2,labels=d3)
        batch_loss=output[0]
    eval_loss+=[batch_loss.cpu().item()]
    del batch_loss
eval_loss=np.mean(eval_loss)    # Evaluate model in zero-shot setting on validation set and calculate perplexity
perplexity=math.exp(eval_loss)
print(f'The average perplexity for valid dataset before fine-tuning is {perplexity}')

In [None]:
# START: COPIED FROM <emily2008/fine-tune-gpt-2-to-generate-stories>
# Using generate function from the model
def generate_story(prompt, k=0, p=0.7, output_length=500, temperature=1, num_return_sequences=1, repetition_penalty=1.0):
    print("----prompt----\n")
    print(prompt + "\n")

    encoded_prompt = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    model.to('cpu')
    model.eval()
    output_sequences = model.generate(
        input_ids=encoded_prompt,
        max_length=output_length,
        temperature=temperature, # control next token probability
        top_k=k,  # number of highest probability vocabulary tokens to keep for top-k-filtering
        top_p=p, # cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling
        repetition_penalty=repetition_penalty, # Between 1.0 and infinity. 1.0 means no penalty
        do_sample=True,  # if set to False greedy decoding is used
        num_return_sequences=num_return_sequences
    )

    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()

    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        print("---- STORY {} ----".format(generated_sequence_idx + 1))
        generated_sequence = generated_sequence.tolist()
        # Decode text
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
        # Remove all text after eos token
        text = text[: text.find(tokenizer.eos_token)]
        print(text)

Fine-Tuning the Model

In [None]:
# number of training samples = 15620

num_train_epochs = args.num_train_epochs
training_steps_per_epoch=len(train_dataloader)
total_num_training_steps = int(training_steps_per_epoch*num_train_epochs)
weight_decay=0
learning_rate=args.learning_rate
adam_epsilon=1e-8
warmup_steps=int(total_num_training_steps*args.warmup)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_num_training_steps
)

In [None]:
#Train the model on GPU
print("  Num Epochs = {}".format(num_train_epochs))
print(f"  Train_batch_size per device = {train_batch_size}")
print(f"  Valid_batch_size per device = {valid_batch_size}")
model.to('cuda')
for epoch in range(num_train_epochs):
    print(f"Start epoch{epoch+1} of {num_train_epochs}")
    train_loss=0
    epoch_iterator = tqdm(train_dataloader,desc='Iteration')
    model.train()
    model.zero_grad()
    for _, inputs in enumerate(epoch_iterator):
        d1,d2,d3=inputs
        d1=d1.to('cuda')
        d2=d2.to('cuda')
        d3=d3.to('cuda')
        output = model(input_ids=d1, attention_mask=d2,labels=d3)
        batch_loss=output[0]
        batch_loss.backward()
        optimizer.step()
        scheduler.step()
        model.zero_grad()
        train_loss+=batch_loss.item()
        epoch_iterator.set_description('(batch loss=%g)' % batch_loss.item())
        del batch_loss
    print(f'Average train loss per example={train_loss/training_steps_per_epoch} in epoch{epoch+1}')
    print(f'Starting evaluate after epoch {epoch+1}')
    eval_loss=[]
    model.eval()
    for inputs in tqdm(valid_dataloader, desc="eval"):
        d1,d2,d3=inputs
        d1=d1.to('cuda')
        d2=d2.to('cuda')
        d3=d3.to('cuda')
        with torch.no_grad():
            output = model(input_ids=d1, attention_mask=d2,labels=d3)
            batch_loss=output[0]
        eval_loss+=[batch_loss.cpu().item()]
        del batch_loss
    eval_loss=np.mean(eval_loss)
    perplexity=math.exp(eval_loss)
    print(f'Average valid loss per example={eval_loss} in epoch{epoch+1}')
    print(f'Perplextiy for valid dataset in epoch{epoch+1} is {perplexity}')

# Perplexity used as the metrics to check if fine-tuning imporves the performance or not
# END: COPIED FROM <emily2008/fine-tune-gpt-2-to-generate-stories>

Generate Stories using Fine-tuned Model (Example)

In [None]:
# Pair of three captions
s1 = generate_story(prompt = '[start] two street signs at an intersection of emerald and university [end]')
s2 = generate_story(prompt = '[start] a view of a stove that is built into the cabinets [end]')
s3 = generate_story(prompt = '[start] three adults watch a child holding a toy doll [end]')

In [None]:
generate_story(prompt = 'two street signs at an intersection of emerald and university,  a view of a stove that is built into the cabinets, three adults watch a child holding a toy doll')