In [1]:
from harmony_tokenizers_m21 import ChordSymbolTokenizer, PitchClassTokenizer, MelodyPitchTokenizer, MergedMelHarmTokenizer
from data_utils import StructBARTMelHarmDataset
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_dir = '/media/maindisk/maximos/data/hooktheory_train'

In [3]:
cstok = ChordSymbolTokenizer()
pctok = PitchClassTokenizer()
meltok = MelodyPitchTokenizer()
tokenizer = MergedMelHarmTokenizer(meltok, pctok)

In [4]:
print(len(tokenizer.vocab))
print(tokenizer.vocab['</m>'])
print(tokenizer.vocab['<h>'])
print(tokenizer.vocab['<fill>'])

211
8
7
9


In [5]:
train_dataset = StructBARTMelHarmDataset(train_dir, tokenizer, max_length=512, num_bars=64)

In [6]:
d = train_dataset[0]

  return self.iter().getElementsByClass(classFilterList)


In [7]:
print(d)

{'input_ids': tensor([  2,   6, 183,  98,   4, 106,  53, 108,  56, 112,  58, 118,  56, 122,
         58, 126,  56,   6,  98,  61, 106,  58, 110,  58, 114,  56, 122,  54,
        126,  56,   6,  98,  56, 102,   4,   6,  98,   4,   6,  98,   4, 106,
         53, 108,  56, 112,  58, 118,  56, 122,  58, 128,  56,   6,  98,  61,
        102,  61, 106,   4, 110,  58, 114,  56, 122,  54, 126,  56,   6,  98,
         56, 102,   4,   6,  98,   4,   6,  98,   4, 102,  49, 106,  53, 108,
         56, 112,  58, 118,  56, 122,   4, 126,  56,   6,  98,  61, 106,  58,
        110,  58, 114,  56, 122,  54, 126,  56,   6,  98,  56, 102,  53, 110,
          4, 126,   4,   6,  98,   4,   6,  98,   4, 106,  53, 112,  56, 114,
         58, 118,  56, 126,  61,   6,  98,  61, 102,  61, 106,  58, 110,  56,
        114,  56, 122,  54, 126,  56,   6,  98,  56, 106,   4,   6,  98,   4,
        110,   4, 126,   4,   8,   6,   9,   6,   9,   6,   9,   6,   9,   6,
          9,   6,   9,   6,   9,   6,   9,   6,   

In [8]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8, #16,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8, #16,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.25, #0.1,
    decoder_layerdrop=0.25, #0.1,
    dropout=0.25 #0.1
)

model = BartForConditionalGeneration(bart_config)

In [9]:
def create_data_collator(tokenizer, model):
    return DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

collator = create_data_collator(tokenizer, model=model)

trainloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collator)

In [10]:
batch = next(iter(trainloader))

  return self.iter().getElementsByClass(classFilterList)
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


In [11]:
print(batch['input_ids'][0])

tensor([  2,   6, 183,  98,  58, 102,  56,   6,  98,  51, 102,  49, 118,  49,
        122,   4,   6,  98,  49, 110,  49, 114,  51, 118,  53, 122,  56, 126,
         53,   6,  98,  53, 114,   4,   8,   6,  98, 199, 203, 206, 210,   9,
          6,   9,   6,   9,   6,   9,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,  