In [7]:
# USER OPTIONS
# define tokenizer name - should be one among the keys in the cell below
tokenizer_name = 'ChordSymbolTokenizer' # or any other name from the keys in tokenizers dictionary
# tokenizer_name = 'RootTypeTokenizer'
# tokenizer_name = 'PitchClassTokenizer'
# folder to xmls
val_dir = '/media/maindisk/maximos/data/hooktheory_test'
# val_dir = '/media/maindisk/maximos/data/gjt_melodies/Library_melodies'
# val_dir = '/media/datadisk/datasets/gjt_melodies/Library_melodies'
# val_dir = '/media/maximos/9C33-6BBD/data/gjt_melodies/Library_melodies'

# define batch size depending on GPU availability / status
batchsize = 16
# select device name - could be 'cpu', 'cuda', 'coda:0', 'cuda:1'...
device_name = 'cuda'

In [8]:
from data_utils import SeparatedMelHarmMarkovDataset
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
from tqdm import tqdm
from models import TransGraphVAE
import csv

In [9]:
tokenizers = {
    'ChordSymbolTokenizer': ChordSymbolTokenizer,
    'RootTypeTokenizer': RootTypeTokenizer,
    'PitchClassTokenizer': PitchClassTokenizer,
    'RootPCTokenizer': RootPCTokenizer,
    'GCTRootPCTokenizer': GCTRootPCTokenizer,
    'GCTSymbolTokenizer': GCTSymbolTokenizer,
    'GCTRootTypeTokenizer': GCTRootTypeTokenizer
}

In [10]:
melody_tokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')
harmony_tokenizer = tokenizers[tokenizer_name].from_pretrained('saved_tokenizers/' + tokenizer_name)

tokenizer = MergedMelHarmTokenizer(melody_tokenizer, harmony_tokenizer)

model_path = 'saved_models/bart/' + tokenizer_name + '/' + tokenizer_name + '.pt'

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

bart = BartForConditionalGeneration(bart_config)

if torch.cuda.is_available():
    checkpoint = torch.load(model_path, weights_only=True)
else:
    checkpoint = torch.load(model_path, map_location="cpu", weights_only=True)
bart.load_state_dict(checkpoint)

<All keys matched successfully>

In [13]:
test_dir = '/mnt/ssd2/maximos/data/hooktheory_test'
test_dataset = SeparatedMelHarmMarkovDataset(test_dir, tokenizer, max_length=512, num_bars=64)

# Data collator for BART
def create_data_collator(tokenizer, model):
    return DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
# end create_data_collator

In [14]:
collator = create_data_collator(tokenizer, model=bart)
valloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [15]:
config = {
    'hidden_dim_LSTM': 512,
    'hidden_dim_GNN': 256,
    'latent_dim': 512,
    'condition_dim': 256,
    'use_attention': False
}

model = TransGraphVAE(transformer=bart, device=device, tokenizer=tokenizer, **config)
model.to(device)
model.eval()

TransGraphVAE(
  (transformer): BartForConditionalGeneration(
    (model): BartModel(
      (shared): BartScaledWordEmbedding(545, 512, padding_idx=1)
      (encoder): BartEncoder(
        (embed_tokens): BartScaledWordEmbedding(545, 512, padding_idx=1)
        (embed_positions): BartLearnedPositionalEmbedding(514, 512)
        (layers): ModuleList(
          (0-7): 8 x BartEncoderLayer(
            (self_attn): BartSdpaAttention(
              (k_proj): Linear(in_features=512, out_features=512, bias=True)
              (v_proj): Linear(in_features=512, out_features=512, bias=True)
              (q_proj): Linear(in_features=512, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=512, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=512, out_features=512, bias=True)
            (fc2): Linear(in_feat

In [16]:
with torch.no_grad():
    with tqdm(valloader, unit='batch') as tepoch:
        tepoch.set_description(f'Running')
        print(tepoch)
        for batch in tepoch:
            input_ids = batch['input_ids'].to(device)
            transitions = batch['transitions'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            output_tokens = []
            output_recon_tokens = []
            print(input_ids)
            print(transitions)
            outputs = model(input_ids, transitions, encoder_attention=attention_mask, generate_max_tokens=500)
            for i in outputs['generated_ids'][0]:
                output_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
            for i in outputs['generated_recon_ids'][0]:
                output_recon_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
            break

  return self.iter().getElementsByClass(classFilterList)
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


Running:   0%|          | 0/1520 [00:00<?, ?batch/s]
tensor([[  2,   6, 180,  95,  46, 103,   4, 119,  46, 123,  48,   6,  95,  48,
          99,  45, 111,   4, 119,  41, 123,  41,   6,  95,  46,  99,  46, 103,
          46, 107,  46, 111,  46, 115,  46, 119,  41, 123,  43,   6,  95,  43,
         103,   4, 111,  46, 115,  46, 119,  48, 123,  46,   6,  95,  46, 107,
           4, 119,  46, 123,  48,   6,  95,  48,  99,  45, 111,   4, 119,  41,
         123,  41,   6,  95,  46,  99,  46, 103,  46, 107,  46, 111,  46, 115,
          46, 119,  41, 123,  43,   6,  95,  43, 103,   4, 111,  46, 115,  46,
         119,  46, 123,  46,   6,  95,  46, 103,  41, 107,  41, 119,  48, 123,
          48,   6,  95,  48, 103,  41, 107,  41, 119,  50, 123,  50,   6,  95,
          50,  99,  48, 101,  46, 103,  46, 107,  46, 119,  51, 123,  51,   6,
          95,  51, 103,   4, 111,  50, 115,  48, 119,  46, 123,  46,   6,  95,
          46, 107,   4, 119,  46, 123,  48,   6,  95,  48,  99,  45, 111,   4,

Running:   0%|          | 0/1520 [00:02<?, ?batch/s]


In [17]:
print(output_tokens)

['<h>', 'G#:min', 'P:99', 'position_9x83', 'position_8x00', 'F:min13', 'E:maj', 'P:85', 'C#:7(#9)', 'C:1', 'E:maj9', 'G#:minmaj7', 'C:minmaj7', 'P:69', 'position_6x66', 'F:hdim7', 'F#:dim', 'C#:maj6', 'P:85', 'A#:dim7', 'E:maj', 'P:49', 'F#:9', 'C:9', 'A:maj13', 'C:min', 'P:85', 'position_9x16', 'position_5x25', 'position_5x00', 'P:85', 'D#:maj6', 'F:min13', 'position_6x66', 'P:62', 'G:7(#9)', '<rest>', 'position_8x83', 'P:96', 'position_9x83', 'A:aug', 'E:hdim7', 'G:7(b9)', 'G#:minmaj7', 'A#:maj7', 'G#:1', 'C#:min6', 'A:maj7', 'D#:aug', 'G:maj9', 'P:108', 'G:7(#11)', 'F#:minmaj7', 'E:min9', 'position_3x83', 'A:maj7', 'A:min6', 'position_1x16', 'C:min6', 'P:92', 'P:43', 'C#:min6', 'P:53', 'P:84', 'position_8x75', 'P:52', 'C#:min6', 'C#:7(#11)', 'position_6x16', 'G:aug', 'P:60', 'position_4x33', 'G#:dim7', 'P:25', 'D:7', 'position_0x00', 'ts_6x4', 'D#:13', 'C#:dim', 'E:11', 'G:13', 'F#:min7', 'G:9', 'G#:min', 'C:minmaj7', 'G:sus2', 'P:86', 'G:13', 'F#:sus4', 'F#:min7', 'D#:hdim7', 'posi

In [18]:
print(output_recon_tokens)

['<h>', 'P:87', 'C#:7', 'E:minmaj7', 'position_0x50', 'E:maj13', 'D#:min13', 'A:dim', 'position_3x75', 'D:7(b13)', 'G#:7(#11)', 'P:92', 'P:103', 'P:57', 'position_0x25', 'position_8x83', 'D:13', 'position_9x83', 'D#:7(b13)', 'position_1x16', 'position_6x66', 'C:dim', 'ts_9x4', 'G:min9', 'A#:5', 'F:5', 'F#:maj7', 'ts_1x8', 'A#:maj', 'A#:minmaj7', 'D:11', 'C:dim', 'D:maj7', 'D#:7', 'C:sus2', 'position_9x83', 'D:min6', 'P:69', 'P:43', 'A#:5', 'C:', 'P:64', 'G#:7(b13)', 'P:77', 'A#:sus2', 'D#:min', 'position_2x66', 'position_0x75', 'D:13', 'F#:aug', 'F:hdim7', 'D#:maj6', 'position_9x66', 'F#:7(#9)', 'position_8x16', 'position_9x83', 'P:53', 'C#:min7', 'G:hdim7', 'P:38', 'P:108', 'position_2x66', 'position_3x33', 'C:sus2', 'D:aug', 'E:minmaj7', 'D:dim', 'C:minmaj7', 'F#:min', 'C#:hdim7', 'P:85', 'P:39', 'F:hdim7', 'P:59', 'D#:min7', 'E:maj9', 'E:7(#9)', 'P:72', 'position_6x66', 'F:hdim7', 'G#:7(#11)', 'P:87', 'C#:maj6', 'D:aug', 'D#:aug', 'G:hdim7', 'A#:5', 'position_9x83', 'G#:maj9', 'P:25

In [19]:
input_ids = batch['input_ids'].to(device)
print(input_ids)
input_tokens = []
for i in input_ids[0]:
    input_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(input_tokens)

tensor([[  2,   6, 180,  95,  46, 103,   4, 119,  46, 123,  48,   6,  95,  48,
          99,  45, 111,   4, 119,  41, 123,  41,   6,  95,  46,  99,  46, 103,
          46, 107,  46, 111,  46, 115,  46, 119,  41, 123,  43,   6,  95,  43,
         103,   4, 111,  46, 115,  46, 119,  48, 123,  46,   6,  95,  46, 107,
           4, 119,  46, 123,  48,   6,  95,  48,  99,  45, 111,   4, 119,  41,
         123,  41,   6,  95,  46,  99,  46, 103,  46, 107,  46, 111,  46, 115,
          46, 119,  41, 123,  43,   6,  95,  43, 103,   4, 111,  46, 115,  46,
         119,  46, 123,  46,   6,  95,  46, 103,  41, 107,  41, 119,  48, 123,
          48,   6,  95,  48, 103,  41, 107,  41, 119,  50, 123,  50,   6,  95,
          50,  99,  48, 101,  46, 103,  46, 107,  46, 119,  51, 123,  51,   6,
          95,  51, 103,   4, 111,  50, 115,  48, 119,  46, 123,  46,   6,  95,
          46, 107,   4, 119,  46, 123,  48,   6,  95,  48,  99,  45, 111,   4,
         119,  41, 123,  41,   6,  95,  50,  99,  50

In [20]:
outputs = bart.generate(
    input_ids=input_ids,
    # attention_mask=batch['attention_mask'][bi],
    bos_token_id=tokenizer.vocab[tokenizer.harmony_tokenizer.start_harmony_token],
    eos_token_id=tokenizer.eos_token_id,
    max_new_tokens=500,
    do_sample=True,
    temperature=1.0
)

In [21]:
print(outputs)
bart_only_outputs = []
for i in outputs[0]:
    bart_only_outputs.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(bart_only_outputs)

tensor([[  2, 196,   6,  95, 197,   6,  95, 342,   6,  95, 197,   6,  95, 400,
           6,  95, 197,   6,  95, 349,   6,  95, 197,   6,  95, 342,   6,  95,
         197,   6,  95, 342,   6,  95, 197,   6,  95, 400,   6,  95, 197,   6,
          95, 349,   6,  95, 197,   6,  95, 342, 111, 400,   3]],
       device='cuda:0')
['<s>', '<h>', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'F:maj', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'F:maj7', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'F:maj', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'F:maj', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'F:maj7', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'F:maj', 'position_2x00', 'G:maj', '</s>']
