In [1]:
import os
import random

import torch
from torch.utils.data import DataLoader
from miditok import REMI, TokenizerConfig, MIDITokenizer, TokSequence
from miditok.pytorch_data import DatasetMIDI, DatasetJSON, DataCollator, split_midis_for_training
from miditok.data_augmentation import augment_midi_dataset
from miditok.utils import get_midi_programs
from pathlib import Path
from symusic import Score
import wandb
from tqdm import tqdm

from transformers.models.opt.modeling_opt import OPTForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


## Tokenizer

In [None]:
midi_paths = list(Path("/home/lklimkiewicz/priv/midi/data").glob("**/*.mid"))

In [3]:
tokenizer_config = TokenizerConfig(num_velocities=16, use_chords=True, use_programs=True)
tokenizer = REMI(tokenizer_config)

In [None]:
random.shuffle(midi_paths)
tokenizer.learn_bpe(vocab_size=30000, files_paths=midi_paths[:1000])

In [5]:
# saving
tokenizer.save_pretrained('../logs/tokenizer2')

In [2]:
# loading
tokenizer = MIDITokenizer.from_pretrained('../logs/tokenizer2')

config.json not found in /home/lklimkiewicz/priv/midi/logs/tokenizer2


## Model

In [3]:
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

In [4]:
config = AutoConfig.from_pretrained(
    'facebook/opt-125m',
    bos_token_id=tokenizer['BOS_None'],
    eos_token_id=tokenizer['EOS_None'],
    pad_token_id=tokenizer['PAD_None'],
    vocab_size=len(tokenizer),
    prefix=None,
    max_length=1024,
    do_sample=True,
)

In [5]:
model = AutoModelForCausalLM.from_config(config)

## Split dataset

In [6]:
midi_paths = list(Path("/home/lklimkiewicz/priv/midi/data").glob("**/*.mid"))

In [9]:
def filter_dataset(paths, val_fun):
    correct = []
    for path in tqdm(paths):
        try:
            midi = Score(path)
            if val_fun(midi):
                correct.append(path)
        except:
            os.remove(path)
    return correct

print('Initial count:', len(midi_paths))

def midi_valid(midi) -> bool:
    if midi.note_num() < 50 or len(midi.time_signatures) == 0 or len(midi.tempos) == 0:
        return False
    
    for time_sig in midi.time_signatures:
        if time_sig.denominator == 0 or time_sig.numerator == 0:
            return False
    
    return True

midi_paths = filter_dataset(midi_paths, midi_valid)

print('Filtered count:', len(midi_paths))

Initial count: 1067913


 14%|█▍        | 151477/1067913 [01:55<10:38, 1434.76it/s] Division type 1 have no tpq.
 15%|█▌        | 162690/1067913 [02:03<10:08, 1488.41it/s]Division type 1 have no tpq.
 38%|███▊      | 403524/1067913 [04:56<07:09, 1548.06it/s]Division type 1 have no tpq.
 39%|███▉      | 420722/1067913 [05:07<06:53, 1566.63it/s]Division type 1 have no tpq.
 41%|████      | 439315/1067913 [05:18<06:28, 1619.31it/s]Division type 1 have no tpq.
 52%|█████▏    | 558252/1067913 [06:51<13:24, 633.39it/s]  Division type 1 have no tpq.
 53%|█████▎    | 569424/1067913 [07:02<05:39, 1467.48it/s]Division type 1 have no tpq.
100%|██████████| 1067913/1067913 [13:06<00:00, 1357.29it/s]


Filtered count: 766717


In [None]:
split_midis_for_training(
    files_paths=midi_paths,
    tokenizer=tokenizer,
    save_dir=Path('./chunks_for_training'),
    max_seq_len=1024,
)

## Load dataset

In [6]:
midi_paths = list(Path("/home/lklimkiewicz/priv/midi/data_prim/chunks_for_training_2").glob("**/*.mid"))

In [None]:
print('Initial count:', len(midi_paths))

def filter_dataset(paths, val_fun):
    correct = []
    for path in tqdm(paths):
        try:
            midi = Score(path)
            if val_fun(midi):
                correct.append(path)
        except:
            pass
    return correct

def midi_valid(midi) -> bool:
    if midi.note_num() < 50 or len(midi.time_signatures) == 0 or len(midi.tempos) == 0:
        return False
    
    for time_sig in midi.time_signatures:
        if time_sig.denominator == 0 or time_sig.numerator == 0:
            return False
    
    return True

midi_paths = filter_dataset(midi_paths, midi_valid)

print('Filtered count:', len(midi_paths))

In [7]:
dataset = DatasetMIDI(
    files_paths=midi_paths,
    tokenizer=tokenizer,
    max_seq_len=1024,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)

collator = DataCollator(
    tokenizer["PAD_None"],
    copy_inputs_as_labels=True,
    shift_labels=True,
)

## Train

In [8]:
from transformers import Trainer, TrainingArguments, TrainerCallback, TrainerState, TrainerControl

In [9]:
class MidiGenerationCallback(TrainerCallback):
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.global_step % 500 == 0:
            
            input = torch.tensor([[tokenizer['BOS_None']]], device=model.device)
            midi = model.generate(input, max_new_tokens=1024)
            generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
            generated_score = tokenizer(generated_ts)
            generated_score.dump_midi(f'outputs/v1/step-{state.global_step}.mid')

In [None]:
training_args = TrainingArguments(
    output_dir="../logs/tmp/output",
    per_device_train_batch_size=4,
    report_to="wandb",
    bf16=True,
    dataloader_num_workers=24,
    logging_strategy="steps",
    logging_steps=100,
    save_strategy="steps",
    lr_scheduler_type="cosine",
    warmup_steps=2000,
    save_steps=1000,
    save_total_limit=5,
    num_train_epochs=1,
    label_smoothing_factor=0.2,
    torch_compile=True,
    deepspeed=True
)

In [11]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
    tokenizer=tokenizer,
    # callbacks=[MidiGenerationCallback()]
)

In [12]:
os.environ["WANDB_PROJECT"] = "midi"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

fp32: 2.2 it/s
bf16: 3.3 it/s
bf16 + compile: 3.4 it/s

In [13]:
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mweights-and-biases[0m ([33mklima7-team[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 100/974907 [01:40<119:01:07,  2.28it/s]

{'loss': 10.4085, 'learning_rate': 2.5e-06, 'epoch': 0.0}


  0%|          | 200/974907 [02:26<122:25:11,  2.21it/s]

{'loss': 10.2045, 'learning_rate': 5e-06, 'epoch': 0.0}


  0%|          | 218/974907 [02:34<124:51:32,  2.17it/s]

KeyboardInterrupt: 

In [None]:
model.save_pretrained('../logs/model2', tokenizer=tokenizer)

In [16]:
tokenizer.save_pretrained('../logs/tokenizer2')

In [None]:
wandb.finish()