In [17]:
from data_utils import SeparatedMelHarmTextDataset, MelHarmTextCollatorForSeq2Seq
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
import torch
from torch.optim import AdamW
from tqdm import tqdm
import torch.nn as nn
from transformers import RobertaModel, RobertaTokenizer
from models import TextGuidedHarmonizationModel

In [18]:
chordSymbolTokenizer = ChordSymbolTokenizer.from_pretrained('saved_tokenizers/ChordSymbolTokenizer')
rootTypeTokenizer = RootTypeTokenizer.from_pretrained('saved_tokenizers/RootTypeTokenizer')
pitchClassTokenizer = PitchClassTokenizer.from_pretrained('saved_tokenizers/PitchClassTokenizer')
rootPCTokenizer = RootPCTokenizer.from_pretrained('saved_tokenizers/RootPCTokenizer')
melodyPitchTokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')

In [19]:
m_chordSymbolTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, chordSymbolTokenizer)
m_rootTypeTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootTypeTokenizer)
m_pitchClassTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, pitchClassTokenizer)
m_rootPCTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootPCTokenizer)

In [20]:
# tokenizer = m_chordSymbolTokenizer
# tokenizer_name = 'ChordSymbolTokenizer'
# tokenizer = m_rootTypeTokenizer
# tokenizer_name = 'RootTypeTokenizer'
# tokenizer = m_pitchClassTokenizer
# tokenizer_name = 'PitchClassTokenizer'
tokenizer = m_rootPCTokenizer
tokenizer_name = 'RootPCTokenizer'

root_dir = '/media/maindisk/maximos/data/hooktheory_test'
dataset = SeparatedMelHarmTextDataset(root_dir, tokenizer, max_length=512, num_bars=64)
def create_data_collator(tokenizer, model):
    return MelHarmTextCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
# end create_data_collator

In [21]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_name = 'cpu'
device = torch.device("cpu")

In [22]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

bart = BartForConditionalGeneration(bart_config)

bart_path = 'saved_models/bart/' + tokenizer_name + '/' + tokenizer_name + '.pt'
checkpoint = torch.load(bart_path, map_location=device_name, weights_only=True)
bart.load_state_dict(checkpoint)

bart.to(device)
bart.eval()

bart_encoder, bart_decoder = bart.get_encoder(), bart.get_decoder()
bart_encoder.to(device)
bart_decoder.to(device)

# # Freeze BART parameters
# for param in bart_encoder.parameters():
#     param.requires_grad = False
# for param in bart_encoder.parameters():
#     param.requires_grad = False

BartDecoder(
  (embed_tokens): BartScaledWordEmbedding(221, 512, padding_idx=1)
  (embed_positions): BartLearnedPositionalEmbedding(514, 512)
  (layers): ModuleList(
    (0-7): 8 x BartDecoderLayer(
      (self_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=512, out_features=512, bias=True)
        (v_proj): Linear(in_features=512, out_features=512, bias=True)
        (q_proj): Linear(in_features=512, out_features=512, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (activation_fn): GELUActivation()
      (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=512, out_features=512, bias=True)
        (v_proj): Linear(in_features=512, out_features=512, bias=True)
        (q_proj): Linear(in_features=512, out_features=512, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (enc

In [23]:
collator = create_data_collator(tokenizer, model=bart)

In [24]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [25]:
b = next(iter(dataloader))

  return self.iter().getElementsByClass(classFilterList)


In [26]:
print(b.keys())

dict_keys(['input_ids', 'attention_mask', 'harmony_input_ids', 'labels', 'decoder_input_ids', 'txt'])


In [27]:
print(b['labels'])

tensor([[ 196,    6,   95,  202,  218,  209,  119,  204,  220,  211,  123,  206,
          209,  213,    6,   95,  206,  209,  213,    6,   95,  204,  220,  211,
          119,  206,  209,  213,  123,  201,  216,  220,    6,   95,  201,  216,
          220,  115,  199,  214,  218,  119,  201,  216,  220,    6,   95,  202,
          218,  209,  119,  204,  220,  211,  123,  206,  209,  213,    6,   95,
          206,  209,  213,    6,   95,  204,  220,  211,  119,  197,  213,  216,
          123,  197,  213,  216,    6,   95,  197,  213,  216,    3, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -

In [28]:
model = TextGuidedHarmonizationModel(bart, device=device)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [29]:
model_input_ids = b['input_ids'].to(device)
melody_attention_mask = b['attention_mask'].to(device)
harmony_input_ids = b['harmony_input_ids'].to(device)
labels = b['labels'].to(device)
texts = b['txt']

In [30]:
decoder_loss, decoder_logits = model(model_input_ids, melody_attention_mask, harmony_input_ids, texts, labels=labels)

hat_h:  torch.Size([1, 120, 512])
labels:  torch.Size([1, 392])


In [31]:
print(decoder_loss)
print(decoder_logits)

tensor(0.5530, grad_fn=<NllLossBackward0>)
tensor([[[ -8.9886,  -9.1118,   2.4815,  ...,   0.3629,   0.7342,   0.6897],
         [-22.2846, -22.1352,  -2.1934,  ...,   0.3874,  -0.8075,   0.7557],
         [-20.6434, -20.6101,  -1.5172,  ...,  -2.3137,  -3.0728,  -2.0902],
         ...,
         [-12.8782, -12.6370,  -2.0237,  ...,  -1.3478,  -0.8248,  -1.4532],
         [-12.9575, -12.7007,  -2.0124,  ...,  -1.8641,  -1.0179,  -1.7308],
         [-13.0036, -12.7806,  -2.0334,  ...,  -1.9939,  -1.4933,  -2.4119]]],
       grad_fn=<AddBackward0>)
