In [1]:
import os
import random

import torch
from torch.utils.data import DataLoader
from miditok import REMI, TokenizerConfig, MIDITokenizer, TokSequence
from miditok.pytorch_data import DatasetMIDI, DatasetJSON, DataCollator, split_midis_for_training
from miditok.data_augmentation import augment_midi_dataset
from miditok.utils import get_midi_programs
from pathlib import Path
from symusic import Score
import wandb
from tqdm import tqdm

from transformers.models.opt.modeling_opt import OPTForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


## Tokenizer

In [2]:
midi_paths = list(Path("/home/lklimkiewicz/priv/midi/data").glob("**/*.mid"))

In [3]:
tokenizer_config = TokenizerConfig(num_velocities=16, use_chords=True, use_programs=True)
tokenizer = REMI(tokenizer_config)

In [4]:
random.shuffle(midi_paths)
tokenizer.learn_bpe(vocab_size=30000, files_paths=midi_paths[:1000])






In [5]:
# saving
tokenizer.save_pretrained('../logs/tokenizer2')

In [3]:
# loading
tokenizer = MIDITokenizer.from_pretrained('../logs/tokenizer2')

config.json not found in /home/lklimkiewicz/priv/midi/logs/tokenizer2


## Model

In [4]:
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

In [6]:
config = AutoConfig.from_pretrained(
    'facebook/opt-125m',
    bos_token_id=tokenizer['BOS_None'],
    eos_token_id=tokenizer['EOS_None'],
    pad_token_id=tokenizer['PAD_None'],
    vocab_size=len(tokenizer),
    prefix=None,
    max_length=1024,
    do_sample=True,
)

In [7]:
model = AutoModelForCausalLM.from_config(config)

## Split dataset

In [6]:
midi_paths = list(Path("/home/lklimkiewicz/priv/midi/data").glob("**/*.mid"))

In [8]:
def filter_dataset(paths, val_fun):
    correct = []
    for path in tqdm(paths):
        try:
            midi = Score(path)
            if val_fun(midi):
                correct.append(path)
        except:
            os.remove(path)
    return correct

In [9]:
print('Initial count:', len(midi_paths))

def midi_valid(midi) -> bool:
    if midi.note_num() < 50 or len(midi.time_signatures) == 0 or len(midi.tempos) == 0:
        return False
    
    for time_sig in midi.time_signatures:
        if time_sig.denominator == 0 or time_sig.numerator == 0:
            return False
    
    return True

midi_paths = filter_dataset(midi_paths, midi_valid)

print('Filtered count:', len(midi_paths))

Initial count: 1067913


 14%|█▍        | 151477/1067913 [01:55<10:38, 1434.76it/s] Division type 1 have no tpq.
 15%|█▌        | 162690/1067913 [02:03<10:08, 1488.41it/s]Division type 1 have no tpq.
 38%|███▊      | 403524/1067913 [04:56<07:09, 1548.06it/s]Division type 1 have no tpq.
 39%|███▉      | 420722/1067913 [05:07<06:53, 1566.63it/s]Division type 1 have no tpq.
 41%|████      | 439315/1067913 [05:18<06:28, 1619.31it/s]Division type 1 have no tpq.
 52%|█████▏    | 558252/1067913 [06:51<13:24, 633.39it/s]  Division type 1 have no tpq.
 53%|█████▎    | 569424/1067913 [07:02<05:39, 1467.48it/s]Division type 1 have no tpq.
100%|██████████| 1067913/1067913 [13:06<00:00, 1357.29it/s]


Filtered count: 766717


In [None]:
split_midis_for_training(
    files_paths=midi_paths,
    tokenizer=tokenizer,
    save_dir=Path('./chunks_for_training'),
    max_seq_len=1024,
)

## Tokenize dataset

In [4]:
midi_paths = list(Path("/home/lklimkiewicz/priv/midi/src/chunks_for_training/khinsider").glob("**/*.mid"))
tokenizer.tokenize_midi_dataset(midi_paths, out_dir="tokenized_dataset/khinsider", save_programs=True)

Tokenizing MIDIs (tokenized_dataset/khinsider):  44%|████▍     | 42160/95341 [09:03<15:17, 57.94it/s] 

.

Tokenizing MIDIs (tokenized_dataset/khinsider):  44%|████▍     | 42198/95341 [09:04<16:28, 53.78it/s]

.

Tokenizing MIDIs (tokenized_dataset/khinsider): 100%|██████████| 95341/95341 [21:09<00:00, 75.09it/s] 


## Augment dataset

In [7]:
augment_midi_dataset(
    Path('/home/lklimkiewicz/priv/midi/src/tokenized_dataset'),
    pitch_offsets=[-12, 12],
    velocity_offsets=[-4, 5],
    duration_offsets=[-0.5, 1],
    out_path="./augmented_dataset",
)

Performing data augmentation: 0it [00:00, ?it/s]


## Load dataset

In [8]:
midi_paths = list(Path("/home/lklimkiewicz/priv/midi/src/chunks_for_training_2").glob("**/*.mid"))

dataset = DatasetMIDI(
    files_paths=midi_paths,
    tokenizer=tokenizer,
    max_seq_len=1024,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)

# json_paths = list(Path("/home/lklimkiewicz/priv/midi/src/tokenized_dataset").glob("**/*.json"))

# dataset = DatasetJSON(
#     files_paths=json_paths,
#     max_seq_len=1024,
#     bos_token_id=tokenizer["BOS_None"],
#     eos_token_id=tokenizer["EOS_None"],
# )

collator = DataCollator(
    tokenizer["PAD_None"],
    copy_inputs_as_labels=True,
    shift_labels=True,
)

## Train

In [9]:
from transformers import Trainer, TrainingArguments, TrainerCallback, TrainerState, TrainerControl

In [10]:
class MidiGenerationCallback(TrainerCallback):
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.global_step % 500 == 0:
            
            input = torch.tensor([[tokenizer['BOS_None']]], device=model.device)
            midi = model.generate(input, max_new_tokens=1024)
            generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
            generated_score = tokenizer(generated_ts)
            generated_score.dump_midi(f'outputs/v1/step-{state.global_step}.mid')
            
            input = torch.tensor([[4]], device=model.device)
            midi = model.generate(input, max_new_tokens=1024)
            generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
            generated_score = tokenizer(generated_ts)
            generated_score.dump_midi(f'outputs/v2/step-{state.global_step}.mid')

In [11]:
training_args = TrainingArguments(
    output_dir="test_trainer",
    per_device_train_batch_size=4,
    report_to="wandb",
    bf16=True,
    dataloader_num_workers=24,
    logging_strategy="steps",
    logging_steps=100,
    save_strategy="steps",
    lr_scheduler_type="cosine",
    warmup_steps=2000,
    save_steps=1000,
    save_total_limit=5,
    num_train_epochs=1,
    label_smoothing_factor=0.2,
    torch_compile=True,
)

In [12]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
    tokenizer=tokenizer,
    callbacks=[MidiGenerationCallback()]
)

In [13]:
os.environ["WANDB_PROJECT"] = "midi"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

In [15]:
trainer.train(resume_from_checkpoint=True)

There were missing keys in the checkpoint model loaded: ['lm_head.weight'].
  2%|▏         | 21540/974907 [1:06:34<49:06:46,  5.39it/s]

[A                                                    

{'loss': 6.3445, 'learning_rate': 4.9952466897011097e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.3532, 'learning_rate': 4.9951968025086607e-05, 'epoch': 0.02}



[A                                                     

{'loss': 6.3451, 'learning_rate': 4.995146655143412e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.4201, 'learning_rate': 4.995096247610592e-05, 'epoch': 0.02}



[A                                                        

{'loss': 6.3569, 'learning_rate': 4.995045579915457e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.381, 'learning_rate': 4.994994652063291e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.2546, 'learning_rate': 4.9949434640594016e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.3057, 'learning_rate': 4.9948920159091294e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.325, 'learning_rate': 4.994840307617836e-05, 'epoch': 0.02}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 6.3965, 'learning_rate': 4.994788339190916e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.386, 'learning_rate': 4.994736110633785e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.3309, 'learning_rate': 4.994683621951891e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.3059, 'learning_rate': 4.994630873150706e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.2829, 'learning_rate': 4.994577864235731e-05, 'epoch': 0.02}



[A                                                        

{'loss': 6.3038, 'learning_rate': 4.994524595212491e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.275, 'learning_rate': 4.994471066086543e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.3376, 'learning_rate': 4.9944172768634666e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.3322, 'learning_rate': 4.994363227548872e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.2695, 'learning_rate': 4.994308918148394e-05, 'epoch': 0.02}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 6.2534, 'learning_rate': 4.994254348667694e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.3325, 'learning_rate': 4.9941995191124647e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.2533, 'learning_rate': 4.994144429488421e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.2796, 'learning_rate': 4.9940890798013085e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.2784, 'learning_rate': 4.994033470056898e-05, 'epoch': 0.02}



[A                                                        

{'loss': 6.2712, 'learning_rate': 4.9939776002609865e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.2678, 'learning_rate': 4.993921470419402e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.216, 'learning_rate': 4.9938650805379946e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.1999, 'learning_rate': 4.9938084306226465e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.2412, 'learning_rate': 4.993751520679263e-05, 'epoch': 0.02}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 6.2864, 'learning_rate': 4.993694350713778e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.2021, 'learning_rate': 4.993636920732153e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.2183, 'learning_rate': 4.993579230740377e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.2403, 'learning_rate': 4.9935212807444635e-05, 'epoch': 0.02}



[A                                                      

{'loss': 6.1698, 'learning_rate': 4.993463070750457e-05, 'epoch': 0.03}



[A                                                        

{'loss': 6.2505, 'learning_rate': 4.993404600764425e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.2077, 'learning_rate': 4.993345870792465e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.2387, 'learning_rate': 4.9932868808407015e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1824, 'learning_rate': 4.993227630915284e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1831, 'learning_rate': 4.993168121022392e-05, 'epoch': 0.03}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 6.1938, 'learning_rate': 4.9931083511682284e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1588, 'learning_rate': 4.993048321359028e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1931, 'learning_rate': 4.992988031601048e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.2253, 'learning_rate': 4.992927481900576e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1416, 'learning_rate': 4.992866672263924e-05, 'epoch': 0.03}



[A                                                        

{'loss': 6.1867, 'learning_rate': 4.992805602697435e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1757, 'learning_rate': 4.992744273207475e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1713, 'learning_rate': 4.9926826838004385e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1829, 'learning_rate': 4.992620834482748e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1794, 'learning_rate': 4.992558725260852e-05, 'epoch': 0.03}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 6.1448, 'learning_rate': 4.992496356141227e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1141, 'learning_rate': 4.992433727130377e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1606, 'learning_rate': 4.9923708382348314e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.134, 'learning_rate': 4.992307689461148e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1955, 'learning_rate': 4.992244280815911e-05, 'epoch': 0.03}



[A                                                        

{'loss': 6.1302, 'learning_rate': 4.9921806123057316e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0804, 'learning_rate': 4.992116683937249e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0962, 'learning_rate': 4.992052495717129e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0462, 'learning_rate': 4.991988047652064e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1287, 'learning_rate': 4.991923339748774e-05, 'epoch': 0.03}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 6.1091, 'learning_rate': 4.9918583720140067e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0524, 'learning_rate': 4.9917931444545364e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.099, 'learning_rate': 4.991727657077163e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.083, 'learning_rate': 4.991661909888716e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0561, 'learning_rate': 4.9915959028960506e-05, 'epoch': 0.03}



[A                                                        

{'loss': 6.1227, 'learning_rate': 4.9915296361060495e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.1264, 'learning_rate': 4.991463109525622e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0485, 'learning_rate': 4.9913963231617035e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0951, 'learning_rate': 4.99132927702126e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0849, 'learning_rate': 4.9912619711112815e-05, 'epoch': 0.03}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 6.0404, 'learning_rate': 4.991194405438786e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0425, 'learning_rate': 4.9911265800108185e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0573, 'learning_rate': 4.991058494834451e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.028, 'learning_rate': 4.990990149916782e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0573, 'learning_rate': 4.9909215452649393e-05, 'epoch': 0.03}



[A                                                        

{'loss': 6.0384, 'learning_rate': 4.990852680886075e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.059, 'learning_rate': 4.990783556787371e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0454, 'learning_rate': 4.990714172976033e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0047, 'learning_rate': 4.990644529459297e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0883, 'learning_rate': 4.990574626244424e-05, 'epoch': 0.03}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 6.0284, 'learning_rate': 4.990504463338703e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0818, 'learning_rate': 4.99043404074945e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.008, 'learning_rate': 4.990363358484007e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9438, 'learning_rate': 4.9902924165497456e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9483, 'learning_rate': 4.990221214954061e-05, 'epoch': 0.03}



[A                                                        

{'loss': 6.0103, 'learning_rate': 4.990149753704379e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0651, 'learning_rate': 4.99007803280815e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.0033, 'learning_rate': 4.990006052272853e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9648, 'learning_rate': 4.989933812105992e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9618, 'learning_rate': 4.9898613123151014e-05, 'epoch': 0.03}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 5.9765, 'learning_rate': 4.989788552907739e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9676, 'learning_rate': 4.989715533891492e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9761, 'learning_rate': 4.989642255273974e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9665, 'learning_rate': 4.989568717062826e-05, 'epoch': 0.03}



[A                                                      

{'loss': 6.007, 'learning_rate': 4.989494919265715e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8786, 'learning_rate': 4.989420861890337e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.903, 'learning_rate': 4.9893465449444135e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9174, 'learning_rate': 4.9892719684356925e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9053, 'learning_rate': 4.9891971323719516e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9588, 'learning_rate': 4.989122036760993e-05, 'epoch': 0.03}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 5.9694, 'learning_rate': 4.989046681610647e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.8963, 'learning_rate': 4.988971066928771e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9134, 'learning_rate': 4.988895192723249e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9024, 'learning_rate': 4.988819059001993e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.856, 'learning_rate': 4.988742665772941e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8871, 'learning_rate': 4.9886660130440575e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.943, 'learning_rate': 4.988589100823337e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9122, 'learning_rate': 4.988511929118798e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9667, 'learning_rate': 4.988434497938487e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.9073, 'learning_rate': 4.988356807290477e-05, 'epoch': 0.03}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 5.8452, 'learning_rate': 4.9882788571828714e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.8586, 'learning_rate': 4.988200647623795e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.8756, 'learning_rate': 4.988122178621405e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.8898, 'learning_rate': 4.988043450183882e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.8614, 'learning_rate': 4.987964462319436e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8887, 'learning_rate': 4.987885215036301e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.8868, 'learning_rate': 4.987805708342742e-05, 'epoch': 0.03}



[A                                                      

{'loss': 5.8665, 'learning_rate': 4.987725942247049e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8592, 'learning_rate': 4.987645916757538e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8599, 'learning_rate': 4.987565631882554e-05, 'epoch': 0.03}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 5.8416, 'learning_rate': 4.987485087630469e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8461, 'learning_rate': 4.987404284009679e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8974, 'learning_rate': 4.987323221028612e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8459, 'learning_rate': 4.987241898695718e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.849, 'learning_rate': 4.9871603170194784e-05, 'epoch': 0.03}



[A                                                          

{'loss': 5.7995, 'learning_rate': 4.9870784760083985e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8475, 'learning_rate': 4.986996375671013e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8236, 'learning_rate': 4.986914016015881e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8089, 'learning_rate': 4.986831397051591e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8159, 'learning_rate': 4.986748518786758e-05, 'epoch': 0.03}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 5.7959, 'learning_rate': 4.9866653812300225e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8007, 'learning_rate': 4.986581984390054e-05, 'epoch': 0.03}



[A                                                        

{'loss': 5.8428, 'learning_rate': 4.986498328275547e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7582, 'learning_rate': 4.9864144128952265e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7856, 'learning_rate': 4.9863302382578415e-05, 'epoch': 0.04}



[A                                                          

{'loss': 5.8355, 'learning_rate': 4.986245804372167e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7896, 'learning_rate': 4.98616111124701e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7247, 'learning_rate': 4.9860761588911985e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.8014, 'learning_rate': 4.985990947313592e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.8288, 'learning_rate': 4.985905476523075e-05, 'epoch': 0.04}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 5.8628, 'learning_rate': 4.9858197465285594e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.8373, 'learning_rate': 4.985733757338985e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7812, 'learning_rate': 4.985647508963317e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7676, 'learning_rate': 4.985561001410549e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.729, 'learning_rate': 4.985474234689701e-05, 'epoch': 0.04}



[A                                                          

{'loss': 5.7344, 'learning_rate': 4.98538720880982e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7182, 'learning_rate': 4.98529992377998e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7396, 'learning_rate': 4.985212379609282e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7297, 'learning_rate': 4.985124576306855e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.725, 'learning_rate': 4.985036513881853e-05, 'epoch': 0.04}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 5.7562, 'learning_rate': 4.984948192343459e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.727, 'learning_rate': 4.9848596117008825e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7186, 'learning_rate': 4.9847707719633594e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7488, 'learning_rate': 4.984681673140153e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7212, 'learning_rate': 4.984592315240554e-05, 'epoch': 0.04}



[A                                                          

{'loss': 5.7179, 'learning_rate': 4.984502698273878e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7244, 'learning_rate': 4.9844128222494714e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7047, 'learning_rate': 4.984322687176704e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7009, 'learning_rate': 4.984232293064975e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7641, 'learning_rate': 4.9841416399237105e-05, 'epoch': 0.04}



Non-default generation parameters: {'max_length': 1024, 'do_sample': True}


{'loss': 5.7288, 'learning_rate': 4.984050727762361e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.6733, 'learning_rate': 4.983959556590408e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.7043, 'learning_rate': 4.9838681264173546e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.6791, 'learning_rate': 4.9837764372527374e-05, 'epoch': 0.04}



[A                                                        

{'loss': 5.743, 'learning_rate': 4.9836844891061155e-05, 'epoch': 0.04}




IndexError: Caught IndexError in DataLoader worker process 4.
Original Traceback (most recent call last):
  File "/home/lklimkiewicz/miniconda3/envs/midi2/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/lklimkiewicz/miniconda3/envs/midi2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/lklimkiewicz/miniconda3/envs/midi2/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/lklimkiewicz/miniconda3/envs/midi2/lib/python3.10/site-packages/miditok/pytorch_data/datasets.py", line 187, in __getitem__
    tokseq = self._tokenize_midi(midi)
  File "/home/lklimkiewicz/miniconda3/envs/midi2/lib/python3.10/site-packages/miditok/pytorch_data/datasets.py", line 227, in _tokenize_midi
    tokseq.ids = self._preprocess_token_ids(
  File "/home/lklimkiewicz/miniconda3/envs/midi2/lib/python3.10/site-packages/miditok/pytorch_data/datasets.py", line 53, in _preprocess_token_ids
    if isinstance(token_ids[0], list):
IndexError: list index out of range


In [10]:
model.save_pretrained('./new_model', tokenizer=tokenizer)

NameError: name 'model' is not defined

In [16]:
tokenizer.save_pretrained('./tokenizer')

In [None]:
wandb.finish()