In [3]:
# USER OPTIONS
# define tokenizer name - should be one among the keys in the cell below
tokenizer_name = 'ChordSymbolTokenizer' # or any other name from the keys in tokenizers dictionary
# tokenizer_name = 'RootTypeTokenizer'
# tokenizer_name = 'PitchClassTokenizer'
# folder to xmls
val_dir = '/media/maindisk/maximos/data/hooktheory_test'
# val_dir = '/media/maindisk/maximos/data/gjt_melodies/Library_melodies'
# val_dir = '/media/datadisk/datasets/gjt_melodies/Library_melodies'
# val_dir = '/media/maximos/9C33-6BBD/data/gjt_melodies/Library_melodies'

# define batch size depending on GPU availability / status
batchsize = 16
# select device name - could be 'cpu', 'cuda', 'coda:0', 'cuda:1'...
device_name = 'cuda'

In [4]:
import pickle
with open('data/idioms_mir_quick_reference.pickle', 'rb') as f:
    idioms = pickle.load(f)
with open('data/idioms_mir_full_info.pickle', 'rb') as f:
    info_idioms = pickle.load(f)

In [5]:
bc_major = idioms['BachChorales_[0 2 4 5 7 9 11]']
bc_minor = idioms['BachChorales_[0 2 3 5 7 8 10]']
jazz_maj = idioms['Jazz_[0 2 4 5 7 9 11]']
jazz_min = idioms['Jazz_[0 2 3 5 7 9 10]']
organum = idioms['organum_[0 2 3 5 7 8 10]']

In [6]:
for c in info_idioms['organum']['[0 2 3 5 7 8 10]']['chords'].values():
    print(c['match_symbol'])

A:5
C:5
D:5
E:5
D#:aug
A:1
G:1
F:5
D:1
C:1
D:1
E:1
G:5
B:1
D:min
B:5
C:maj


In [7]:
from data_utils import SeparatedMelHarmMarkovDataset
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
from tqdm import tqdm
from models import TransGraphVAE
import csv
from transformers import LogitsProcessor, StoppingCriteria, StoppingCriteriaList

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
tokenizers = {
    'ChordSymbolTokenizer': ChordSymbolTokenizer,
    'RootTypeTokenizer': RootTypeTokenizer,
    'PitchClassTokenizer': PitchClassTokenizer,
    'RootPCTokenizer': RootPCTokenizer,
    'GCTRootPCTokenizer': GCTRootPCTokenizer,
    'GCTSymbolTokenizer': GCTSymbolTokenizer,
    'GCTRootTypeTokenizer': GCTRootTypeTokenizer
}

In [9]:
melody_tokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')
harmony_tokenizer = tokenizers[tokenizer_name].from_pretrained('saved_tokenizers/' + tokenizer_name)

tokenizer = MergedMelHarmTokenizer(melody_tokenizer, harmony_tokenizer)

bart_path = 'saved_models/bart/' + tokenizer_name + '/' + tokenizer_name + '.pt'

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
# chord symbol tokenizer blues
toks = ['C:7', 'F:7','C:7','C:7', \
    'F:7','F:7','C:7','C:7',\
    'G:7','F:7','C:7','G:7']
blues_maj = tokenizer.make_markov_from_tokens_list(toks)
print(np.nonzero(blues_maj))

(array([  6,   6,   6, 151, 151, 209]), array([  6, 151, 209,   6, 151, 151]))


In [12]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

bart = BartForConditionalGeneration(bart_config)

if torch.cuda.is_available():
    checkpoint = torch.load(bart_path, weights_only=True)
else:
    checkpoint = torch.load(bart_path, map_location="cpu", weights_only=True)
bart.load_state_dict(checkpoint)

  return self.fget.__get__(instance, owner)()


<All keys matched successfully>

In [13]:
enc1 = bart.model.encoder
enc2 = bart.get_encoder()
dec1 = bart.model.decoder
dec2 = bart.get_decoder()

In [14]:
dec1 == dec2

True

In [15]:
test_dir = '/mnt/ssd2/maximos/data/hooktheory_test'
test_dataset = SeparatedMelHarmMarkovDataset(test_dir, tokenizer, max_length=512, num_bars=64)

# Data collator for BART
def create_data_collator(tokenizer, model):
    return DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
# end create_data_collator

In [16]:
collator = create_data_collator(tokenizer, model=bart)
valloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [17]:
model_path = 'saved_models/bart_cvae/' + tokenizer_name + '/' + tokenizer_name + '.pt'

config = {
    'hidden_dim_LSTM': 1024,
    'hidden_dim_GNN': 1024,
    'latent_dim': 1024,
    'condition_dim': 1024,
    'use_attention': False
}

model = TransGraphVAE(transformer=bart, device=device, tokenizer=tokenizer, **config)

if torch.cuda.is_available():
    checkpoint = torch.load(model_path, weights_only=True)
else:
    checkpoint = torch.load(model_path, map_location="cpu", weights_only=True)
model.load_state_dict(checkpoint)

model.to(device)
model.eval()

TransGraphVAE(
  (transformer): BartForConditionalGeneration(
    (model): BartModel(
      (shared): BartScaledWordEmbedding(545, 512, padding_idx=1)
      (encoder): BartEncoder(
        (embed_tokens): BartScaledWordEmbedding(545, 512, padding_idx=1)
        (embed_positions): BartLearnedPositionalEmbedding(514, 512)
        (layers): ModuleList(
          (0-7): 8 x BartEncoderLayer(
            (self_attn): BartSdpaAttention(
              (k_proj): Linear(in_features=512, out_features=512, bias=True)
              (v_proj): Linear(in_features=512, out_features=512, bias=True)
              (q_proj): Linear(in_features=512, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=512, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=512, out_features=512, bias=True)
            (fc2): Linear(in_feat

In [18]:
batch = next(iter(valloader))

  return self.iter().getElementsByClass(classFilterList)
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


In [19]:
with torch.no_grad():
    input_ids = batch['input_ids'].to(device)
    # transitions = batch['transitions'].to(device)
    # transitions = torch.FloatTensor(bc_major).reshape(1,bc_major.shape[0], bc_major.shape[1]).to(device)
    # transitions = torch.FloatTensor(bc_minor).reshape(1,bc_minor.shape[0], bc_minor.shape[1]).to(device)
    # transitions = torch.FloatTensor(organum).reshape(1,organum.shape[0], organum.shape[1]).to(device)
    # transitions = torch.FloatTensor(jazz_maj).reshape(1,jazz_maj.shape[0], jazz_maj.shape[1]).to(device)
    # transitions = torch.FloatTensor(jazz_min).reshape(1,jazz_min.shape[0], jazz_min.shape[1]).to(device)
    transitions = torch.FloatTensor(blues_maj).reshape(1,organum.shape[0], organum.shape[1]).to(device)
    # attention_mask = batch['attention_mask'].to(device)
    attention_mask = None
    output_tokens = []
    output_recon_tokens = []
    num_bars = (input_ids == tokenizer.vocab['<bar>']).sum(dim=1).reshape(input_ids.shape[0],-1)
    print(input_ids)
    print(transitions)
    print(num_bars.shape)
    outputs = model(input_ids, transitions, encoder_attention=attention_mask, generate_max_tokens=500, num_bars=num_bars, temperature=1.0)
    for i in outputs['generated_ids'][0]:
        output_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
    for i in outputs['generated_recon_ids'][0]:
        output_recon_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )

tensor([[  2,   6, 180,  95,   4, 111,   4, 115,  50, 119,  50, 123,  48,   6,
          95,  50, 103,  50, 111,  50, 115,  50, 119,  50, 123,  50,   6,  95,
          48,  99,  48, 103,  48, 107,  48, 111,  48, 115,  46, 119,  45,   6,
          95,  46, 103,  46, 107,  50, 111,  50, 115,  48, 119,  46, 123,  45,
           6,  95,  43, 103,   4, 115,  43, 119,  43, 123,  50,   6,  95,  48,
         103,  48, 111,  48, 115,  48, 119,  48, 123,  48,   6,  95,  48,  99,
          48, 103,  48, 107,  48, 111,  48, 119,  46,   6,  95,  46, 103,   4,
         107,  50, 111,  50, 115,  48, 119,  46, 123,  45,   6,  95,  43, 103,
           4, 115,  43, 119,  43, 123,  50,   6,  95,  48, 103,  48, 111,  48,
         115,  48, 119,  48, 123,  48,   6,  95,  48,  99,  48, 103,  48, 107,
          48, 111,  48, 119,  48,   6,  95,  50, 103,   4, 107,  50, 111,  50,
         115,  48, 119,  46, 123,  45,   6,  95,  43, 103,   4, 115,  43, 119,
          43, 123,  50,   6,  95,  48, 103,  48, 111

In [20]:
print(transitions.shape)

torch.Size([1, 348, 348])


In [21]:
print(output_tokens)

['<s>', '<h>', '<bar>', 'position_0x00', 'F:maj', '<bar>', 'position_0x00', 'E:sus4', 'position_2x00', 'D:min', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'F:maj', '<bar>', 'position_0x00', 'E:sus4', 'position_2x00', 'D:min', '<bar>', 'position_0x00', 'G:maj', 'position_2x00', 'D:min7', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'F:maj', '<bar>', 'position_0x00', 'D:min7', '<bar>', 'position_0x00', 'F:maj', '<bar>', 'position_0x00', 'C:maj', 'position_2x00', 'D:min', 'position_3x00', 'C:maj', 'position_3x25', 'C:maj7', 'position_3x50', 'C:maj7', 'position_3x75', 'A#:maj13', 'position_0x75', 'A:maj', 'position_0x75', 'A:maj', 'position_1x50', 'C#:maj7', 'position_2x00', 'F:maj9', 'position_2x75', 'F:maj9', 'position_3x50', 'F:maj', 'position_3x75', 'F:maj', 'position_3x75', 'F:ma

In [22]:
print(output_recon_tokens)

['<s>', '<h>', 'A:maj', 'A:maj', 'A:maj', 'A:maj', 'position_5x50', 'A:maj', 'position_4x00', 'A:maj', 'C#:maj7', 'A:maj', 'G#:maj', 'A:maj', 'A:maj', 'A:maj', 'A:maj9', 'A:maj', 'A:maj', 'A:maj', 'A:maj', 'A:maj7', 'C#:maj', 'G#:sus4', 'A:maj', 'A:maj', 'A:maj', 'A:maj', 'A:maj', 'A:maj7', '<bar>', 'A:maj', 'A:maj', 'D:maj', 'A:maj7', 'A:maj', 'G:5', 'position_4x00', 'A:maj', 'A:maj7', 'A:maj7', 'A:maj', 'D:maj', 'D#:maj7', 'G:maj', 'B:min9', 'C#:maj7', 'D#:sus2', 'A:maj', 'A:maj', 'A:maj', 'D:min6', 'A:maj7', 'A:maj7', 'G:5', 'A#:hdim7', 'A:maj', 'A:maj', 'A:maj', 'G#:sus2', 'A:maj7', 'position_4x00', 'A:maj', 'A:maj', 'A:maj', 'G#:sus4', 'A:maj', 'A:maj', 'A:maj', 'A:maj', 'A:maj', 'A:maj', 'A:maj', 'position_4x00', 'A:maj7', 'A:maj', 'A:maj', '<rest>', 'A:maj', 'position_4x00', 'A:maj', 'A:maj7', 'C#:maj7', 'A:maj6', 'A:maj', 'position_4x00', 'A:maj', 'A:maj', 'D:maj', 'A:maj', 'A:maj', 'A:maj', 'A:maj', 'A:maj', 'A:maj', 'A#:hdim7', 'A:maj', 'A:maj', 'C#:maj7', 'A:maj', 'G:maj', '

In [23]:
input_ids = batch['input_ids'].to(device)
print(input_ids)
input_tokens = []
for i in input_ids[0]:
    input_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(input_tokens)

tensor([[  2,   6, 180,  95,   4, 111,   4, 115,  50, 119,  50, 123,  48,   6,
          95,  50, 103,  50, 111,  50, 115,  50, 119,  50, 123,  50,   6,  95,
          48,  99,  48, 103,  48, 107,  48, 111,  48, 115,  46, 119,  45,   6,
          95,  46, 103,  46, 107,  50, 111,  50, 115,  48, 119,  46, 123,  45,
           6,  95,  43, 103,   4, 115,  43, 119,  43, 123,  50,   6,  95,  48,
         103,  48, 111,  48, 115,  48, 119,  48, 123,  48,   6,  95,  48,  99,
          48, 103,  48, 107,  48, 111,  48, 119,  46,   6,  95,  46, 103,   4,
         107,  50, 111,  50, 115,  48, 119,  46, 123,  45,   6,  95,  43, 103,
           4, 115,  43, 119,  43, 123,  50,   6,  95,  48, 103,  48, 111,  48,
         115,  48, 119,  48, 123,  48,   6,  95,  48,  99,  48, 103,  48, 107,
          48, 111,  48, 119,  48,   6,  95,  50, 103,   4, 107,  50, 111,  50,
         115,  48, 119,  46, 123,  45,   6,  95,  43, 103,   4, 115,  43, 119,
          43, 123,  50,   6,  95,  48, 103,  48, 111

In [24]:
class ExactTokenCountLogitsProcessor(LogitsProcessor):
    def __init__(self, token_id, max_count):
        self.token_id = token_id
        self.max_count = max_count

    def __call__(self, input_ids, scores):
        token_count = (input_ids == self.token_id).sum().item()
        if token_count >= self.max_count:
            scores[:, self.token_id] = -float("inf")  # Mask the token
        return scores

class ExactTokenCountStoppingCriteria(StoppingCriteria):
    def __init__(self, token_id, max_count):
        self.token_id = token_id
        self.max_count = max_count

    def __call__(self, input_ids, scores, **kwargs):
        token_count = (input_ids == self.token_id).sum().item()
        return token_count >= self.max_count  # Stop when the count is reached


In [25]:
# Define the token ID and the exact count you want
token_id = 200
exact_count = 10

bart_outputs = bart.generate(
    input_ids=input_ids,
    # attention_mask=batch['attention_mask'][bi],
    bos_token_id=tokenizer.vocab[tokenizer.harmony_tokenizer.start_harmony_token],
    eos_token_id=tokenizer.eos_token_id,
    max_new_tokens=500,
    do_sample=True,
    temperature=1.0,
    # logits_processor=[ExactTokenCountLogitsProcessor(token_id, exact_count)],
    # stopping_criteria=StoppingCriteriaList([ExactTokenCountStoppingCriteria(token_id, exact_count)])
)

In [26]:
print(bart_outputs)
bart_only_outputs = []
for i in bart_outputs[0]:
    bart_only_outputs.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(bart_only_outputs)

tensor([[  2, 196,   6,  95, 459,   6,  95, 197,   6,  95, 400,   6,  95, 256,
           6,  95, 459,   6,  95, 197,   6,  95, 400,   6,  95, 256,   6,  95,
         459,   6,  95, 197,   6,  95, 400,   6,  95, 256,   6,  95, 459,   6,
          95, 197,   6,  95, 400,   6,  95, 256,   6,  95, 459,   6,  95, 459,
           3]], device='cuda:0')
['<s>', '<h>', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'D:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'D:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'D:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'D:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', '</s>'

In [27]:
print(bart_only_outputs)
print(output_tokens)
print(output_recon_tokens)

['<s>', '<h>', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'D:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'D:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'D:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'D:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', '</s>']
['<s>', '<h>', '<bar>', 'position_0x00', 'F:maj', '<bar>', 'position_0x00', 'E:sus4', 'position_2x00', 'D:min', '<bar>', 'position_0x00', 'G:maj', '<bar>', 'position_0x00', 'C:maj', '<bar>', 'position_0x00', 'F:maj', '<bar>', 'position_0x00', 'E:sus4', 'position_2x00', 'D:min', '<bar>', 'position_0x00', 'G:maj', 'position_2x00', 'D:min7', '<bar>

In [28]:
tokenizer.decode( input_tokens + bart_only_outputs[1:], output_format='file', output_path='examples/bart.mxl' )
tokenizer.decode( input_tokens + output_tokens[1:], output_format='file', output_path='examples/encdec.mxl' )
tokenizer.decode( input_tokens + output_recon_tokens[1:], output_format='file', output_path='examples/recon.mxl' )

Saved as examples/bart.mxl
Saved as examples/encdec.mxl
unknown chord symbol token:  P:61
Saved as examples/recon.mxl


In [29]:
print( (outputs['generated_markov']-transitions.to('cpu')).pow(2).sum().sqrt() )
print( (outputs['recon_markov']-transitions.to('cpu')).pow(2).sum().sqrt() )

tensor(3.7821, dtype=torch.float64)
tensor(6.6577, dtype=torch.float64)
