In [30]:
from harmony_tokenizers_m21 import ChordSymbolTokenizer, PitchClassTokenizer, MelodyPitchTokenizer, MergedMelHarmTokenizer
from data_utils import StructBARTMelHarmDataset
from torch.utils.data import DataLoader
import torch
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
import numpy as np

from a_star import AStarBART

In [31]:
test_dir = '/media/maindisk/maximos/data/hooktheory_test'

cstok = ChordSymbolTokenizer()
pctok = PitchClassTokenizer()
meltok = MelodyPitchTokenizer()
# tokenizer = MergedMelHarmTokenizer(meltok, cstok)
tokenizer = MergedMelHarmTokenizer(meltok, pctok)

test_dataset = StructBARTMelHarmDataset(test_dir, tokenizer, max_length=512, num_bars=16)

In [32]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.25,
    decoder_layerdrop=0.25,
    dropout=0.25
)

model = BartForConditionalGeneration(bart_config)

# model_path = 'saved_models/bart/ChordSymbolTokenizer/ChordSymbolTokenizer.pt'
model_path = 'saved_models/bart/PitchClassTokenizer/PitchClassTokenizer.pt'

# device_name = 'cuda:0'
device_name = 'cpu'
if device_name == 'cpu':
    device = torch.device('cpu')
else:
    if torch.cuda.is_available():
        device = torch.device(device_name)
    else:
        print('Selected device not available: ' + device_name)

checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(211, 512, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(211, 512, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(514, 512)
      (layers): ModuleList(
        (0-7): 8 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=512, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=512, bias=True)
          (final_layer_norm

In [33]:
def create_data_collator(tokenizer, model):
    return DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

collator = create_data_collator(tokenizer, model=model)

trainloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [34]:
batch = next(iter(trainloader))

In [35]:
print(batch)

{'input_ids': tensor([[  2,   6, 183,  98,  65, 102,  68, 104,  63, 108,  63, 110,  61, 112,
          63, 118,   4, 122,  61, 124,  63, 126,   4, 128,  65,   6,  98,  65,
         102,  68, 104,  63, 108,  63, 110,  61, 112,  63, 118,   4, 122,  61,
         124,  61, 128,  60,   6,  98,  60, 102,  60, 104,  60, 110,  60, 112,
          60, 118,  60, 120,  58, 124,  56, 128,  56,   6,  98,  58, 100,  56,
         102,  58, 104,  56, 106,  58, 108,  65, 112,  65, 118,   4, 122,  61,
         124,  63, 126,   4, 128,  65,   6,  98,  65, 102,  68, 104,  63, 108,
          63, 110,  61, 112,  63, 118,   4, 122,  61, 124,  63, 126,   4, 128,
          65,   6,  98,  65, 102,  68, 104,  63, 108,  63, 110,  61, 112,  63,
         118,   4, 126,  61, 128,  63,   6,  98,  65, 102,   4, 106,  65, 110,
           4, 114,  65, 118,   4, 122,  65, 124,  68, 128,  70,   6,  98,  70,
         100,  70, 104,   4, 120,   4, 128,   4,   8,   6,   9,   6,  98, 201,
         204, 208,   9,   6,   9,   6,

In [36]:
all_ids = batch['input_ids']
print(all_ids)
melody_end_index = all_ids[0].tolist().index( tokenizer.vocab['</m>'] )
print(melody_end_index)
constraint_ids = all_ids[0][melody_end_index:]
print(constraint_ids)
# start_harmony_position = np.where( all_ids == harmony_start_index )[0][0]
input_ids = all_ids.clone()
print(input_ids.shape)
print(tokenizer.vocab['</s>'])

tensor([[  2,   6, 183,  98,  65, 102,  68, 104,  63, 108,  63, 110,  61, 112,
          63, 118,   4, 122,  61, 124,  63, 126,   4, 128,  65,   6,  98,  65,
         102,  68, 104,  63, 108,  63, 110,  61, 112,  63, 118,   4, 122,  61,
         124,  61, 128,  60,   6,  98,  60, 102,  60, 104,  60, 110,  60, 112,
          60, 118,  60, 120,  58, 124,  56, 128,  56,   6,  98,  58, 100,  56,
         102,  58, 104,  56, 106,  58, 108,  65, 112,  65, 118,   4, 122,  61,
         124,  63, 126,   4, 128,  65,   6,  98,  65, 102,  68, 104,  63, 108,
          63, 110,  61, 112,  63, 118,   4, 122,  61, 124,  63, 126,   4, 128,
          65,   6,  98,  65, 102,  68, 104,  63, 108,  63, 110,  61, 112,  63,
         118,   4, 126,  61, 128,  63,   6,  98,  65, 102,   4, 106,  65, 110,
           4, 114,  65, 118,   4, 122,  65, 124,  68, 128,  70,   6,  98,  70,
         100,  70, 104,   4, 120,   4, 128,   4,   8,   6,   9,   6,  98, 201,
         204, 208,   9,   6,   9,   6,   9,   6,   9

In [37]:
astar = AStarBART( model, tokenizer, input_ids, constraint_ids, max_length=512, beam_width=20, lookahead_k=10 )

In [38]:
print(astar.constraint_bar)
print(astar.position_token)
print(astar.chord_tokens)

2
position_0x00
['chord_pc_2', 'chord_pc_5', 'chord_pc_9']


In [39]:
generated_ids, finished = astar.decode()
generated_tokens = []

constraint_checker: ['<s>', '<h>']
num_bars: 0 - num_tokens: 2
constraint_checker: ['<s>', '<bar>']
num_bars: 1 - num_tokens: 2
constraint_checker: ['<s>', 'chord_pc_9']
num_bars: 0 - num_tokens: 2
constraint_checker: ['<s>', '</s>']
num_bars: 0 - num_tokens: 2
constraint_checker: ['<s>', 'chord_pc_2']
num_bars: 0 - num_tokens: 2
constraint_checker: ['<s>', 'chord_pc_4']
num_bars: 0 - num_tokens: 2
constraint_checker: ['<s>', '<s>']
num_bars: 0 - num_tokens: 2
inconsistent
constraint_checker: ['<s>', 'chord_pc_7']
num_bars: 0 - num_tokens: 2
constraint_checker: ['<s>', 'chord_pc_5']
num_bars: 0 - num_tokens: 2
constraint_checker: ['<s>', 'chord_pc_6']
num_bars: 0 - num_tokens: 2
constraint_checker: ['<s>', '<h>', '<bar>']
num_bars: 1 - num_tokens: 3
constraint_checker: ['<s>', '<h>', 'position_3x75']
num_bars: 0 - num_tokens: 3
inconsistent
constraint_checker: ['<s>', '<h>', 'chord_pc_11']
num_bars: 0 - num_tokens: 3
constraint_checker: ['<s>', '<h>', 'position_2x00']
num_bars: 0 - num

In [40]:
print(generated_ids.shape)

torch.Size([1, 84])


In [41]:
print(len(finished))
print(finished[0].tokens)

1
tensor([[  2,   7,   6,  98, 199, 203, 206, 114, 199, 203, 206,   6,  98, 201,
         204, 208, 114, 201, 206, 210,   6,  98, 199, 203, 206, 114, 199, 203,
         206,   6,  98, 201, 204, 208, 114, 201, 206, 210,   6,  98, 199, 203,
         206, 114, 199, 203, 206,   6,  98, 201, 204, 208, 114, 201, 206, 210,
           6,  98, 199, 203, 206, 114, 199, 203, 206,   6,  98, 201, 204, 208,
         114, 201, 206, 210,   6,  98, 199, 203, 206, 114, 199, 203, 206,   3]])


In [42]:
for i in generated_ids[0]:
    generated_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(generated_tokens)

['<s>', '<h>', '<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', 'position_2x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', '<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', 'position_2x00', 'chord_pc_2', 'chord_pc_7', 'chord_pc_11', '<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', 'position_2x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', '<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', 'position_2x00', 'chord_pc_2', 'chord_pc_7', 'chord_pc_11', '<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', 'position_2x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', '<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', 'position_2x00', 'chord_pc_2', 'chord_pc_7', 'chord_pc_11', '<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', 'position_2x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', '<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', 'position_2x00', 'chord_p

In [43]:
t = generated_tokens
line = []
for i in range(len(t)):
    line.append(t[i])
    if i+1 < len(t) and 'bar' in t[i+1]:
        print(line)
        line = []
print(line)

['<s>', '<h>']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', 'position_2x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', 'position_2x00', 'chord_pc_2', 'chord_pc_7', 'chord_pc_11']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', 'position_2x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', 'position_2x00', 'chord_pc_2', 'chord_pc_7', 'chord_pc_11']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', 'position_2x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', 'position_2x00', 'chord_pc_2', 'chord_pc_7', 'chord_pc_11']
['<bar>', 'position_0x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7', 'position_2x00', 'chord_pc_0', 'chord_pc_4', 'chord_pc_7']
['<bar>', 'position_0x00', 'chord_pc_2', 'chord_pc_5', 'chord_pc_9', 'position_2x00', 