In [None]:
import pandas as pd
import re
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, set_seed
import transformers
import accelerate
from torch.utils.data import Dataset, DataLoader

In [None]:
data = pd.read_csv("data/goodreads_books.csv", converters={'COLUMN_NAME': pd.eval})

In [None]:
data = data.filter(["title", "rating_count", "average_rating", "genre_and_votes"], axis=1)

In [None]:
def genre_break(genre):
    for index, item in enumerate(genre):
        new_item = re.sub('\s[0-9]+', '', item)
        genre[index] = new_item
    
    return genre

In [None]:
data["genre"] = data["genre_and_votes"].str.split(", ")
data.drop(axis=1, columns=["genre_and_votes"], inplace=True)

In [None]:
data.dropna(inplace=True, axis=0)

In [None]:
data.info()

In [None]:
data["genre"].apply(lambda item: genre_break(item))

In [None]:
data

## Fine-Tuning Set-Up

In [None]:
class BookTitles(Dataset):
    def __init__(self, content, gpt2_type="gpt2", max_length=1024):

        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        self.title = []

        for row in content:
          self.title.append(torch.tensor(
                self.tokenizer.encode(f'<startoftext>Tags: {row[3]}\nTitle: {row[0]}<endoftext>', truncation = True, max_length = max_length)
            ))
        self.titles_count = len(self.title)
        
    def __len__(self):
        return self.titles_count

    def __getitem__(self, item):
        return self.titles_count[item]

In [None]:
train_set = BookTitles(data)

## Model Set-Up

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token = '<startoftext>', eos_token = '<endoftext>', pad_token = '<pad>')
model = GPT2LMHeadModel.from_pretrained('gpt2').cuda()
model.resize_token_embeddings(len(tokenizer))

In [None]:
set_seed(32)
text = "Generate a book title using the word 'Death'."
encoded_input = tokenizer.encode(text, return_tensors='pt')
outputs = model.generate(encoded_input, max_length = 40, no_repeat_ngram_size = 2, early_stopping = True, num_return_sequences = 5, num_beams=5,)