In [13]:
from data_utils import SeparatedMelHarmTextDataset
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
import torch
from torch.optim import AdamW
from tqdm import tqdm

In [14]:
root_dir = '/mnt/ssd2/maximos/data/hooktheory_train'

In [15]:
chordSymbolTokenizer = ChordSymbolTokenizer.from_pretrained('saved_tokenizers/ChordSymbolTokenizer')
rootTypeTokenizer = RootTypeTokenizer.from_pretrained('saved_tokenizers/RootTypeTokenizer')
pitchClassTokenizer = PitchClassTokenizer.from_pretrained('saved_tokenizers/PitchClassTokenizer')
rootPCTokenizer = RootPCTokenizer.from_pretrained('saved_tokenizers/RootPCTokenizer')
melodyPitchTokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')

In [16]:
m_chordSymbolTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, chordSymbolTokenizer)
m_rootTypeTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootTypeTokenizer)
m_pitchClassTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, pitchClassTokenizer)
m_rootPCTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootPCTokenizer)

In [17]:
tokenizer = m_chordSymbolTokenizer
tokenizer_name = 'ChordSymbolTokenizer'

description_modes = [
    'specific_chord',
    'chord_root',
    'pitch_class'
]

dataset = SeparatedMelHarmTextDataset(root_dir, tokenizer, max_length=512, num_bars=64, \
    description_mode=description_modes[2])
# Data collator for BART
def create_data_collator(tokenizer, model):
    return DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

In [18]:
print(dataset[0])

{'input_ids': tensor([  2,   6, 180,  95,  69,   6,  95,  69,   6,  95,  69, 103,   4,   6,
         95,   4, 119,  67, 123,  69,   6,  95,  70,  99,  72, 103,  67,   6,
         95,  67,   6,  95,  67,   6,  95,  67, 119,  64, 123,  69]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 'labels': tensor([ 196,    6,   95,  379,  103,  379,  119,  379,    6,   95,  379,  111,
         379,    6,   95,  379,  103,  379,  119,  379,    6,   95,  379,  111,
         379,    6,   95,  466,  103,  466,  119,  466,    6,   95,  466,  111,
         466,    6,   95,  466,  103,  466,  119,  466,    6,   95,  466,  111,
         466,    3, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -1