In [None]:
from data_utils import SeparatedMelHarmTextDataset, MelHarmTextCollatorForSeq2Seq
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
import torch
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm import tqdm
import torch.nn as nn
from transformers import RobertaModel, RobertaTokenizer
from models import TextGuidedHarmonizationModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
chordSymbolTokenizer = ChordSymbolTokenizer.from_pretrained('saved_tokenizers/ChordSymbolTokenizer')
rootTypeTokenizer = RootTypeTokenizer.from_pretrained('saved_tokenizers/RootTypeTokenizer')
pitchClassTokenizer = PitchClassTokenizer.from_pretrained('saved_tokenizers/PitchClassTokenizer')
rootPCTokenizer = RootPCTokenizer.from_pretrained('saved_tokenizers/RootPCTokenizer')
melodyPitchTokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')

In [3]:
m_chordSymbolTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, chordSymbolTokenizer)
m_rootTypeTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootTypeTokenizer)
m_pitchClassTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, pitchClassTokenizer)
m_rootPCTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootPCTokenizer)

In [4]:
# tokenizer = m_chordSymbolTokenizer
# tokenizer_name = 'ChordSymbolTokenizer'
# tokenizer = m_rootTypeTokenizer
# tokenizer_name = 'RootTypeTokenizer'
# tokenizer = m_pitchClassTokenizer
# tokenizer_name = 'PitchClassTokenizer'
tokenizer = m_rootPCTokenizer
tokenizer_name = 'RootPCTokenizer'

root_dir = '/media/maindisk/maximos/data/hooktheory_test'

dataset = SeparatedMelHarmTextDataset(
    root_dir,
    tokenizer,
    max_length=512,
    num_bars=64,
    description_mode='specific_chord',
    alteration=True
)

def create_data_collator(tokenizer, model):
    return MelHarmTextCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
# end create_data_collator

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_name = 'cpu'
# device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)

In [6]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

bart = BartForConditionalGeneration(bart_config)

bart_path = 'saved_models/bart/' + tokenizer_name + '/' + tokenizer_name + '.pt'
checkpoint = torch.load(bart_path, map_location=device_name, weights_only=True)
bart.load_state_dict(checkpoint)

bart.to(device)
bart.eval()

bart_encoder, bart_decoder = bart.get_encoder(), bart.get_decoder()
bart_encoder.to(device)
bart_decoder.to(device)

# # Freeze BART parameters
# for param in bart_encoder.parameters():
#     param.requires_grad = False
# for param in bart_encoder.parameters():
#     param.requires_grad = False

BartDecoder(
  (embed_tokens): BartScaledWordEmbedding(221, 512, padding_idx=1)
  (embed_positions): BartLearnedPositionalEmbedding(514, 512)
  (layers): ModuleList(
    (0-7): 8 x BartDecoderLayer(
      (self_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=512, out_features=512, bias=True)
        (v_proj): Linear(in_features=512, out_features=512, bias=True)
        (q_proj): Linear(in_features=512, out_features=512, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (activation_fn): GELUActivation()
      (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=512, out_features=512, bias=True)
        (v_proj): Linear(in_features=512, out_features=512, bias=True)
        (q_proj): Linear(in_features=512, out_features=512, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (enc

In [7]:
collator = create_data_collator(tokenizer, model=bart)

In [8]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [9]:
b = next(iter(dataloader))

  return self.iter().getElementsByClass(classFilterList)


In [10]:
print(b.keys())

dict_keys(['input_ids', 'attention_mask', 'harmony_input_ids', 'labels', 'decoder_input_ids', 'txt'])


In [11]:
print(b['labels'])

tensor([[-100,    6,   95,  197,  213,  216,  111,  204,  220,  211,    6,   95,
          197,  213,  216,  111,  204,  220,  211,    6,   95,  197,  213,  216,
          111,  202,  218,  209,    6,   95,  204,  220,  211,    6,   95,  197,
          213,  216,  111,  204,  220,  211,    6,   95,  197,  213,  216,  111,
          204,  220,  211,    6,   95,  197,  213,  216,  111,  202,  218,  209,
            6,   95,  197,  211,  213,  216,  220,    1]])


In [12]:
model = TextGuidedHarmonizationModel(bart, device=device)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
model_input_ids = b['input_ids'].to(device)
melody_attention_mask = b['attention_mask'].to(device)
harmony_input_ids = b['harmony_input_ids'].to(device)
labels = b['labels'].to(device)
texts = b['txt']

In [None]:
output = model(model_input_ids, melody_attention_mask, harmony_input_ids, texts, labels=labels)
decoder_loss = output['loss']
decoder_logits = output['logits']

hat_h:  torch.Size([1, 102, 512])
labels:  torch.Size([1, 68])


In [15]:
print(decoder_loss)
print(decoder_logits)

tensor(1.0673, grad_fn=<NllLossBackward0>)
tensor([[[ -8.8498,  -8.9703,   2.5205,  ...,   0.3193,   0.6228,   0.6747],
         [-22.2938, -22.1433,  -2.1604,  ...,   0.3856,  -0.9911,   0.7216],
         [-20.5718, -20.5356,  -1.4832,  ...,  -2.3109,  -3.1869,  -2.1093],
         ...,
         [-17.3761, -17.5985,  -2.0772,  ...,   1.9245,  -1.8549,  -1.1557],
         [-19.1630, -19.1568,  -2.4330,  ...,   3.9359,   6.1086,   8.8422],
         [-18.7110, -18.9539,  -3.0130,  ...,  -2.3950,  -2.7397,   0.4671]]],
       grad_fn=<AddBackward0>)


In [None]:
num_epochs = 5

# Define optimizer (only update trainable parameters)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)

# Learning rate scheduler
num_training_steps = len(dataloader) * num_epochs
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)