In [1]:
# USER OPTIONS
# define tokenizer name - should be one among the keys in the cell below
tokenizer_name = 'ChordSymbolTokenizer' # or any other name from the keys in tokenizers dictionary
# tokenizer_name = 'RootTypeTokenizer'
# tokenizer_name = 'PitchClassTokenizer'
# folder to xmls
val_dir = '/media/maindisk/maximos/data/hooktheory_test'
# val_dir = '/media/maindisk/maximos/data/gjt_melodies/Library_melodies'
# val_dir = '/media/datadisk/datasets/gjt_melodies/Library_melodies'
# val_dir = '/media/maximos/9C33-6BBD/data/gjt_melodies/Library_melodies'

# define batch size depending on GPU availability / status
batchsize = 16
# select device name - could be 'cpu', 'cuda', 'coda:0', 'cuda:1'...
device_name = 'cuda'

In [2]:
from data_utils import SeparatedMelHarmMarkovDataset
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
from tqdm import tqdm
from models import TransGraphVAE
import csv

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizers = {
    'ChordSymbolTokenizer': ChordSymbolTokenizer,
    'RootTypeTokenizer': RootTypeTokenizer,
    'PitchClassTokenizer': PitchClassTokenizer,
    'RootPCTokenizer': RootPCTokenizer,
    'GCTRootPCTokenizer': GCTRootPCTokenizer,
    'GCTSymbolTokenizer': GCTSymbolTokenizer,
    'GCTRootTypeTokenizer': GCTRootTypeTokenizer
}

In [4]:
melody_tokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')
harmony_tokenizer = tokenizers[tokenizer_name].from_pretrained('saved_tokenizers/' + tokenizer_name)

tokenizer = MergedMelHarmTokenizer(melody_tokenizer, harmony_tokenizer)

model_path = 'saved_models/bart/' + tokenizer_name + '/' + tokenizer_name + '.pt'

In [5]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

bart = BartForConditionalGeneration(bart_config)

In [6]:
test_dir = '/mnt/ssd2/maximos/data/hooktheory_test'
test_dataset = SeparatedMelHarmMarkovDataset(test_dir, tokenizer, max_length=512, num_bars=64)

# Data collator for BART
def create_data_collator(tokenizer, model):
    return DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
# end create_data_collator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
collator = create_data_collator(tokenizer, model=bart)
valloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [8]:
config = {
    'hidden_dim_LSTM': 512,
    'hidden_dim_GNN': 256,
    'latent_dim': 512,
    'condition_dim': 256,
    'use_attention': False
}

model = TransGraphVAE(transformer=bart, device=device, **config)
model.to(device)
model.eval()

TransGraphVAE(
  (transformer): BartForConditionalGeneration(
    (model): BartModel(
      (shared): BartScaledWordEmbedding(545, 512, padding_idx=1)
      (encoder): BartEncoder(
        (embed_tokens): BartScaledWordEmbedding(545, 512, padding_idx=1)
        (embed_positions): BartLearnedPositionalEmbedding(514, 512)
        (layers): ModuleList(
          (0-7): 8 x BartEncoderLayer(
            (self_attn): BartSdpaAttention(
              (k_proj): Linear(in_features=512, out_features=512, bias=True)
              (v_proj): Linear(in_features=512, out_features=512, bias=True)
              (q_proj): Linear(in_features=512, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=512, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=512, out_features=512, bias=True)
            (fc2): Linear(in_feat

In [9]:
with torch.no_grad():
    with tqdm(valloader, unit='batch') as tepoch:
        tepoch.set_description(f'Running')
        print(tepoch)
        for batch in tepoch:
            input_ids = batch['input_ids'].to(device)
            transitions = batch['transitions'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            output_tokens = []
            output_recon_tokens = []
            print(input_ids)
            print(transitions)
            outputs = model(input_ids, transitions, encoder_attention=attention_mask, generate_max_tokens=500)
            for i in outputs['generated_ids'][0]:
                output_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
            for i in outputs['generated_recon_ids'][0]:
                output_recon_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
            break

  return self.iter().getElementsByClass(classFilterList)
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


Running:   0%|          | 0/1520 [00:00<?, ?batch/s]
tensor([[  2,   6, 180,  95,   4, 123,  46,   6,  95,  43, 103,  46, 111,  50,
         119,  48, 123,  46,   6,  95,  46, 103,  48, 107,  46, 111,  50, 119,
          46,   6,  95,  43,  99,  46, 107,  50, 119,  48, 123,  46,   6,  95,
          48,  99,  46, 103,  46, 107,  48, 111,  50, 119,  46,   6,  95,  43,
          99,  46, 107,  50, 119,  48, 123,  46,   6,  95,  46, 103,  46, 107,
          48, 111,  50, 123,   4,   6,  95,  51,  99,  50, 103,  51, 107,  50,
         111,  51, 119,  51, 123,  50,   6,  95,  48, 111,   4,   6,  95,  51,
         103,  51, 111,  51, 119,  53, 123,  50,   6,  95,  48, 111,   4, 119,
          48, 123,  50,   6,  95,  51,  99,  51, 107,  51, 111,  53, 119,  43,
           6,  95,  46,   6,  95,   4,   6,  95,   4, 123,  50,   6,  95,  51,
         103,  50, 107,  51, 115,  50, 119,  53, 123,  46,   6,  95,  46, 111,
           4]], device='cuda:0')
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
    

Running:   0%|          | 0/1520 [00:01<?, ?batch/s]


In [11]:
print(output_tokens)

['<s>', 'F:min7', 'E:min', 'C:minmaj7', 'P:21', 'P:54', 'C#:sus4', 'G:dim', 'A:min11', 'F:dim7', 'G#:dim7', 'F#:11', 'G#:5', 'A#:aug', 'E:dim', 'P:100', 'F#:7', 'G#:maj7', 'ts_9x8', 'position_9x50', 'ts_5x4', 'D:dim7', 'position_3x33', 'A:min6', 'A#:', 'P:24', 'F:7(#9)', 'position_1x66', 'D:7(#9)', 'E:sus4', 'F#:13', 'A:min13', 'position_3x83', 'F:maj', 'C#:7(#11)', 'position_3x66', 'A#:maj7', 'position_4x66', 'F:min9', 'G#:maj13', 'A:maj6', 'F#:7(b9)', 'A:dim7', 'G:hdim7', 'D#:dim', 'ts_13x8', 'G:sus2', 'F:dim', 'P:91', 'F#:7(b9)', 'D:min9', 'P:40', 'A#:hdim7', 'C:maj7', 'F:13', 'P:21', 'G:minmaj7', 'position_1x66', 'P:96', 'position_7x25', 'position_2x83', 'F:min6', 'P:96', 'C#:7(b13)', 'A#:1', 'P:54', 'F#:7(b13)', 'F#:', 'F#:maj6', 'A#:maj9', 'F#:13', 'P:88', 'position_4x83', 'position_0x33', 'ts_5x8', '<rest>', 'P:100', 'P:103', 'position_9x16', 'P:107', 'C:min9', 'F:aug', 'A:', 'G:5', 'ts_8x4', 'F#:min13', 'C#:7(#9)', 'A#:11', 'D#:11', 'P:52', 'C:5', 'A#:13', 'F:7(#11)', 'P:93', '

In [13]:
print(output_recon_tokens)

['<s>', 'F:7', 'P:64', 'position_8x16', 'G#:7(b9)', 'F#:min9', 'P:42', 'position_2x66', 'A#:dim7', 'P:38', 'ts_6x4', 'D#:9', 'position_5x66', 'position_5x25', 'D:maj9', 'C#:dim7', 'F#:sus4', 'position_4x16', 'F#:min11', 'G:minmaj7', 'P:104', 'ts_5x4', 'D:13', 'position_2x75', 'C:min', 'G#:maj6', '<rest>', 'ts_5x4', '<pad>', 'P:49', 'F:sus2', 'F#:min', 'A#:5', 'C:maj6', 'G:min6', 'G#:minmaj7', 'position_4x25', 'D:7(b13)', 'D#:7', 'E:min13', 'position_6x25', 'A:hdim7', 'C#:1', 'G:7(#11)', 'position_2x75', 'F:11', 'G#:dim7', 'F#:maj6', 'A#:7', 'D#:7(b13)', 'P:91', 'P:99', 'F#:maj9', 'A:1', 'D:7(#9)', 'P:22', 'P:71', 'G:13', 'P:95', 'C#:minmaj7', 'G:dim', 'position_6x00', 'E:5', 'E:maj13', 'F#:min9', 'G#:dim7', 'P:72', 'G#:minmaj7', 'D:1', '<h>', 'position_9x50', 'position_2x16', 'D#:7(b9)', 'C#:dim', 'A:7', 'E:sus2', 'G#:hdim7', 'P:49', 'position_4x50', 'position_7x75', 'C:min', 'D#:sus4', 'P:24', 'E:aug', 'D:7(#9)', 'D#:sus4', 'C#:aug', 'G#:min9', 'A#:aug', 'position_8x75', 'position_8x7