In [58]:
from data_utils import SeparatedMelHarmTextDataset, MelHarmTextCollatorForSeq2Seq
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
import torch
from torch.optim import AdamW
from tqdm import tqdm
from models import TransTextVAE

In [59]:
chordSymbolTokenizer = ChordSymbolTokenizer.from_pretrained('saved_tokenizers/ChordSymbolTokenizer')
rootTypeTokenizer = RootTypeTokenizer.from_pretrained('saved_tokenizers/RootTypeTokenizer')
pitchClassTokenizer = PitchClassTokenizer.from_pretrained('saved_tokenizers/PitchClassTokenizer')
rootPCTokenizer = RootPCTokenizer.from_pretrained('saved_tokenizers/RootPCTokenizer')
melodyPitchTokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')

In [60]:
m_chordSymbolTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, chordSymbolTokenizer)
m_rootTypeTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootTypeTokenizer)
m_pitchClassTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, pitchClassTokenizer)
m_rootPCTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootPCTokenizer)

In [61]:
# tokenizer = m_chordSymbolTokenizer
# tokenizer_name = 'ChordSymbolTokenizer'
# tokenizer = m_rootTypeTokenizer
# tokenizer_name = 'RootTypeTokenizer'
# tokenizer = m_pitchClassTokenizer
# tokenizer_name = 'PitchClassTokenizer'
tokenizer = m_rootPCTokenizer
tokenizer_name = 'RootPCTokenizer'

root_dir = '/media/maindisk/maximos/data/hooktheory_test'
dataset = SeparatedMelHarmTextDataset(root_dir, tokenizer, max_length=512, num_bars=64)
def create_data_collator(tokenizer, model):
    return MelHarmTextCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
# end create_data_collator

In [62]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [63]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

bart = BartForConditionalGeneration(bart_config)

bart_path = 'saved_models/bart/' + tokenizer_name + '/' + tokenizer_name + '.pt'
if device == 'cpu':
    checkpoint = torch.load(bart_path, map_location="cpu", weights_only=True)
else:
    checkpoint = torch.load(bart_path, weights_only=True)
bart.load_state_dict(checkpoint)

bart.to(device)
bart.eval()

bart_encoder, bart_decoder = bart.get_encoder(), bart.get_decoder()
bart_encoder.to(device)
bart_decoder.to(device)

# Freeze BART parameters
for param in bart_encoder.parameters():
    param.requires_grad = False
for param in bart_encoder.parameters():
    param.requires_grad = False

In [64]:
collator = create_data_collator(tokenizer, model=bart)

In [65]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [66]:
b = next(iter(dataloader))

In [67]:
config = {
    'lstm_dim': 2048,
    'roberta_model': "roberta-base",
    'latent_dim': 2048,
    'freeze_roberta': True
}

model = TransTextVAE(bart, tokenizer=tokenizer, device=device, config=config)
model.to(device)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


TransTextVAE(
  (transformer): BartForConditionalGeneration(
    (model): BartModel(
      (shared): BartScaledWordEmbedding(221, 512, padding_idx=1)
      (encoder): BartEncoder(
        (embed_tokens): BartScaledWordEmbedding(221, 512, padding_idx=1)
        (embed_positions): BartLearnedPositionalEmbedding(514, 512)
        (layers): ModuleList(
          (0-7): 8 x BartEncoderLayer(
            (self_attn): BartSdpaAttention(
              (k_proj): Linear(in_features=512, out_features=512, bias=True)
              (v_proj): Linear(in_features=512, out_features=512, bias=True)
              (q_proj): Linear(in_features=512, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=512, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=512, out_features=512, bias=True)
            (fc2): Linear(in_featu

In [68]:
b['input_ids']

tensor([[  2,   6, 178,  95,  62, 111,  60, 115,  65,   6,  95,  62, 111,  60,
         115,  65,   6,  95,  62, 111,   4,   6,  95,  60, 103,  58, 111,  57,
           6,  95,  55, 103,   4, 111,  53, 115,  60,   6,  95,  55, 111,  53,
           6,  95,  55, 107,  55, 111,  62, 115,  60,   6,  95,  62, 111,   4,
           6,  95,  62, 111,  60, 115,  65,   6,  95,  62, 111,  60, 115,  65,
           6,  95,  62, 111,   4,   6,  95,  60, 103,  58, 111,  57,   6,  95,
          55, 103,   4, 111,  53, 115,  60,   6,  95,  55, 103,   4, 111,  53,
         115,  62,   6,  95,  55]])

In [69]:
# from transformers import RobertaModel, RobertaTokenizer
# roberta_model = "roberta-base"
# # Load RoBERTa
# roberta = RobertaModel.from_pretrained(roberta_model)
# text_tokenizer = RobertaTokenizer.from_pretrained(roberta_model)

In [70]:
# txts = ['Bar number 2 begins with a A:(7b9) chord.']
# roberta_inputs = text_tokenizer(
#     txts, padding=True, truncation=True, return_tensors="pt"
# ).to(device)
# print(roberta_inputs)
# print(text_tokenizer.decode(roberta_inputs['input_ids'][0]))

In [77]:
with torch.no_grad():
    input_ids = b['input_ids'].to(device)
    txts = ['Bar number 0 begins with a G:maj chord.']
    num_bars = (input_ids == tokenizer.vocab['<bar>']).sum(dim=1).reshape(input_ids.shape[0],-1)
    outputs = model(input_ids, txts, encoder_attention=None, generate_max_tokens=500, num_bars=num_bars, temperature=1.0)

recon generation
bars_left: tensor([[15]], device='cuda:0')
bars_left: tensor([[15]], device='cuda:0')
bars_left: tensor([[14]], device='cuda:0')
bars_left: tensor([[14]], device='cuda:0')
bars_left: tensor([[14]], device='cuda:0')
bars_left: tensor([[14]], device='cuda:0')
bars_left: tensor([[14]], device='cuda:0')
bars_left: tensor([[14]], device='cuda:0')
bars_left: tensor([[14]], device='cuda:0')
bars_left: tensor([[14]], device='cuda:0')
bars_left: tensor([[14]], device='cuda:0')
bars_left: tensor([[13]], device='cuda:0')
bars_left: tensor([[13]], device='cuda:0')
bars_left: tensor([[13]], device='cuda:0')
bars_left: tensor([[13]], device='cuda:0')
bars_left: tensor([[13]], device='cuda:0')
bars_left: tensor([[13]], device='cuda:0')
bars_left: tensor([[13]], device='cuda:0')
bars_left: tensor([[13]], device='cuda:0')
bars_left: tensor([[13]], device='cuda:0')
bars_left: tensor([[12]], device='cuda:0')
bars_left: tensor([[12]], device='cuda:0')
bars_left: tensor([[12]], device='cud

In [78]:
print(outputs)

{'loss': tensor(0.9255, device='cuda:0'), 'recon_loss': tensor(0.9174, device='cuda:0'), 'kl_loss': tensor(0.0081, device='cuda:0'), 'x': tensor([[[-0.3162, -0.8094,  2.9027,  ...,  0.2310,  0.5949, -0.1098],
         [-0.9619,  0.0118,  0.5210,  ...,  0.3674,  1.3328,  1.5655],
         [ 1.2174, -0.7365,  1.8608,  ...,  0.3479,  1.0249,  0.8572],
         ...,
         [ 0.1112, -0.7174,  0.5119,  ..., -1.4962, -0.5612,  0.1700],
         [ 0.9170, -0.7247, -0.8587,  ..., -0.6898,  0.0823, -1.6773],
         [ 1.9352,  0.5214, -0.8843,  ..., -1.0361, -0.1259,  0.6265]]],
       device='cuda:0'), 'recon_x': tensor([[[-9.3821e-02, -2.6970e-02, -2.2643e-01,  ..., -6.9977e-02,
           2.0411e-03,  5.7349e-02],
         [-8.8398e-02, -5.4961e-02, -2.2857e-01,  ..., -5.3729e-02,
          -2.2503e-02,  6.5256e-02],
         [-8.1637e-02, -7.1162e-02, -2.2807e-01,  ..., -4.1609e-02,
          -3.2813e-02,  7.0468e-02],
         ...,
         [-8.6004e-02, -9.1715e-02, -2.0698e-01,  ..., 

In [79]:
output_tokens = []
output_recon_tokens = []

for i in outputs['generated_ids'][0]:
    output_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
for i in outputs['generated_recon_ids'][0]:
    output_recon_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )

In [80]:
print(output_tokens)
print(output_recon_tokens)

['<s>', '<h>', '<bar>', 'position_0x00', 'chord_root_5', 'chord_pc_9', 'chord_pc_0', 'chord_pc_4', '<bar>', 'position_0x00', 'chord_root_7', 'chord_pc_11', 'chord_pc_2', '<bar>', 'position_0x00', 'chord_root_9', 'chord_pc_0', 'chord_pc_4', '<bar>', 'position_0x00', 'chord_root_9', 'chord_pc_0', 'chord_pc_4', '<bar>', 'position_0x00', 'chord_root_5', 'chord_pc_9', 'chord_pc_0', 'chord_pc_4', '<bar>', 'position_0x00', 'chord_root_7', 'chord_pc_11', 'chord_pc_2', '<bar>', 'position_0x00', 'chord_root_9', 'chord_pc_2', 'chord_pc_4', '<bar>', 'position_0x00', 'chord_root_9', 'chord_pc_1', 'chord_pc_4', '<bar>', 'position_0x00', 'chord_root_5', 'chord_pc_9', 'chord_pc_0', 'chord_pc_4', '<bar>', 'position_0x00', 'chord_root_7', 'chord_pc_11', 'chord_pc_2', '<bar>', 'position_0x00', 'chord_root_9', 'chord_pc_2', 'chord_pc_4', '<bar>', 'position_0x00', 'chord_root_9', 'chord_pc_2', 'chord_pc_4', '<bar>', 'position_0x00', 'chord_root_5', 'chord_pc_9', 'chord_pc_0', 'chord_pc_4', '<bar>', 'positi

In [81]:
input_ids = b['input_ids'].to(device)
print(input_ids)
input_tokens = []
for i in input_ids[0]:
    input_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
print(input_tokens)

tensor([[  2,   6, 178,  95,  62, 111,  60, 115,  65,   6,  95,  62, 111,  60,
         115,  65,   6,  95,  62, 111,   4,   6,  95,  60, 103,  58, 111,  57,
           6,  95,  55, 103,   4, 111,  53, 115,  60,   6,  95,  55, 111,  53,
           6,  95,  55, 107,  55, 111,  62, 115,  60,   6,  95,  62, 111,   4,
           6,  95,  62, 111,  60, 115,  65,   6,  95,  62, 111,  60, 115,  65,
           6,  95,  62, 111,   4,   6,  95,  60, 103,  58, 111,  57,   6,  95,
          55, 103,   4, 111,  53, 115,  60,   6,  95,  55, 103,   4, 111,  53,
         115,  62,   6,  95,  55]], device='cuda:0')
['<s>', '<bar>', 'ts_3x4', 'position_0x00', 'P:76', 'position_2x00', 'P:74', 'position_2x50', 'P:79', '<bar>', 'position_0x00', 'P:76', 'position_2x00', 'P:74', 'position_2x50', 'P:79', '<bar>', 'position_0x00', 'P:76', 'position_2x00', '<rest>', '<bar>', 'position_0x00', 'P:74', 'position_1x00', 'P:72', 'position_2x00', 'P:71', '<bar>', 'position_0x00', 'P:69', 'position_1x00', '<rest>', 'p

In [82]:
os.makedirs('examples', exist_ok=True)
tokenizer.decode( input_tokens + output_tokens[1:], output_format='file', output_path='examples/encdec.mxl' )
tokenizer.decode( input_tokens + output_recon_tokens[1:], output_format='file', output_path='examples/recon.mxl' )

Saved as examples/encdec.mxl
Saved as examples/recon.mxl
