In [1]:
from harmony_tokenizers_m21 import ChordSymbolTokenizer, PitchClassTokenizer, MelodyPitchTokenizer, MergedMelHarmTokenizer
from data_utils import StructGPTMelHarmDataset, GenCollator
from torch.utils.data import DataLoader
import torch
from transformers import AutoConfig, GPT2LMHeadModel
import numpy as np

from a_star import AStarGPT

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
test_dir = '/media/maindisk/maximos/data/hooktheory_test'

cstok = ChordSymbolTokenizer()
pctok = PitchClassTokenizer()
meltok = MelodyPitchTokenizer()
# tokenizer = MergedMelHarmTokenizer(meltok, cstok)
tokenizer = MergedMelHarmTokenizer(meltok, pctok)

test_dataset = StructGPTMelHarmDataset(test_dir, tokenizer, max_length=512, num_bars=16, return_harmonization_labels=True)

collator = GenCollator(tokenizer)

trainloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [3]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer.vocab),
    n_positions=512,
    n_layer=8,
    n_head=8,
    pad_token_id=tokenizer.vocab[tokenizer.pad_token],
    bos_token_id=tokenizer.vocab[tokenizer.bos_token],
    eos_token_id=tokenizer.vocab[tokenizer.eos_token],
    n_embd=512
)

model = GPT2LMHeadModel(config)

# model_path = 'saved_models/gpt/ChordSymbolTokenizer/ChordSymbolTokenizer.pt'
model_path = 'saved_models/gpt/PitchClassTokenizer/PitchClassTokenizer.pt'

# device_name = 'cuda:0'
device_name = 'cpu'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)

checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(211, 512)
    (wpe): Embedding(512, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=1536, nx=512)
          (c_proj): Conv1D(nf=512, nx=512)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=2048, nx=512)
          (c_proj): Conv1D(nf=512, nx=2048)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=211, bias=False)
)

In [4]:
batch = next(iter(trainloader))

  return self.iter().getElementsByClass(classFilterList)


In [5]:
all_ids = batch['input_ids']
print(all_ids)
melody_end_index = all_ids[0].tolist().index( tokenizer.vocab['</m>'] )
harmony_start_index = all_ids[0].tolist().index( tokenizer.vocab['<h>'] )
print(melody_end_index)
constraint_ids = all_ids[0][melody_end_index:]
print(constraint_ids)
# start_harmony_position = np.where( all_ids == harmony_start_index )[0][0]
input_ids = all_ids[0][:(harmony_start_index+2)]
input_ids = input_ids.reshape(1, -1)
print(input_ids.shape)

tensor([[  2,   6, 183,  98,  65, 106,  65, 114,  63, 122,  63,   6,  98,  62,
         114,  61, 122,  58,   6,  98,  65, 106,  65, 114,  61, 122,  61,   6,
          98,  61, 106,  61, 114,  61, 122,  60,   6,  98,  65, 106,  65, 114,
          63, 122,  63,   6,  98,  62, 114,  61, 122,  58,   6,  98,  65, 106,
          65, 114,  61, 122,  61,   6,  98,  61, 106,  61, 114,  61, 122,   4,
         126,  61, 128,  63,   8,   6,   9,   6,   9,   6,   9,   6,   9,   6,
           9,   6,   9,   6,   9,   6,  98, 199, 201, 204, 208,   9,   3,   7,
           6,  98, 199, 203, 204, 208, 114, 201, 204, 208,   6,  98, 200, 203,
         208, 114, 199, 203, 208, 122, 201, 205, 208,   6,  98, 199, 203, 204,
         208, 114, 199, 203, 206,   6,  98, 199, 201, 204, 208,   6,  98, 199,
         203, 204, 208, 114, 201, 204, 208,   6,  98, 200, 203, 208, 114, 199,
         203, 208, 122, 201, 205, 208,   6,  98, 199, 203, 204, 208, 114, 199,
         203, 206,   6,  98, 199, 201, 204, 208,   3

In [6]:
astar = AStarGPT( model, tokenizer, input_ids, constraint_ids, max_length=512, beam_width=20, lookahead_k=10 )

In [7]:
print(astar.constraint_bar)
print(astar.position_token)
print(astar.chord_tokens)

8
position_0x00
['chord_pc_0', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9']


In [8]:
generated_ids, finished = astar.decode()
generated_tokens = []

children
constraint_checker: ['<h>', '<bar>', 'position_0x00']
num_bars: 1 - num_tokens: 3
new_logprob: -0.011933247558772564
constraint_checker: ['<h>', '<bar>', '<bar>']
num_bars: 2 - num_tokens: 3
new_logprob: -4.847202777862549
constraint_checker: ['<h>', '<bar>', 'position_1x00']
num_bars: 1 - num_tokens: 3
new_logprob: -6.219756126403809
constraint_checker: ['<h>', '<bar>', 'position_2x00']
num_bars: 1 - num_tokens: 3
new_logprob: -6.772315979003906
constraint_checker: ['<h>', '<bar>', 'position_0x50']
num_bars: 1 - num_tokens: 3
new_logprob: -8.349837303161621
constraint_checker: ['<h>', '<bar>', 'position_3x00']
num_bars: 1 - num_tokens: 3
new_logprob: -8.549269676208496
constraint_checker: ['<h>', '<bar>', '</s>']
num_bars: 1 - num_tokens: 3
constraint_checker: ['<h>', '<bar>', 'position_3x50']
num_bars: 1 - num_tokens: 3
new_logprob: -10.112442016601562
constraint_checker: ['<h>', '<bar>', 'position_0x25']
num_bars: 1 - num_tokens: 3
new_logprob: -10.242109298706055
constrain

In [9]:
print(generated_ids.shape)

torch.Size([1, 148])


In [10]:
print(len(finished))
print(finished[0].tokens)

1
tensor([[  2,   6, 183,  98,  65, 106,  65, 114,  63, 122,  63,   6,  98,  62,
         114,  61, 122,  58,   6,  98,  65, 106,  65, 114,  61, 122,  61,   6,
          98,  61, 106,  61, 114,  61, 122,  60,   6,  98,  65, 106,  65, 114,
          63, 122,  63,   6,  98,  62, 114,  61, 122,  58,   6,  98,  65, 106,
          65, 114,  61, 122,  61,   6,  98,  61, 106,  61, 114,  61, 122,   4,
         126,  61, 128,  63,   8,   6,   9,   6,   9,   6,   9,   6,   9,   6,
           9,   6,   9,   6,   9,   6,  98, 199, 201, 204, 208,   9,   3,   7,
           6,  98, 203, 206, 210,   6,  98, 200, 203, 208,   6,  98, 200, 203,
         208,   6,  98, 199, 201, 204, 208,   6,  98, 199, 201, 204, 208,   6,
          98, 203, 206, 210,   6,  98, 200, 203, 208,   6,  98, 199, 201, 204,
         208,   6,  98, 199, 201, 204, 208,   3]])


In [11]:
for i in generated_ids[0]:
    generated_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(generated_tokens[harmony_start_index:])

['<h>', '<bar>', 'position_0x00', 'chord_pc_4', 'chord_pc_7', 'chord_pc_11', '<bar>', 'position_0x00', 'chord_pc_1', 'chord_pc_4', 'chord_pc_9', '<bar>', 'position_0x00', 'chord_pc_1', 'chord_pc_4', 'chord_pc_9', '<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', '<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', '<bar>', 'position_0x00', 'chord_pc_4', 'chord_pc_7', 'chord_pc_11', '<bar>', 'position_0x00', 'chord_pc_1', 'chord_pc_4', 'chord_pc_9', '<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', '<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', '</s>']


In [12]:
t = generated_tokens[harmony_start_index:]
line = []
for i in range(len(t)):
    line.append(t[i])
    if i+1 < len(t) and 'bar' in t[i+1]:
        print(line)
        line = []
print(line)

['<h>']
['<bar>', 'position_0x00', 'chord_pc_4', 'chord_pc_7', 'chord_pc_11']
['<bar>', 'position_0x00', 'chord_pc_1', 'chord_pc_4', 'chord_pc_9']
['<bar>', 'position_0x00', 'chord_pc_1', 'chord_pc_4', 'chord_pc_9']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9']
['<bar>', 'position_0x00', 'chord_pc_4', 'chord_pc_7', 'chord_pc_11']
['<bar>', 'position_0x00', 'chord_pc_1', 'chord_pc_4', 'chord_pc_9']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', '</s>']
