In [68]:
from harmony_tokenizers_m21 import ChordSymbolTokenizer, PitchClassTokenizer, MelodyPitchTokenizer, MergedMelHarmTokenizer
from data_utils import StructGPTMelHarmDataset, GenCollator
from torch.utils.data import DataLoader
import torch
from transformers import AutoConfig, GPT2LMHeadModel
import numpy as np

from a_star import AStarGPT

In [69]:
test_dir = '/media/maindisk/maximos/data/hooktheory_test'

cstok = ChordSymbolTokenizer()
pctok = PitchClassTokenizer()
meltok = MelodyPitchTokenizer()
tokenizer = MergedMelHarmTokenizer(meltok, cstok)
# tokenizer = MergedMelHarmTokenizer(meltok, pctok)

test_dataset = StructGPTMelHarmDataset(test_dir, tokenizer, max_length=512, num_bars=16, return_harmonization_labels=True)

collator = GenCollator(tokenizer)

trainloader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collator)

In [70]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer.vocab),
    n_positions=512,
    n_layer=8,
    n_head=8,
    pad_token_id=tokenizer.vocab[tokenizer.pad_token],
    bos_token_id=tokenizer.vocab[tokenizer.bos_token],
    eos_token_id=tokenizer.vocab[tokenizer.eos_token],
    n_embd=512
)

model = GPT2LMHeadModel(config)

model_path = 'saved_models/gpt/ChordSymbolTokenizer/ChordSymbolTokenizer.pt'
# model_path = 'saved_models/gpt/PitchClassTokenizer/PitchClassTokenizer.pt'

# device_name = 'cuda:0'
device_name = 'cpu'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)

checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(547, 512)
    (wpe): Embedding(512, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=1536, nx=512)
          (c_proj): Conv1D(nf=512, nx=512)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=2048, nx=512)
          (c_proj): Conv1D(nf=512, nx=2048)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=547, bias=False)
)

In [71]:
# batch = next(iter(trainloader))
i = 0
for batch in trainloader:
    if i == 54:
        break
    i += 1

  return self.iter().getElementsByClass(classFilterList)


In [72]:
all_ids = batch['input_ids']
print(all_ids)
melody_end_index = all_ids[0].tolist().index( tokenizer.vocab['</m>'] )
harmony_start_index = all_ids[0].tolist().index( tokenizer.vocab['<h>'] )
print(melody_end_index)
# constraint_ids = all_ids[0][melody_end_index:]
# start_harmony_position = np.where( all_ids == harmony_start_index )[0][0]
input_ids = all_ids[0][:(harmony_start_index+2)]
constraint_ids = input_ids[melody_end_index:(harmony_start_index+1)]
input_ids = input_ids.reshape(1, -1)
print(input_ids)
print(constraint_ids)

tensor([[  2,   6, 183,  98,  67, 114,  65,   6,  98,  58, 114,  65,   6,  98,
          61, 114,  58,   6,  98,  56, 114,  53,   6,  98,  65, 114,  65,   6,
          98,  65, 114,  65,   6,  98,  65, 114,  65,   6,  98,  65, 114,  65,
           6,  98,  67, 114,  65,   6,  98,  58, 114,  65,   6,  98,  61, 114,
          58,   6,  98,  56, 114,  53,   6,  98,  63, 114,  63,   6,  98,  63,
         114,  63,   6,  98,  63, 114,  63,   8,   6,   9,   6,   9,   6,   9,
           6,  98, 499,   9,   6,   9,   6,   9,   6,   9,   6,   9,   6,   9,
           6,   9,   6,   9,   6,   9,   6,   9,   6,   9,   6,   9,   3,   7,
           6,  98, 215, 114, 533,   6,  98, 419, 114, 272,   6,  98, 199, 114,
         419,   6,  98, 499, 114, 468,   6,  98, 533, 114, 533,   6,  98, 533,
         114, 533,   6,  98, 533, 114, 533,   6,  98, 533, 114, 533,   6,  98,
         215, 114, 533,   6,  98, 419, 114, 272,   6,  98, 199, 114, 419,   6,
          98, 499, 114, 468,   6,  98, 257, 114, 257

In [73]:
tokenizer.convert_ids_to_tokens([ 4])

['<rest>']

In [74]:
astar = AStarGPT( model, tokenizer, input_ids, constraint_ids, max_length=512, beam_width=20, lookahead_k=10 )

constraint tokens:  ['</m>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', 'position_0x00', 'A#:maj6', '<fill>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', '<fill>', '</s>', '<h>']
self.constraint_bar:  4
self.position_token:  position_0x00
self.chord_tokens:  ['A#:maj6']


In [75]:
print(astar.constraint_bar)
print(astar.position_token)
print(astar.chord_tokens)

4
position_0x00
['A#:maj6']


In [76]:
generated_ids, model_steps = astar.decode()
generated_tokens = []

num_bars: 4 - num_tokens: 41

KeyboardInterrupt: 

In [65]:
print(generated_ids.shape)

torch.Size([1, 208])


In [66]:
print(model_steps)

117


In [67]:
for i in generated_ids[0]:
    generated_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(generated_tokens[harmony_start_index:])

['<h>', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'E:7', '<bar>', 'position_0x00', 'F:maj7', '<bar>', 'position_0x00', 'E:7', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'E:7', '<bar>', 'position_0x00', 'F:maj7', '<bar>', 'position_0x00', 'E:7', '</s>']


In [12]:
t = generated_tokens[harmony_start_index:]
line = []
for i in range(len(t)):
    line.append(t[i])
    if i+1 < len(t) and 'bar' in t[i+1]:
        print(line)
        line = []
print(line)

['<h>']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_10']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_10']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_10']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_10', '</s>']
