In [81]:
import pandas as pd
import transformers
import torch 
import math
from itertools import chain

In [85]:
GENRE = 'Rap'
# use all songs
N_SONGS = 1000000000

In [86]:
# read in  data
df = pd.read_csv('clean_data.csv')
# filter data by genre
genre_df = df[df.genres.apply(lambda x: GENRE in x)]

# save lyrics
lyrics = genre_df.lyrics.values
# split lyrics by line
lines = [song.split('\n') for song in lyrics[0:N_SONGS]]
lines = list(chain.from_iterable(lines))
# remove empty lines
lines = [line for line in lines if len(line)>0]


In [28]:
# name model
MODEL_NAME = 'gpt2'
pipe = transformers.pipeline(task='text-generation', model=MODEL_NAME, device='cpu')

In [29]:
# set model 
model = transformers.GPT2LMHeadModel.from_pretrained(MODEL_NAME)

# set model configurations
config = transformers.GPT2Config.from_pretrained(MODEL_NAME)
config.do_sample = True
config.max_length = 100
model = transformers.GPT2LMHeadModel.from_pretrained(MODEL_NAME,
                                                      config=config)

In [30]:
# set tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)

In [31]:
# encode training data
enc_tokens = [tokenizer(line, return_tensors='pt') for line in lines]
enc_tokens = [enc['input_ids'].tolist()[0] for enc in enc_tokens]

In [32]:
class MyDset(torch.utils.data.Dataset):
     """A custom dataset"""
     def __init__(self, data: list[list[int]]):
         self.data = []
         for d in data:
             input_ids = torch.tensor(d, dtype=torch.int64)
             attention_mask = torch.ones(len(d), dtype=torch.int64)
             self.data.append({'input_ids': input_ids,
                  'attention_mask': attention_mask, 'labels': input_ids})
 
     def __len__(self):
         return len(self.data)
 
     def __getitem__(self, idx: int):
         return self.data[idx]


In [33]:
training_args = transformers.TrainingArguments(
     output_dir="idiot_save/",
     learning_rate=1e-3,
     per_device_train_batch_size=1,
     per_device_eval_batch_size=1,
     num_train_epochs=1,
     evaluation_strategy='epoch',
     save_strategy='no',
 )

In [34]:
# create training, valdation, and testing data intervals
END1 = math.ceil(len(enc_tokens)*0.8)
END2 = END1 + math.ceil(len(enc_tokens)*0.1)

# create training, valdation, and testing data
dset_train = MyDset(enc_tokens[0:END1])
dset_val = MyDset(enc_tokens[END1:END2])
dset_test = MyDset(enc_tokens[END2:])

In [11]:
# train model
trainer = transformers.Trainer(
     model=model,
     args=training_args,
     train_dataset=dset_train,
     eval_dataset=dset_val,
 )

trainer.train()

                                               
100%|██████████| 38/38 [01:16<00:00,  2.01s/it]


{'eval_loss': 5.124989986419678, 'eval_runtime': 1.054, 'eval_samples_per_second': 8.539, 'eval_steps_per_second': 8.539, 'epoch': 1.0}
{'train_runtime': 76.2414, 'train_samples_per_second': 0.498, 'train_steps_per_second': 0.498, 'train_loss': 6.826578240645559, 'epoch': 1.0}


TrainOutput(global_step=38, training_loss=6.826578240645559, metrics={'train_runtime': 76.2414, 'train_samples_per_second': 0.498, 'train_steps_per_second': 0.498, 'train_loss': 6.826578240645559, 'epoch': 1.0})

In [76]:
inputs = tokenizer(" ", return_tensors="pt")
generation_output = model.generate(**inputs, pad_token_id=50256)

In [56]:
tokenizer.batch_decode(generation_output)

['I am vernacular-minded that our culture has an inherent moral dimension to it.\n\n\nEven before we could talk about this, the question of morality and morality should have been completely central to the thinking of the Western civilization. We certainly know that morality is a product of human nature, and the very qualities of moral behavior are important parts of it. It is up to us to explain it and determine the course of action toward it.\n\n\nBut what should the moral character of a man']

In [77]:
lines

['Go, go, go, go',
 'Go, go, go shawty',
 "It's your birthday",
 "We gon' party like it's your birthday",
 "We gon' sip Bacardi like it's yo birthday",
 "And you know we don't give a fuck",
 "It's not your birthday!",
 '[Chorus (2x)]',
 'You can find me in the club,',
 'bottle full of Bud',
 'Look mami I got the X',
 'if you into taking drugs',
 "I'm into having sex, I ain't into making love",
 'So come give me a hug if you into getting rubbed',
 'When I pull out up front, you see the Benz on dubs',
 "When I roll 20 deep, it's 20 knives in the club",
 'Niggas heard I fuck with Dre,',
 'now they wanna show me love',
 'When you sell like Eminem,',
 'and the hoes they wanna fuck',
 "But homie ain't nothing change ho's down, G's up",
 'I see Xzibit in the Cut',
 'that nigga roll that weed up',
 'If you watch how I move',
 "you'll mistake me for a playa or pimp",
 'Been hit wit a few shells but I dont walk wit a limp',
 'In the hood, in L.A. they saying "50 you hot"',
 'They like me,',
 "I 

In [55]:
generation_output

tensor([[   40,   716,   220,   933, 12754,    12, 14543,   326,   674,  3968,
           468,   281, 11519,  6573, 15793,   284,   340,    13,   628,   198,
          6104,   878,   356,   714,  1561,   546,   428,    11,   262,  1808,
           286, 18016,   290, 18016,   815,   423,   587,  3190,  4318,   284,
           262,  3612,   286,   262,  4885, 14355,    13,   775,  3729,   760,
           326, 18016,   318,   257,  1720,   286,  1692,  3450,    11,   290,
           262,   845, 14482,   286,  6573,  4069,   389,  1593,  3354,   286,
           340,    13,   632,   318,   510,   284,   514,   284,  4727,   340,
           290,  5004,   262,  1781,   286,  2223,  3812,   340,    13,   628,
           198,  1537,   644,   815,   262,  6573,  2095,   286,   257,   582]])