In [15]:
from data_utils import SeparatedMelHarmDataset
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
import torch
from torch.optim import AdamW
from tqdm import tqdm

In [16]:
root_dir = '/media/maindisk/maximos/data/hooktheory_train'

In [17]:
chordSymbolTokenizer = ChordSymbolTokenizer.from_pretrained('saved_tokenizers/ChordSymbolTokenizer')
rootTypeTokenizer = RootTypeTokenizer.from_pretrained('saved_tokenizers/RootTypeTokenizer')
pitchClassTokenizer = PitchClassTokenizer.from_pretrained('saved_tokenizers/PitchClassTokenizer')
rootPCTokenizer = RootPCTokenizer.from_pretrained('saved_tokenizers/RootPCTokenizer')
gctRootPCTokenizer = GCTRootPCTokenizer.from_pretrained('saved_tokenizers/GCTRootPCTokenizer')
gctSymbolTokenizer = GCTSymbolTokenizer.from_pretrained('saved_tokenizers/GCTSymbolTokenizer')
gctRootTypeTokenizer = GCTRootTypeTokenizer.from_pretrained('saved_tokenizers/GCTRootTypeTokenizer')
melodyPitchTokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')

In [18]:
m_chordSymbolTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, chordSymbolTokenizer)
m_rootTypeTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootTypeTokenizer)
m_pitchClassTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, pitchClassTokenizer)
m_rootPCTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootPCTokenizer)
m_gctRootPCTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, gctRootPCTokenizer)
m_gctSymbolTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, gctSymbolTokenizer)
m_gctRootTypeTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, gctRootTypeTokenizer)

In [19]:
tokenizer = m_chordSymbolTokenizer

dataset = SeparatedMelHarmDataset(root_dir, tokenizer, max_length=512, num_bars=64)
# Data collator for BART
def create_data_collator(tokenizer, model):
    return DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

In [20]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

model = BartForConditionalGeneration(bart_config)

In [21]:
collator = create_data_collator(tokenizer, model=model)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collator)

In [22]:
b = next(iter(dataloader))

In [23]:
print(b['input_ids'].shape)
print(b['attention_mask'].shape)
print(b['labels'].shape)

print(b['input_ids'][0])
print(b['attention_mask'][0])
print(b['labels'][0])

torch.Size([32, 256])
torch.Size([32, 256])
torch.Size([32, 468])
tensor([  2,   6, 180,  95,   4,  99,  48, 101,  50, 107,  53, 111,   4, 115,
         48, 117,  50, 123,  57,   6,  95,   4,  99,  48, 101,  50, 107,  53,
        111,   4, 115,  48, 117,  50, 123,  53,   6,  95,   4,  99,  48, 101,
         50, 107,  53, 111,   4, 115,  48, 117,  50, 123,  57,   6,  95,   4,
         99,  48, 101,  50, 107,  53, 111,   4, 115,  48, 117,  50, 123,  53,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

In [26]:
# Initialize variables for tracking progress
model.train()
# progress_bar = tqdm(range(num_training_steps))

for epoch in range(3):  # Number of epochs
    train_loss = 0
    running_loss = 0
    batch_num = 0
    running_accuracy = 0
    train_accuracy = 0
    with tqdm(dataloader, unit='batch') as tepoch:
        tepoch.set_description(f"Epoch {epoch} | trn")
        for batch in tepoch:
            # Move batch to the same device as the model
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels  # BART handles shifting internally
            )
            
            loss = outputs.loss  # Loss is directly computed by the model
            logits = outputs.logits
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # lr_scheduler.step()
            
            # update loss
            batch_num += 1
            running_loss += loss.item()
            train_loss = running_loss/batch_num
            # accuracy
            predictions = outputs.logits.argmax(dim=-1)
            mask = labels != -100
            running_accuracy += (predictions[mask] == labels[mask]).sum().item()/mask.sum().item()
            train_accuracy = running_accuracy/batch_num
            
            tepoch.set_postfix(loss=train_loss, accuracy=train_accuracy)


  return self.iter().getElementsByClass(classFilterList)
Epoch 0 | trn: 100%|██████████| 428/428 [07:23<00:00,  1.04s/batch, accuracy=0.407, loss=2.49] 
Epoch 1 | trn:   6%|▌         | 24/428 [00:25<07:14,  1.07s/batch, accuracy=0.555, loss=1.72]


KeyboardInterrupt: 