In [1]:
from data_utils import SeparatedMelHarmTextDataset, MelHarmTextCollatorForSeq2Seq
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
import torch
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm import tqdm
import torch.nn as nn
from transformers import RobertaModel, RobertaTokenizer
from models import TextGuidedHarmonizationModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
chordSymbolTokenizer = ChordSymbolTokenizer.from_pretrained('saved_tokenizers/ChordSymbolTokenizer')
rootTypeTokenizer = RootTypeTokenizer.from_pretrained('saved_tokenizers/RootTypeTokenizer')
pitchClassTokenizer = PitchClassTokenizer.from_pretrained('saved_tokenizers/PitchClassTokenizer')
rootPCTokenizer = RootPCTokenizer.from_pretrained('saved_tokenizers/RootPCTokenizer')
melodyPitchTokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')

In [3]:
m_chordSymbolTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, chordSymbolTokenizer)
m_rootTypeTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootTypeTokenizer)
m_pitchClassTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, pitchClassTokenizer)
m_rootPCTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootPCTokenizer)

In [4]:
tokenizer = m_chordSymbolTokenizer
tokenizer_name = 'ChordSymbolTokenizer'
# tokenizer = m_rootTypeTokenizer
# tokenizer_name = 'RootTypeTokenizer'
# tokenizer = m_pitchClassTokenizer
# tokenizer_name = 'PitchClassTokenizer'
# tokenizer = m_rootPCTokenizer
# tokenizer_name = 'RootPCTokenizer'

description_mode = 'specific_chord'

train_dir = '/media/maindisk/maximos/data/hooktheory_train'
test_dir = '/media/maindisk/maximos/data/hooktheory_test'

train_dataset = SeparatedMelHarmTextDataset(
    train_dir,
    tokenizer,
    max_length=512,
    num_bars=64,
    description_mode=description_mode,
    alteration=True
)

test_dataset = SeparatedMelHarmTextDataset(
    test_dir,
    tokenizer,
    max_length=512,
    num_bars=64,
    description_mode=description_mode,
    alteration=True
)

def create_data_collator(tokenizer, model):
    return MelHarmTextCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
# end create_data_collator

In [5]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device_name = 'cpu'
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)

In [6]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

bart = BartForConditionalGeneration(bart_config)
# bart_path = 'saved_models/bart/' + tokenizer_name + '/' + tokenizer_name + '.pt'
# checkpoint = torch.load(bart_path, map_location=device_name, weights_only=True)
# bart.load_state_dict(checkpoint)

model = TextGuidedHarmonizationModel(bart, device=device)

model_path = 'saved_models/bart_text_cvae/' + tokenizer_name + '/' +description_mode+'/' + tokenizer_name +'_' +description_mode+ '.pt'
checkpoint = torch.load(model_path, map_location=device_name, weights_only=True)
model.load_state_dict(checkpoint)

model.to(device)
model.eval()

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


TextGuidedHarmonizationModel(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
             

In [7]:
collator = create_data_collator(tokenizer, model=model.bart)

In [8]:
trainloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collator)
valloader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collator)

In [9]:
b = next(iter(valloader))
print(b.keys())

dict_keys(['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'txt', 'harmony_input_ids'])


  return self.iter().getElementsByClass(classFilterList)
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
  [torch.tensor(h) for h in harmony_inputs],


In [10]:
num_bars = (b['input_ids'] == tokenizer.vocab['<bar>']).sum(dim=1).reshape(b['input_ids'].shape[0],-1)
print(num_bars)
outputs = model.generate(tokenizer, b['input_ids'], b['attention_mask'], b['txt'], max_length=500, num_bars=num_bars, temperature=1.0)

tensor([[16]])


In [11]:
print(outputs)

tensor([[  2, 196,   6,   6,  95,   6, 459,  95,   6, 459,  95,   6, 459,  95,
           6, 459,  95,   6, 459,  95,   6, 459,  95, 107, 459, 459, 119,   6,
         459,  95,   6, 454,  95, 107, 459, 459, 119, 107, 459, 459,   6,   6,
          95,  95, 459, 459, 107,   6, 459,  95,   3]], device='cuda:0')


In [None]:
output_tokens = []

for i in outputs[0]:
    output_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )

In [13]:
print(b['txt'])
print(output_tokens)

['Bar number 1 begins with a G#:7(#11) chord.']
['<s>', '<h>', '<bar>', '<bar>', 'position_0x00', '<bar>', 'A:min', 'position_0x00', '<bar>', 'A:min', 'position_0x00', '<bar>', 'A:min', 'position_0x00', '<bar>', 'A:min', 'position_0x00', '<bar>', 'A:min', 'position_0x00', '<bar>', 'A:min', 'position_0x00', 'position_1x50', 'A:min', 'A:min', 'position_3x00', '<bar>', 'A:min', 'position_0x00', '<bar>', 'G#:7(b9)', 'position_0x00', 'position_1x50', 'A:min', 'A:min', 'position_3x00', 'position_1x50', 'A:min', 'A:min', '<bar>', '<bar>', 'position_0x00', 'position_0x00', 'A:min', 'A:min', 'position_1x50', '<bar>', 'A:min', 'position_0x00', '</s>']


In [15]:
loss_output = model(
    b['input_ids'].to(device),
    b['attention_mask'].to(device),
    b['harmony_input_ids'].to(device),
    b['txt'],
    labels=b['labels'].to(device)
)

In [19]:
print(loss_output['loss'])
predictions = loss_output['logits'].argmax(dim=-1)
mask = b['labels'] != -100
running_accuracy = (predictions[mask] == b['labels'][mask].to(device)).sum().item()/mask.sum().item()
print(running_accuracy)
print(predictions)

tensor(0.4857, device='cuda:0', grad_fn=<NllLossBackward0>)
0.9090909090909091
tensor([[196,   6,  95, 459,   6,  95, 459,   6,  95, 459,   6,  95, 459,   6,
          95, 459,   6, 400,   6,  95, 342, 111, 342,   6,  95, 459, 111, 400,
           6,  95, 314, 111, 342,   6,  95, 459, 111,  95, 459,   6,  95, 459,
         111,  95, 314,   6,  95, 459, 111, 400,   6,  95, 314, 111, 342,   6,
          95, 459, 111, 400,   6,  95, 314, 111, 342,   3, 111, 111, 111,   3,
           6, 111,   6,   3,   6,   3,   6,   3,   6,   3,   6,   3, 111, 111,
           6,   6,   6, 111,   6,   3,   6,   6, 111, 111, 111,   6,   6,   3,
           6, 111,   6,   3, 111, 111,   6,   3, 111, 111, 111, 111, 111, 111,
         111,   3,   6, 111, 111, 111,   6, 111, 111,   6,   6,   6,   6, 111,
           6, 111, 111, 111,   6,   3,   6, 111,   6, 111, 111, 111,   6,   3,
           6,   3,   6, 111, 111, 111, 111, 111, 111, 111,   6, 111,   6, 111,
         111,   6,   6,   6,   6, 111, 111, 111,   6

In [22]:
prediction_tokens = []
label_tokens = []

for i in predictions[0]:
    prediction_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )
for i in b['labels'][0]:
    if i >= 0:
        label_tokens.append( tokenizer.ids_to_tokens[ int(i) ].replace(' ','x') )

In [23]:
print(label_tokens)
print(prediction_tokens)

['<h>', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'G#:7(#11)', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', 'position_2x00', 'G:maj', '<bar>', 'position_0x00', 'E:min', 'position_2x00', 'F:maj', '<bar>', 'position_0x00', 'A:min', 'position_2x00', 'G:maj', '<bar>', 'position_0x00', 'E:min', 'position_2x00', 'F:maj', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', 'position_2x00', 'G:maj', '<bar>', 'position_0x00', 'E:min', 'position_2x00', 'F:maj', '<bar>', 'position_0x00', 'A:min', 'position_2x00', 'G:maj', '<bar>', 'position_0x00', 'E:min', 'position_2x00', 'F:maj', '</s>']
['<h>', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'position_0x00', 'A:min', '<bar>', 'G:maj', '<bar>', 'p