In [1]:
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator
from transformers import DataCollatorForLanguageModeling
from miditok.utils import split_files_for_training
from torch.utils.data import DataLoader
from pathlib import Path
from random import shuffle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the tokenizer
prefix = '/media/datadisk/datasets/GiantMIDI-PIano/aug'
saved_tokenizer_path = f'{prefix}_REMI_BPE_tokenizer.json'
path_to_dataset = prefix
path_to_train_splits = f'{prefix}_splits_REMI_BPE/train/midis'
path_to_valid_splits = f'{prefix}_splits_REMI_BPE/valid/midis'

max_seq_len = 1024

tokenizer = REMI(params=Path(saved_tokenizer_path))

tokenizer.pad_token = tokenizer.special_tokens[0]
tokenizer.mask_token = tokenizer.special_tokens[1]

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm_probability=0.15
)

In [3]:
tokenizer.pad_token = tokenizer.special_tokens[0]

In [4]:
tokenizer.special_tokens

['PAD_None', 'MASK_None', 'BOS_None', 'EOS_None']

In [5]:
files_paths = list(Path(path_to_dataset).glob("**/*.mid"))
shuffle(files_paths)

In [6]:
total_num_files = len(files_paths)
num_files_valid = round(total_num_files * 0.10)

In [7]:
print(total_num_files, num_files_valid)

160832 16083


In [8]:
midi_paths_valid = files_paths[:num_files_valid]
midi_paths_train = files_paths[num_files_valid:]

In [19]:
# Split MIDIs into smaller chunks for validation
dataset_chunks_dir = Path(path_to_valid_splits)
split_files_for_training(
    files_paths=midi_paths_valid,
    tokenizer=tokenizer,
    save_dir=dataset_chunks_dir,
    max_seq_len=max_seq_len,
)

Splitting music files (/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/valid/midis): 100%|██████████| 160832/160832 [21:17<00:00, 125.90it/s]


[PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/valid/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_0.mid'),
 PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/valid/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_1.mid'),
 PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/valid/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_2.mid'),
 PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/valid/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_3.mid'),
 PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/valid/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_4.mid'),
 PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/valid/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_5.mid'),
 PosixPath

In [20]:

# Split MIDIs into smaller chunks for training
dataset_chunks_dir = Path(path_to_train_splits)
split_files_for_training(
    files_paths=midi_paths_train,
    tokenizer=tokenizer,
    save_dir=dataset_chunks_dir,
    max_seq_len=max_seq_len,
)

Splitting music files (/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/train/midis): 100%|██████████| 160832/160832 [17:58<00:00, 149.15it/s]


[PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/train/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_0.mid'),
 PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/train/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_1.mid'),
 PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/train/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_2.mid'),
 PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/train/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_3.mid'),
 PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/train/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_4.mid'),
 PosixPath('/media/datadisk/datasets/GiantMIDI-PIano/aug_splits_REMI_BPE/train/midis/Bach, Johann Sebastian, Toccata in D major, BWV 912, I8tDTeajtCs#d-192_5.mid'),
 PosixPath

In [9]:
from transformers import AutoModelForMaskedLM, TrainingArguments, Trainer

In [25]:
train_dataset = DatasetMIDI(
    files_paths=Path(path_to_train_splits).glob('**/*.mid'),
    tokenizer=tokenizer,
    max_seq_len=1024,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)
valid_dataset = DatasetMIDI(
    files_paths=Path(path_to_valid_splits).glob('**/*.mid'),
    tokenizer=tokenizer,
    max_seq_len=1024,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)
collator = DataCollator(tokenizer.pad_token_id, copy_inputs_as_labels=True)
# dataloader = DataLoader(dataset, batch_size=64, collate_fn=collator)


In [24]:
print(train_dataset[0])

{'input_ids': tensor([1317, 1894,  233,  ...,  241, 1199,  321])}


In [11]:
model = AutoModelForMaskedLM.from_pretrained("distilbert/distilroberta-base")

Some weights of the model checkpoint at distilbert/distilroberta-base were not used when initializing RobertaForMaskedLM: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
training_args = TrainingArguments(
    output_dir='models/robertaGM',
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
)

In [26]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
)

trainer.train()

AttributeError: 'REMI' object has no attribute 'pad'