In [1]:
from data_utils import SeparatedMelHarmTextDataset, MelHarmTextCollatorForSeq2Seq
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
import torch
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm import tqdm
import torch.nn as nn
from transformers import RobertaModel, RobertaTokenizer
from models import TextGuidedHarmonizationModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
chordSymbolTokenizer = ChordSymbolTokenizer.from_pretrained('saved_tokenizers/ChordSymbolTokenizer')
rootTypeTokenizer = RootTypeTokenizer.from_pretrained('saved_tokenizers/RootTypeTokenizer')
pitchClassTokenizer = PitchClassTokenizer.from_pretrained('saved_tokenizers/PitchClassTokenizer')
rootPCTokenizer = RootPCTokenizer.from_pretrained('saved_tokenizers/RootPCTokenizer')
melodyPitchTokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')

In [3]:
m_chordSymbolTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, chordSymbolTokenizer)
m_rootTypeTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootTypeTokenizer)
m_pitchClassTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, pitchClassTokenizer)
m_rootPCTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootPCTokenizer)

In [4]:
# tokenizer = m_chordSymbolTokenizer
# tokenizer_name = 'ChordSymbolTokenizer'
# tokenizer = m_rootTypeTokenizer
# tokenizer_name = 'RootTypeTokenizer'
# tokenizer = m_pitchClassTokenizer
# tokenizer_name = 'PitchClassTokenizer'
tokenizer = m_rootPCTokenizer
tokenizer_name = 'RootPCTokenizer'

train_dir = '/media/maindisk/maximos/data/hooktheory_train'
test_dir = '/media/maindisk/maximos/data/hooktheory_test'

train_dataset = SeparatedMelHarmTextDataset(
    train_dir,
    tokenizer,
    max_length=512,
    num_bars=64,
    description_mode='specific_chord',
    alteration=True
)

test_dataset = SeparatedMelHarmTextDataset(
    test_dir,
    tokenizer,
    max_length=512,
    num_bars=64,
    description_mode='specific_chord',
    alteration=True
)

def create_data_collator(tokenizer, model):
    return MelHarmTextCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
# end create_data_collator

In [5]:
print(train_dataset[0]['harmony_input_ids'])

tensor([196, 196,   6,  95, 197, 213, 216, 111, 197, 211, 216,   6,  95, 202,
        218, 209, 107, 197, 211, 216,   6,  95, 197, 213, 216, 111, 197, 211,
        216,   6,  95, 197, 214, 216, 107, 202, 218, 209,   6,  95, 197, 213,
        216, 111, 197, 211, 216,   6,  95, 202, 218, 209, 107, 197, 211, 216,
          6,  95, 207, 211, 214, 109, 202, 218, 209,   6,  95, 197, 213, 216,
        107, 204, 220, 211,   6,  95, 197, 213, 216, 111, 197, 211, 216,   6,
         95, 197, 213, 216, 218, 216,   6,  95, 197, 213, 216, 111, 197, 211,
        216,   6,  95, 197, 214, 216, 107, 202, 218, 209,   6,  95, 197, 213,
        216, 111, 197, 211, 216,   6,  95, 202, 218, 209, 107, 197, 211, 216,
          6,  95, 207, 211, 214, 109, 202, 218, 209,   6,  95, 197, 213, 216,
        107, 204, 220, 211,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,  

  return self.iter().getElementsByClass(classFilterList)


In [6]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device_name = 'cpu'
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)

In [7]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

bart = BartForConditionalGeneration(bart_config)

bart_path = 'saved_models/bart/' + tokenizer_name + '/' + tokenizer_name + '.pt'
checkpoint = torch.load(bart_path, map_location=device_name, weights_only=True)
bart.load_state_dict(checkpoint)

bart.to(device)
bart.eval()

bart_encoder, bart_decoder = bart.get_encoder(), bart.get_decoder()
bart_encoder.to(device)
bart_decoder.to(device)

# # Freeze BART parameters
# for param in bart_encoder.parameters():
#     param.requires_grad = False
# for param in bart_encoder.parameters():
#     param.requires_grad = False

BartDecoder(
  (embed_tokens): BartScaledWordEmbedding(221, 512, padding_idx=1)
  (embed_positions): BartLearnedPositionalEmbedding(514, 512)
  (layers): ModuleList(
    (0-7): 8 x BartDecoderLayer(
      (self_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=512, out_features=512, bias=True)
        (v_proj): Linear(in_features=512, out_features=512, bias=True)
        (q_proj): Linear(in_features=512, out_features=512, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (activation_fn): GELUActivation()
      (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=512, out_features=512, bias=True)
        (v_proj): Linear(in_features=512, out_features=512, bias=True)
        (q_proj): Linear(in_features=512, out_features=512, bias=True)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (enc

In [8]:
collator = create_data_collator(tokenizer, model=bart)

In [9]:
trainloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collator)
valloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=collator)

In [10]:
b = next(iter(trainloader))

  return self.iter().getElementsByClass(classFilterList)
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
  [torch.tensor(h) for h in harmony_inputs],


In [11]:
print(b.keys())

dict_keys(['input_ids', 'attention_mask', 'labels', 'decoder_input_ids', 'txt', 'harmony_input_ids'])


In [12]:
print(b['labels'])

tensor([[ 196,    6,   95,  ..., -100, -100, -100],
        [ 196,    6,    6,  ..., -100, -100, -100],
        [ 196,    6,   95,  ..., -100, -100, -100],
        ...,
        [ 196,    6,   95,  ..., -100, -100, -100],
        [ 196,    6,   95,  ..., -100, -100, -100],
        [ 196,    6,    6,  ..., -100, -100, -100]])


In [13]:
model = TextGuidedHarmonizationModel(bart, device=device)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
model_input_ids = b['input_ids'].to(device)
melody_attention_mask = b['attention_mask'].to(device)
harmony_input_ids = b['harmony_input_ids'].to(device)
labels = b['labels'].to(device)
texts = b['txt']

In [15]:
output = model(model_input_ids, melody_attention_mask, harmony_input_ids, texts, labels=labels)
decoder_loss = output['loss']
decoder_logits = output['logits']

In [16]:
print(decoder_loss)
print(decoder_logits)

tensor(1.1272, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor([[[-8.5915e+00, -8.7255e+00,  2.5597e+00,  ...,  4.8873e-01,
           8.6521e-01,  6.4023e-01],
         [-2.2204e+01, -2.2071e+01, -2.0999e+00,  ...,  4.6540e-01,
          -7.4663e-01,  5.8135e-01],
         [-2.0728e+01, -2.0710e+01, -1.4445e+00,  ..., -2.1973e+00,
          -3.0354e+00, -2.2271e+00],
         ...,
         [-6.4683e+00, -6.2666e+00, -6.7605e-01,  ..., -1.2500e+00,
          -2.0705e+00, -1.1193e+00],
         [-6.3887e+00, -6.2053e+00, -6.5636e-01,  ..., -1.2764e+00,
          -2.1578e+00, -1.2972e+00],
         [-6.4046e+00, -6.2078e+00, -6.7884e-01,  ..., -1.3459e+00,
          -1.9831e+00, -1.1115e+00]],

        [[-8.8575e+00, -8.9734e+00,  2.5176e+00,  ...,  4.1744e-01,
           7.4860e-01,  6.6337e-01],
         [-2.2240e+01, -2.2079e+01, -2.1657e+00,  ...,  4.8006e-01,
          -8.0228e-01,  7.0033e-01],
         [-2.0656e+01, -2.0613e+01, -1.4958e+00,  ..., -2.2319e+00,
          -3.069

In [17]:
epochs = 5

# Define optimizer (only update trainable parameters)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)

# Learning rate scheduler
num_training_steps = len(trainloader) * epochs
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [18]:
model.train()

TextGuidedHarmonizationModel(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
             

In [19]:
# Training loop
for epoch in range(epochs):  # Number of epochs
    train_loss = 0
    running_loss = 0
    batch_num = 0
    running_accuracy = 0
    train_accuracy = 0
    running_perplexity = 0
    train_perplexity = 0
    running_token_entropy = 0
    train_token_entropy = 0
    print('training')
    with tqdm(trainloader, unit='batch') as tepoch:
        tepoch.set_description(f'Epoch {epoch} | trn')
        for batch in tepoch:
            model_input_ids = batch['input_ids'].to(device)
            melody_attention_mask = batch['attention_mask'].to(device)
            harmony_input_ids = batch['harmony_input_ids'].to(device)
            labels = batch['labels'].to(device)
            texts = batch['txt']

            output = model(
                model_input_ids,
                melody_attention_mask,
                harmony_input_ids,
                texts,
                labels=labels
            )
            optimizer.zero_grad()
            loss = output['loss']
            logits = output['logits']
            
            loss.backward()  # Compute gradients
            optimizer.step()  # Update trainable weights
            lr_scheduler.step()  # Update learning rate

            # update loss
            batch_num += 1
            running_loss += loss.item()
            train_loss = running_loss/batch_num
            # accuracy
            predictions = logits.argmax(dim=-1)
            mask = labels != -100
            running_accuracy += (predictions[mask] == labels[mask]).sum().item()/mask.sum().item()
            train_accuracy = running_accuracy/batch_num
            # # perplexity
            # running_perplexity += perplexity_metric.update(outputs.logits, labels).compute().item()
            # train_perplexity = running_perplexity/batch_num
            # # token entropy
            # _, entropy_per_batch = compute_normalized_token_entropy(outputs.logits, labels, pad_token_id=-100)
            # running_token_entropy += entropy_per_batch
            # train_token_entropy = running_token_entropy/batch_num
            
            tepoch.set_postfix(loss=train_loss, accuracy=train_accuracy)
    val_loss = 0
    running_loss = 0
    batch_num = 0
    running_accuracy = 0
    val_accuracy = 0
    running_perplexity = 0
    val_perplexity = 0
    running_token_entropy = 0
    val_token_entropy = 0
    print('validation')
    with torch.no_grad():
        with tqdm(valloader, unit='batch') as tepoch:
            tepoch.set_description(f'Epoch {epoch} | val')
            for batch in tepoch:
                model_input_ids = batch['input_ids'].to(device)
                melody_attention_mask = batch['attention_mask'].to(device)
                harmony_input_ids = batch['harmony_input_ids'].to(device)
                labels = batch['labels'].to(device)
                texts = batch['txt']

                output = model(
                    model_input_ids,
                    melody_attention_mask,
                    harmony_input_ids,
                    texts,
                    labels=labels
                )
                loss = output['loss']
                logits = output['logits']

                # update loss
                batch_num += 1
                running_loss += loss.item()
                val_loss = running_loss/batch_num
                # accuracy
                predictions = logits.argmax(dim=-1)
                mask = labels != -100
                running_accuracy += (predictions[mask] == labels[mask]).sum().item()/mask.sum().item()
                val_accuracy = running_accuracy/batch_num
                # # perplexity
                # running_perplexity += perplexity_metric.update(outputs.logits, labels).compute().item()
                # val_perplexity = running_perplexity/batch_num
                # # token entropy
                # _, entropy_per_batch = compute_normalized_token_entropy(outputs.logits, labels, pad_token_id=-100)
                # running_token_entropy += entropy_per_batch
                # val_token_entropy = running_token_entropy/batch_num
                
                tepoch.set_postfix(loss=val_loss, accuracy=val_accuracy)

training


Epoch 0 | trn:   0%|          | 0/855 [00:00<?, ?batch/s]

  return self.iter().getElementsByClass(classFilterList)
  [torch.tensor(h) for h in harmony_inputs],
Epoch 0 | trn:  15%|█▍        | 128/855 [01:10<06:38,  1.82batch/s, accuracy=0.8, loss=0.667]  


KeyboardInterrupt: 