In [1]:
import os
import random

import torch
from torch.utils.data import DataLoader
from miditok import REMI, TokenizerConfig, MIDITokenizer, TokSequence
from miditok.pytorch_data import DatasetMIDI, DatasetJSON, DataCollator, split_midis_for_training
from miditok.data_augmentation import augment_midi_dataset
from miditok.utils import get_midi_programs
from pathlib import Path
from symusic import Score
import wandb
from tqdm import tqdm

from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


## Tokenizer

In [3]:
midi_paths = list(Path("/home/lklimkiewicz/priv/midi/data/vgmusic").glob("**/*.mid")) + \
    list(Path("/home/lklimkiewicz/priv/midi/data/khinsider").glob("**/*.mid"))

In [3]:
tokenizer_config = TokenizerConfig(num_velocities=16, use_chords=True, use_programs=True)
tokenizer = REMI(tokenizer_config)

In [None]:
random.shuffle(midi_paths)
tokenizer.learn_bpe(vocab_size=30000, files_paths=midi_paths[:10000])

In [6]:
# saving
tokenizer.save_pretrained('../logs/tokenizer')

In [6]:
# loading
tokenizer = MIDITokenizer.from_pretrained('../logs/tokenizer')

config.json not found in /home/lklimkiewicz/priv/midi/logs/tokenizer


In [None]:
tokenizer = AutoTokenizer.from_pretrained('../logs/tokenizer')

In [11]:
tokenizer.push_to_hub('midi-ganerator-game')

CommitInfo(commit_url='https://huggingface.co/lklimkiewicz/midi-ganerator-game/commit/777c949a787add8470f4bc1f8a922a40a8fdf47e', commit_message='Push model using huggingface_hub.', commit_description='', oid='777c949a787add8470f4bc1f8a922a40a8fdf47e', pr_url=None, pr_revision=None, pr_num=None)

In [9]:
MIDITokenizer.from_pretrained('lklimkiewicz/midi_tokenizer')

30000 tokens with ('T',) io format(one token stream), with BPE

## Model

In [6]:
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

In [7]:
config = AutoConfig.from_pretrained(
    'facebook/opt-125m',
    bos_token_id=tokenizer['BOS_None'],
    eos_token_id=tokenizer['EOS_None'],
    pad_token_id=tokenizer['PAD_None'],
    vocab_size=len(tokenizer),
    prefix=None,
    max_length=1024,
    do_sample=True,
)

In [9]:
model = AutoModelForCausalLM.from_config(config)

## Split dataset

In [None]:
midi_paths = list(Path("/home/lklimkiewicz/priv/midi/data/vgmusic").glob("**/*.mid")) + \
    list(Path("/home/lklimkiewicz/priv/midi/data/khinsider").glob("**/*.mid"))

In [None]:
def filter_dataset(paths, val_fun):
    correct = []
    for path in tqdm(paths):
        try:
            midi = Score(path)
            if val_fun(midi):
                correct.append(path)
        except:
            pass
    return correct

In [None]:
print('Initial count:', len(midi_paths))

def midi_valid(midi) -> bool:
    if midi.note_num() < 50 or len(midi.time_signatures) == 0 or len(midi.tempos) == 0:
        return False
    
    for time_sig in midi.time_signatures:
        if time_sig.denominator == 0 or time_sig.numerator == 0:
            return False
    
    return True

midi_paths = filter_dataset(midi_paths, midi_valid)

print('Filtered count:', len(midi_paths))

Initial count: 40447


100%|██████████| 40447/40447 [00:10<00:00, 3773.71it/s]

Filtered count: 40442





In [None]:
split_midis_for_training(
    files_paths=midi_paths,
    tokenizer=tokenizer,
    save_dir=Path('./chunks_for_training'),
    max_seq_len=1024,
)

## Tokenize dataset

In [4]:
midi_paths = list(Path("/home/lklimkiewicz/priv/midi/src/chunks_for_training/khinsider").glob("**/*.mid"))
tokenizer.tokenize_midi_dataset(midi_paths, out_dir="tokenized_dataset/khinsider", save_programs=True)

Tokenizing MIDIs (tokenized_dataset/khinsider):  44%|████▍     | 42160/95341 [09:03<15:17, 57.94it/s] 

.

Tokenizing MIDIs (tokenized_dataset/khinsider):  44%|████▍     | 42198/95341 [09:04<16:28, 53.78it/s]

.

Tokenizing MIDIs (tokenized_dataset/khinsider): 100%|██████████| 95341/95341 [21:09<00:00, 75.09it/s] 


## Augment dataset

In [7]:
augment_midi_dataset(
    Path('/home/lklimkiewicz/priv/midi/src/tokenized_dataset'),
    pitch_offsets=[-12, 12],
    velocity_offsets=[-4, 5],
    duration_offsets=[-0.5, 1],
    out_path="./augmented_dataset",
)

Performing data augmentation: 0it [00:00, ?it/s]


## Load dataset

In [8]:
json_paths = list(Path("/home/lklimkiewicz/priv/midi/src/tokenized_dataset").glob("**/*.json"))

dataset = DatasetJSON(
    files_paths=json_paths,
    max_seq_len=1024,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)

collator = DataCollator(
    tokenizer["PAD_None"],
    copy_inputs_as_labels=True,
    shift_labels=True,
)

## Train

In [11]:
from transformers import Trainer, TrainingArguments, TrainerCallback, TrainerState, TrainerControl

In [12]:
class MidiGenerationCallback(TrainerCallback):
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.global_step % 500 == 0:
            
            input = torch.tensor([[tokenizer['BOS_None']]], device=model.device)
            midi = model.generate(input, max_new_tokens=1024)
            generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
            generated_score = tokenizer(generated_ts)
            generated_score.dump_midi(f'outputs/v1/step-{state.global_step}.mid')
            
            input = torch.tensor([[4]], device=model.device)
            midi = model.generate(input, max_new_tokens=1024)
            generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
            generated_score = tokenizer(generated_ts)
            generated_score.dump_midi(f'outputs/v2/step-{state.global_step}.mid')

In [13]:
training_args = TrainingArguments(
    output_dir="test_trainer",
    per_device_train_batch_size=4,
    report_to="wandb",
    bf16=True,
    dataloader_num_workers=16,
    logging_strategy="steps",
    logging_steps=100,
    save_strategy="steps",
    lr_scheduler_type="cosine",
    warmup_steps=600,
    save_steps=1000,
    save_total_limit=5,
    num_train_epochs=2,
    label_smoothing_factor=0.2,
    torch_compile=True,
)

In [14]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
    tokenizer=tokenizer,
    callbacks=[MidiGenerationCallback()]
)

In [15]:
os.environ["WANDB_PROJECT"] = "midi"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

In [None]:
trainer.train()

In [None]:
model.save_pretrained('../logs/model', tokenizer=tokenizer)

In [18]:
tokenizer.save_pretrained('./logs/tokenizer')

In [None]:
wandb.finish()