In [1]:
import wandb
from torch.utils.data import Dataset, random_split
from transformers import GPT2Tokenizer, TrainingArguments, Trainer, GPT2LMHeadModel
import pickle
import torch
from tqdm.auto import tqdm
from IPython.display import Markdown, display

In [2]:
class LyricsDataset(Dataset):
    def __init__(self, txt_list, tokenizer, max_length):
        self.input_ids = []
        self.attn_masks = []
        self.labels = []
        for txt in tqdm(txt_list):
            encodings_dict = tokenizer('<|startoftext|>' + txt + '<|endoftext|>', truncation=True,
                                       max_length=max_length, padding="max_length")
            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx]

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
max_length = 1024

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium', bos_token='<|startoftext|>',
                                          eos_token='<|endoftext|>', pad_token='<|pad|>')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium').to(device)
model.resize_token_embeddings(len(tokenizer))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Embedding(50259, 1024)

In [5]:
with open('data/artist_lyrics.p', 'rb') as f:
      artist_lyrics = pickle.load(f)

In [6]:
lyrics = []
for songs in artist_lyrics.values():
      lyrics.extend(songs)

In [7]:
%env WANDB_PROJECT=lyrics_gen
wandb.login()


env: WANDB_PROJECT=lyrics_gen


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjanithw[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
dataset = LyricsDataset(lyrics, tokenizer, max_length=max_length)
train_size = int(0.9 * len(dataset))
train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])

training_args = TrainingArguments(output_dir='/scratch/jnw301/lyrics_results/', num_train_epochs=8, logging_steps=100, save_steps=10000,
                                  per_device_train_batch_size=1, per_device_eval_batch_size=1,report_to="wandb", 
                                  warmup_steps=10, weight_decay=0.05, logging_dir='./logs')

training_op = Trainer(model=model,  args=training_args, train_dataset=train_dataset, 
        eval_dataset=val_dataset, data_collator=lambda data: {'input_ids': torch.stack([f[0] for f in data]),
                                                              'attention_mask': torch.stack([f[1] for f in data]),
                                                              'labels': torch.stack([f[0] for f in data])}).train()

  0%|          | 0/3851 [00:00<?, ?it/s]

PyTorch: setting up devices
***** Running training *****
  Num examples = 3465
  Num Epochs = 8
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 27720
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss
100,0.7589
200,0.8091
300,0.7669
400,0.7684
500,0.7484
600,0.8278
700,0.7405
800,0.8162
900,0.7437
1000,0.8121


Saving model checkpoint to /scratch/jnw301/lyrics_results/checkpoint-10000
Configuration saved in /scratch/jnw301/lyrics_results/checkpoint-10000/config.json
Model weights saved in /scratch/jnw301/lyrics_results/checkpoint-10000/pytorch_model.bin


KeyboardInterrupt: 

In [4]:
model2 = GPT2LMHeadModel.from_pretrained('/scratch/jnw301/lyrics_results/checkpoint-10000').to(device)