In [458]:
from harmony_tokenizers_m21 import ChordSymbolTokenizer, PitchClassTokenizer, MelodyPitchTokenizer, MergedMelHarmTokenizer
from data_utils import StructGPTMelHarmDataset, GenCollator
from torch.utils.data import DataLoader
import torch
from transformers import AutoConfig, GPT2LMHeadModel
import numpy as np

In [459]:
test_dir = '/media/maindisk/maximos/data/hooktheory_test'

In [460]:
cstok = ChordSymbolTokenizer()
pctok = PitchClassTokenizer()
meltok = MelodyPitchTokenizer()
tokenizer = MergedMelHarmTokenizer(meltok, cstok)

In [461]:
test_dataset = StructGPTMelHarmDataset(test_dir, tokenizer, max_length=512, num_bars=64, return_harmonization_labels=True)

In [462]:
d = test_dataset[0]

In [463]:
collator = GenCollator(tokenizer)

In [464]:
trainloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [465]:
batch = next(iter(trainloader))

In [466]:
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer.vocab),
    n_positions=512,
    n_layer=8,
    n_head=8,
    pad_token_id=tokenizer.vocab[tokenizer.pad_token],
    bos_token_id=tokenizer.vocab[tokenizer.bos_token],
    eos_token_id=tokenizer.vocab[tokenizer.eos_token],
    n_embd=512
)

model = GPT2LMHeadModel(config)

In [467]:
model_path = 'saved_models/gpt/ChordSymbolTokenizer/ChordSymbolTokenizer.pt'

In [468]:
device_name = 'cuda:0'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)

checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(547, 512)
    (wpe): Embedding(512, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=1536, nx=512)
          (c_proj): Conv1D(nf=512, nx=512)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=2048, nx=512)
          (c_proj): Conv1D(nf=512, nx=2048)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=547, bias=False)
)

In [469]:
num_beams = 5
for b in batch['input_ids']:
    melody_tokens = []
    real_tokens = []
    generated_tokens = []
    # find the start harmony token
    start_harmony_position = np.where( b == tokenizer.vocab[tokenizer.harmony_tokenizer.start_harmony_token] )[0][0]
    real_ids = b
    input_ids = b[:(start_harmony_position+1)].to(device)
    for i in input_ids:
        melody_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )

    for i in range(start_harmony_position, len(real_ids), 1):
        if real_ids[i] != tokenizer.pad_token_id:
            real_tokens.append( tokenizer.ids_to_tokens[ int(real_ids[i]) ].replace(' ','x') )
    
    # Define the bar token ID, eos_token_id, and per-batch sequence constraints
    bar_token_id = tokenizer.vocab['<bar>']
    eos_token_id = tokenizer.eos_token_id
    bars_count = (batch['input_ids'] == bar_token_id).sum(dim=1).reshape(batch['input_ids'].shape[0],-1)
    bars_count = bars_count[0]

    try:
        outputs = model.generate(
            input_ids=input_ids.reshape(1, input_ids.shape[0]),
            eos_token_id=tokenizer.eos_token_id,
            max_new_tokens=512,
            num_beams=num_beams,
        )
    except:
        print('exception: ', input_ids)
        outputs = model.generate(
            input_ids=input_ids.reshape(1, input_ids.shape[0]),
            eos_token_id=tokenizer.eos_token_id,
            max_new_tokens=512,
            num_beams=2,
        )
    for i in range(start_harmony_position, len(outputs[0]), 1):
        generated_tokens.append( tokenizer.ids_to_tokens[ int(outputs[0][i]) ].replace(' ','x') )

In [470]:
print(batch.keys())
print(melody_tokens)
# find where melody ends
melody_end_index = melody_tokens.index('</m>')
print('melody_end_index: ', melody_end_index)
# keep melody and constraints after end index
after_melody = melody_tokens[melody_end_index:]
print(after_melody)
# find how many bars before constraint and keep the constraints
bars_count = 0
constraint_tokens = []
constraint_found = False
i = 0
while i < len(after_melody):
    if after_melody[i] == '<bar>':
        bars_count += 1
    i += 1
    while i < len(after_melody) and '<bar>' != after_melody[i] and '<fill>' != after_melody[i]:
        constraint_found = True
        constraint_tokens.append(after_melody[i])
        i += 1
    if constraint_found:
        break

dict_keys(['input_ids', 'attention_mask', 'labels', 'constraints_mask'])
['<s>', '<bar>', 'ts_4x4', 'position_0x00', 'P:64', 'position_0x50', 'P:64', 'position_1x00', 'P:64', 'position_1x25', 'P:64', 'position_1x75', 'P:64', 'position_3x25', 'P:64', 'position_3x50', 'P:64', 'position_3x75', 'P:62', '<bar>', 'position_0x00', 'P:64', 'position_0x50', 'P:64', 'position_0x75', 'P:64', 'position_1x25', 'P:62', 'position_1x50', 'P:61', 'position_1x75', 'P:57', 'position_2x75', 'P:57', 'position_3x00', 'P:64', 'position_3x25', 'P:64', 'position_3x75', 'P:64', '<bar>', 'position_0x00', 'P:62', 'position_1x50', 'P:64', 'position_1x75', 'P:64', 'position_2x25', 'P:62', 'position_3x25', 'P:62', 'position_3x50', 'P:60', 'position_3x75', 'P:57', '<bar>', 'position_0x00', 'P:57', 'position_2x00', '<rest>', 'position_3x00', '<rest>', 'position_3x25', 'P:64', 'position_3x75', 'P:64', '</m>', '<bar>', '<fill>', '<bar>', '<fill>', '<bar>', 'position_0x00', 'D:min', '<fill>', '<bar>', '<fill>', '<h>']
me

In [471]:
print(bars_count)
print(constraint_tokens)

3
['position_0x00', 'D:min']


In [472]:
print(generated_tokens)

['<h>', '<bar>', 'position_0x00', 'A:maj', 'position_2x00', 'A:maj', '<bar>', 'position_0x00', 'A:maj', 'position_2x00', 'A:maj', '<bar>', 'position_0x00', 'A:maj', 'position_2x00', 'A:maj', '<bar>', 'position_0x00', 'A:maj', 'position_2x00', '</s>']


In [473]:
def is_sublist_contiguous(q, d):
    q_len = len(q)
    for i in range(len(d) - q_len + 1):
        if d[i:i + q_len] == q:
            return True
    return False

# get proper bar in generated tokens
bar_idxs = [i for i in range(len(generated_tokens)) if generated_tokens[i] == '<bar>']
print(bar_idxs)
if bars_count - 1 >= len(bar_idxs):
    start_index = -1 # not applicable
else:
    start_index = bar_idxs[bars_count-1]
    if bars_count - 1 == len(bar_idxs) - 1:
        end_index = len(bar_idxs)
    else:
        end_index = bar_idxs[bars_count]
constraints_area = None
if start_index >= 0:
    constraints_area = generated_tokens[start_index:end_index]
print(constraint_tokens)
print(constraints_area)
res = False if constraints_area is None else is_sublist_contiguous(constraint_tokens, constraints_area)
print(res)

[1, 6, 11, 16]
['position_0x00', 'D:min']
['<bar>', 'position_0x00', 'A:maj', 'position_2x00', 'A:maj']
False
