## 11. GPT for learning chords

Now we apply Huggingface's GPT2 model to attempt to learn the chord representation using the cleaned chords we generated in `08-make-cleaned-chord-dataset.ipynb`. Each token is a 36-long string with `0`s and `1`s representing the active notes.

### Load necessary libraries and objects

In [1]:
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from pathlib import Path
import wandb
import os
import json
import itertools
import numpy as np

In [2]:
TOKENIZER_SAVEDIR = Path('tokenizers/chord-augmented-tokenizer')
LM_MODEL_SAVEDIR = Path('models/gpt-chords-augmented/')
LM_MODEL_SAVEDIR.mkdir(exist_ok=True, parents=True)
TXT_LOCATION = Path('chords-txt-augmented/')

In [None]:
torch.cuda.set_device(0)

In [3]:
tokenizer = GPT2TokenizerFast.from_pretrained(TOKENIZER_SAVEDIR, 
                                              bos_token="<start>", 
                                              eos_token="</start>",
                                              unk_token="<unk>")

Sanity check tokenizer:

In [4]:
sorted(tokenizer.vocab.items(), key=lambda x: x[1])[:10]

[('<start>', 0),
 ('</start>', 1),
 ('<pad>', 2),
 ('<unk>', 3),
 ('000000010010010010000000000000000000', 4),
 ('000000000010010010000000000000000000', 5),
 ('000000010000010010000000000000000000', 6),
 ('000000010000000000010000000000010000', 7),
 ('000010010010010010000000000000000000', 8),
 ('000000010010010010010000000000000000', 9)]

In [5]:
config = GPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_head=12,
)

In [6]:
model = GPT2LMHeadModel(config=config)
print('Num parameters:', model.num_parameters())

Num parameters: 90743808


### Define dataset

In [7]:
class CustomDataset(Dataset):
    def __init__(self, src_files, tokenizer, max_len):
        self.examples = []
        self.pad_token = tokenizer.encode('<pad>')[0]
        for src_file in tqdm(src_files):
            words = src_file.read_text(encoding="utf-8")
            words = '<start> ' + words + ' </start>'
            tokenized = tokenizer.encode(words)
            for i in range(0, len(tokenized), max_len):
                chunk = tokenized[i:i + max_len]
                tensor = torch.ones(max_len, dtype=torch.int64) * self.pad_token
                tensor[:len(chunk)] = torch.tensor(chunk)
                self.examples.append(tensor)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        return self.examples[i]

In [8]:
src_files = list(Path(TXT_LOCATION).glob("**/*.txt"))
dataset = CustomDataset(src_files, tokenizer, 128)

100%|███████████████████████████████████████████████████████████████████████████████████████████████| 3408/3408 [00:09<00:00, 365.04it/s]


Sanity check an example from the dataset:

In [10]:
tokenizer.decode(dataset.__getitem__(108))

'000000000000010010000010010000000000 000010000000000010000000000000010000 000000010000010000010100000000000000 010000000010000000000000010000000000 000000000000000010000010010000010000 000000000000000010010000000010010000 010000010010000000000000000000000000 000000000000010000010010000000000000 010000010010010010000000000000000000 000000000000000010000000000010010000 000010000000000000000000000010000010 000000000000010000010010010000000000 000000010000010000010000000000000000 010000010000010000010000000000000000 000000000010010010000010000000000000 000010010000000000000000000010000000 000010000000000000000100000010000000 000000010000000010000010010000000000 000010010000000010000000000000000000 000000000000000000000010000100010010 000010000000010000000000000000010000 000000010000010000000000000010000000 000000000000010000010000010000000000 000000000000000010000010010000000000 010000010000010000010010000000000000 000000000000010000000100010010000000 000010000000000000010000000000010000 

### Training

Now, we simply create a data collator, define the training arguments, train, and save the model.

In [20]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [21]:
training_args = TrainingArguments(
    output_dir=LM_MODEL_SAVEDIR,
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=32,
    save_steps=10000,
    logging_steps=2000,
    save_total_limit=1,
    prediction_loss_only=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

In [None]:
ret = trainer.train()

In [None]:
trainer.save_model(LM_MODEL_SAVEDIR)