## Generating note pairs with GPT2

We now use the model trained in `09-gpt-model-pairs.ipynb` to generate chords.

In [1]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config

import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from pathlib import Path
import numpy as np
import os
import re

### Load and prepare model

In [2]:
TOKENIZER_SAVEDIR = Path('tokenizers/pair-tokenizer/')
LM_MODEL_SAVEDIR = Path('models/gpt-pairs/')
TXT_FILES = Path('corpus-pairs-txt/')

In [3]:
tokenizer = GPT2TokenizerFast.from_pretrained(TOKENIZER_SAVEDIR, 
                                              bos_token="<start>", 
                                              eos_token="</start>",
                                              unk_token="<unk>")

In [4]:
pad_token_id, eos_token_id = tokenizer.encode('<pad> </start>')

In [5]:
config = GPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_head=12,
)
model = GPT2LMHeadModel(config=config).from_pretrained(str(LM_MODEL_SAVEDIR))

### Define helper functions

In [6]:
def split_into_chords(tokens):
    sequences_joined = []
    current = []
    inChord = False
    for token in tokens:
        if token == '<chord>':
            inChord = True
            current.append(token)
        elif token == '</chord>':
            current.append(token)
            sequences_joined.append(' '.join(current))
            current = []
            inChord = False
        else:
            if inChord:
                current.append(token)
            elif token == '<nochord>':
                sequences_joined.append(token)
            else:
                continue
    return sequences_joined

In [7]:
step_to_index = {
    'C': 0,
    'D': 2,
    'E': 4,
    'F': 5,
    'G': 7,
    'A': 9,
    'B': 11
}

def lettername_to_base_index(lettername):
    index = step_to_index[lettername[0]]
    if len(lettername) > 1:
        adjuster = 1 if lettername[1] == '#' else -1
        index += adjuster * (len(lettername) - 1)
    return index

def pitch_str_to_pitch_index(pitch_str):
    pitch, octave = pitch_str[:-1], int(pitch_str[-1])
    pitch_index = lettername_to_base_index(pitch) + octave * 12 - 9
    if pitch_index < 0 or pitch_index > 87:
        return None
    return pitch_index

def pitch_to_index_soft(pitch):
    if pitch[0] == '<':
        return 0
    else:
        return pitch_str_to_pitch_index(pitch)

In [13]:
def parse_chords(chord):
    ''' Given a bunch of note pairs, just take the unique notes given '''
    notes_dirty = re.split(' |,', chord)
    notes = list(set([note for note in notes_dirty if note and note not in ['<chord>', '</chord>']]))
    return sorted(notes, key=pitch_to_index_soft)

### Generate example outputs

In [10]:
input_tensors = tokenizer.encode('<start> <chord>', return_tensors="pt")
output_tokens = model.generate(input_tensors, 
                               pad_token_id=pad_token_id,
                               eos_token_id=eos_token_id,
                               temperature=1,
                               max_length=256,
                               do_sample=True)[0]
output_tokens = tokenizer.decode(output_tokens).split()

In [12]:
chords = split_into_chords(output_tokens)
for chord in chords:
    print(' '.join(parse_chords(chord)))

F3 C4 F4 A4
<nochord>
B-2 A-3 D-4 F4 B-4
<nochord>
B-2 A-3 D-4 F4 D-5
B-2 D-4 F4 D-5
B-2 A-3 D-4 F4 D-5
B-2 A-3 F4
B-2 A-3 D-5


In [14]:
Path('embeddings/gpt-pairs/').mkdir(exist_ok=True)

In [15]:
embeddings = model.transformer.wte.weight
with open('embeddings/gpt-pairs/embedding.tsv', 'w') as f:
    for row in tqdm(embeddings):
        f.write('\t'.join([str(col.item()) for col in row]) + "\n")

100%|███████████████████████████████████████████████████████████████████████████████████████████████| 3329/3329 [00:06<00:00, 530.51it/s]


In [16]:
with open('embeddings/gpt-pairs/vocab.tsv', 'w') as f:
    for i in range(len(embeddings)):
        f.write(tokenizer.decode([i]) + '\n')