## 9. GPT for learning note pairs

Now we apply Huggingface's GPT2 model to attempt to learn the chord representation using note pairs. Recall that the data format looks like `"<chord> C4,E4 C4,G4 C4,B4 E4,G4 E4,B4 G4,B4 </chord>"`, which in this case represents the chord consisting of the notes `C4, E4, G4, B4`. 

### Load necessary libraries and objects

In [3]:
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from pathlib import Path
import os
import json
import itertools
import numpy as np

In [4]:
TOKENIZER_SAVEDIR = Path('tokenizers/pair-tokenizer')
LM_MODEL_SAVEDIR = Path('models/gpt-pairs/')
LM_MODEL_SAVEDIR.mkdir(exist_ok=True, parents=True)
TXT_LOCATION = Path('corpus-pairs-txt/')

In [None]:
torch.cuda.set_device(0)

In [5]:
tokenizer = GPT2TokenizerFast.from_pretrained(TOKENIZER_SAVEDIR, 
                                              bos_token="<start>", 
                                              eos_token="</start>",
                                              unk_token="<unk>")

In [6]:
config = GPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_head=12,
)

In [7]:
# Define the model
model = GPT2LMHeadModel(config=config)
print('Num parameters:', model.num_parameters())

Num parameters: 88399104


In [8]:
def split_into_chords(tokens, boc_token, eoc_token):
    ''' Split tokens into lists of chords '''
    sequences_joined = []
    current = []
    inChord = False
    for token in tokens:
        if token == boc_token:
            inChord = True
            current.append(token)
        elif token == eoc_token:
            current.append(token)
            sequences_joined.append(current)
            current = []
            inChord = False
        else:
            if inChord:
                current.append(token)
            else:
                sequences_joined.append([token])
    return sequences_joined

In [9]:
class CustomDataset(Dataset):
    ''' Create a torch Dataset that tokenizes input examples and pads to max length
        Takes care to not cut off in the middle of a chord'''
    def __init__(self, src_files, tokenizer, num_chords):
        self.max_len = 0
        nc = num_chords
        boc_token, eoc_token, self.pad_token = tokenizer.encode('<chord> </chord> <pad>')
        self.examples = []
        for src_file in tqdm(src_files):
            words = src_file.read_text(encoding="utf-8")
            words = '<start> ' + words + ' </start>'
            tokenized = tokenizer.encode(words)
            chords = split_into_chords(tokenized, boc_token, eoc_token)
            chunks = [torch.tensor(list(itertools.chain(*chords[i:i+nc]))) for i in range(0, len(chords) - nc + 1, nc // 2)]
            for example in chunks:
                self.examples.append(example)
                self.max_len = max(self.max_len, len(example))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        tensor = torch.ones(self.max_len, dtype=torch.int64) * self.pad_token
        example = self.examples[i]
        tensor[:len(example)] = example
        return tensor

Create dataset:

In [10]:
src_files = list(Path(TXT_LOCATION).glob("**/*.txt"))
dataset = CustomDataset(src_files, tokenizer, num_chords=10)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 284/284 [00:05<00:00, 55.63it/s]


Sanity check:

In [11]:
tokenizer.decode(dataset.__getitem__(104))

'<chord> C3,C4 </chord> <chord> A-3,F4 </chord> <nochord> <chord> B-2,A-3 B-2,B-3 B-2,D-4 B-2,G-4 A-3,B-3 A-3,D-4 A-3,G-4 B-3,D-4 B-3,G-4 D-4,G-4 </chord> <nochord> <chord> A-3,D4 A-3,G-4 A-3,D-5 A-3,G-5 D4,G-4 D4,D-5 D4,G-5 G-4,D-5 G-4,G-5 D-5,G-5 </chord> <nochord> <chord> A-3,D4 A-3,G-4 A-3,B-4 A-3,D-5 A-3,G-5 D4,G-4 D4,B-4 D4,D-5 D4,G-5 G-4,B-4 G-4,D-5 G-4,G-5 B-4,D-5 B-4,G-5 D-5,G-5 </chord> <nochord> <chord> A-3,D4 A-3,G-4 A-3,B4 A-3,E-5 A-3,G-5 A-3,B5 D4,G-4 D4,B4 D4,E-5 D4,G-5 D4,B5 G-4,B4 G-4,E-5 G-4,G-5 G-4,B5 B4,E-5 B4,G-5 B4,B5 E-5,G-5 E-5,B5 G-5,B5 </chord> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pa

### Training

Now, we simply create a data collator, define the training arguments, train, and save the model.

In [77]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [79]:
training_args = TrainingArguments(
    output_dir=LM_MODEL_SAVEDIR,
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=32,
    save_steps=10000,
    logging_steps=2000,
    save_total_limit=1,
    prediction_loss_only=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

In [None]:
ret = trainer.train()