In [25]:
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, TrainingArguments, Trainer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
import torch.nn.functional as F
import csv

In [None]:
import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from string import punctuation

nltk.download('punkt')

Dataset: https://www.kaggle.com/datasets/aditya2803/eminem-lyrics

In [None]:
main_data = pd.read_csv("Eminem_Lyrics.csv", sep='\t', comment='#', encoding = "ISO-8859-1")
main_data.head()

In [28]:
def intro_words_filter(text):
  
  prepared_text = re.sub(r'\[([^]]*)]', '', text)
  prepared_text = prepared_text.replace("  ", " ")

  return prepared_text

In [None]:
main_data["Prepared_Lyrics"] = main_data["Lyrics"].apply(lambda text: intro_words_filter(text))

main_data

In [30]:
def filtring_punct(text):
  elements = [element if element not in punctuation else '' for element in text]
  return ''.join(elements)

In [31]:
lemmatizer = WordNetLemmatizer()
word_tokenizer = word_tokenize  

In [32]:
main_data["Prepared_Lyrics"] = main_data["Prepared_Lyrics"].apply(lambda text: filtring_punct(text))

In [None]:
lyrics = main_data["Prepared_Lyrics"]
lyrics = lyrics.apply(lambda x: x.replace('\n', ' ').strip())
lyrics

In [None]:
torch.manual_seed(42)

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium', bos_token='<|startoftext|>',
                                          eos_token='<|endoftext|>', pad_token='<|pad|>')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium').cuda()
model.resize_token_embeddings(len(tokenizer))

In [None]:
lyrics = lyrics.apply(lambda text: " ".join(text.split(" ")[0:750]))
lyrics

In [None]:
max_length = max([len(tokenizer.encode(lyric)) for lyric in lyrics])
max_length

In [38]:
class EminemDataset(Dataset):
    def __init__(self, txt_list, tokenizer, max_length):
        self.input_ids = []
        self.attn_masks = []
        self.labels = []
        for txt in txt_list:
            encodings_dict = tokenizer('<|startoftext|>' + txt + '<|endoftext|>', truncation=True,
                                       max_length=max_length, padding="max_length")
            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx]

In [39]:
from torch.utils.data import random_split

In [40]:
dataset = EminemDataset(lyrics, tokenizer, max_length=max_length)
train_size = int(0.9 * len(dataset))
train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])

In [None]:
import gc
gc.collect()

In [42]:
torch.cuda.empty_cache()

In [None]:
training_args = TrainingArguments(output_dir='./results', num_train_epochs=1, logging_steps=100, save_steps=5000,
                                  per_device_train_batch_size=1, per_device_eval_batch_size=1,
                                  warmup_steps=10, weight_decay=0.05, logging_dir='./logs', report_to = 'none')

In [None]:
Trainer(model=model,  args=training_args, train_dataset=train_dataset, 
        eval_dataset=val_dataset, data_collator=lambda data: {'input_ids': torch.stack([f[0] for f in data]),
                                                              'attention_mask': torch.stack([f[1] for f in data]),
                                                              'labels': torch.stack([f[0] for f in data])}).train()

In [None]:
generated = tokenizer("<|startoftext|>Look, I was gonna go easy on you not to hurt your feelings", return_tensors="pt").input_ids.cuda()

In [None]:
sample_outputs = model.generate(generated, do_sample=True, top_k=50, 
                                max_length=50, top_p=0.95, temperature=1.9, num_return_sequences=20)

In [None]:
for i, sample_output in enumerate(sample_outputs):
    print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))