### imports, set up

In [5]:
from datasets import load_dataset
import os
from miditoolkit import MidiFile
from miditok import REMI
import torch
from transformers import (
    AutoModelForCausalLM,
    Trainer,
    default_data_collator,
    set_seed,
)
from itertools import chain

In [21]:
max_train_samples = 10000
block_size = 1024

# esta signature no permite procesamiento por batches
def midifn2tokens(x, data_dir, tokenizer, return_dict=True):
    x['midi_filename'] = os.path.join(data_dir, x['midi_filename'])
    midi_file = MidiFile(x['midi_filename'])
    tokens = tokenizer.midi_to_tokens(midi_file)[0]
    x['input_ids'] = tokens
    if return_dict:
        return x
    else:
        return tokens

DATA_DIR = "/Users/juanigp/Desktop/data/e-gmd-v1.0.0"
csv_dir = "e-gmd-v1.0.0.csv"
csv_dir = os.path.join(DATA_DIR, csv_dir)
dataset = load_dataset('csv', data_files=csv_dir)
dataset = dataset['train'].train_test_split(test_size = 0.2)
train_dataset = dataset['train'] 
train_dataset = train_dataset.select(range(max_train_samples))# for quick testing
tokenizer = REMI()
column_names = train_dataset.column_names

map_func = lambda x: midifn2tokens(x, DATA_DIR, tokenizer, True)
train_dataset = train_dataset.map(
    map_func, 
    num_proc=8,
    # batched=True,  
    remove_columns=column_names, # al hacer esto es mas facil agrupar por chunks luego
)

Using custom data configuration default-51ad44df56849337
Reusing dataset csv (/Users/juanigp/.cache/huggingface/datasets/csv/default-51ad44df56849337/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)
100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 202.35it/s]
#0:   0%|                                              | 0/1250 [00:00<?, ?ex/s]



#3:   0%|                                              | 0/1250 [00:00<?, ?ex/s][A[A[A[A

#2:   0%|                                              | 0/1250 [00:00<?, ?ex/s][A[A



#4:   0%|                                              | 0/1250 [00:00<?, ?ex/s][A[A[A[A




#5:   0%|                                              | 0/1250 [00:00<?, ?ex/s][A[A[A[A[A





#6:   0%|                                              | 0/1250 [00:00<?, ?ex/s][A[A[A[A[A[A






#7:   0%|                                              | 0/1250 [00:00<?, ?ex/s][A[A[A[A[A[A[A
#0:   1%|▎              

#6:   7%|██▋                                  | 89/1250 [00:02<00:40, 28.90ex/s][A[A[A[A[A[A

#2:   7%|██▌                                  | 86/1250 [00:02<00:30, 37.58ex/s][A[A






#7:   6%|██▏                                  | 74/1250 [00:02<00:35, 33.40ex/s][A[A[A[A[A[A[A


#3:   5%|█▋                                   | 57/1250 [00:02<00:45, 26.47ex/s][A[A[A



#4:   9%|███                                 | 108/1250 [00:02<00:24, 45.95ex/s][A[A[A[A






#7:   6%|██▎                                  | 80/1250 [00:02<00:30, 38.52ex/s][A[A[A[A[A[A[A




#5:   8%|██▊                                  | 94/1250 [00:02<00:37, 30.45ex/s][A[A[A[A[A
#1:  10%|███▌                                | 123/1250 [00:02<00:20, 54.07ex/s][A

#2:   8%|██▊                                  | 95/1250 [00:02<00:25, 44.50ex/s][A[A



#4:   9%|███▎                                | 114/1250 [00:02<00:23, 48.48ex/s][A[A[A[A






#7:   7%|██▌                       

#5:  14%|████▉                               | 173/1250 [00:05<00:25, 42.62ex/s][A[A[A[A[A


#3:  10%|███▌                                | 124/1250 [00:05<00:34, 33.11ex/s][A[A[A
#1:  18%|██████▍                             | 224/1250 [00:05<00:31, 32.29ex/s][A

#2:  16%|█████▋                              | 198/1250 [00:05<00:30, 34.48ex/s][A[A




#5:  14%|█████▏                              | 180/1250 [00:05<00:22, 46.72ex/s][A[A[A[A[A


#3:  11%|███▉                                | 137/1250 [00:05<00:22, 50.41ex/s][A[A[A






#7:  13%|████▊                               | 165/1250 [00:05<00:46, 23.12ex/s][A[A[A[A[A[A[A
#1:  18%|██████▋                             | 231/1250 [00:05<00:27, 36.47ex/s][A




#5:  15%|█████▌                              | 191/1250 [00:05<00:18, 57.95ex/s][A[A[A[A[A


#3:  12%|████▏                               | 144/1250 [00:05<00:21, 50.60ex/s][A[A[A






#7:  14%|█████                               | 174/1250 [0

#7:  25%|████████▉                           | 309/1250 [00:07<00:13, 69.90ex/s][A[A[A[A[A[A[A
#1:  24%|████████▋                           | 302/1250 [00:07<00:23, 40.49ex/s][A

#2:  24%|████████▋                           | 301/1250 [00:07<00:24, 38.35ex/s][A[A


#3:  20%|███████▎                            | 253/1250 [00:07<00:30, 32.33ex/s][A[A[A




#5:  24%|████████▊                           | 306/1250 [00:07<00:19, 47.39ex/s][A[A[A[A[A





#6:  20%|███████▎                            | 254/1250 [00:07<00:57, 17.46ex/s][A[A[A[A[A[A



#4:  21%|███████▍                            | 260/1250 [00:07<00:27, 35.64ex/s][A[A[A[A
#0:  25%|█████████                           | 314/1250 [00:08<00:30, 30.67ex/s][A




#5:  25%|█████████▏                          | 317/1250 [00:08<00:16, 57.21ex/s][A[A[A[A[A


#3:  21%|███████▋                            | 265/1250 [00:08<00:22, 42.99ex/s][A[A[A





#6:  21%|███████▍                            | 257/12

#7:  30%|██████████▉                         | 378/1250 [00:10<00:24, 35.32ex/s][A[A[A[A[A[A[A



#4:  28%|██████████▏                         | 355/1250 [00:10<00:24, 35.82ex/s][A[A[A[A
#0:  32%|███████████▍                        | 397/1250 [00:10<00:43, 19.48ex/s][A




#5:  30%|██████████▉                         | 379/1250 [00:10<00:30, 28.98ex/s][A[A[A[A[A

#2:  32%|███████████▌                        | 402/1250 [00:10<00:17, 48.64ex/s][A[A

#2:  33%|███████████▊                        | 409/1250 [00:10<00:17, 49.26ex/s][A[A
#0:  33%|███████████▊                        | 412/1250 [00:10<00:25, 32.66ex/s][A





#6:  26%|█████████▌                          | 330/1250 [00:10<00:23, 39.89ex/s][A[A[A[A[A[A






#7:  31%|███████████                         | 383/1250 [00:10<00:32, 26.79ex/s][A[A[A[A[A[A[A




#5:  31%|███████████                         | 384/1250 [00:10<00:32, 26.35ex/s][A[A[A[A[A

#2:  33%|████████████                        

#0:  41%|██████████████▉                     | 518/1250 [00:13<00:16, 45.62ex/s][A[A[A[A[A

#2:  38%|█████████████▋                      | 475/1250 [00:13<00:25, 30.02ex/s][A[A




#5:  36%|████████████▉                       | 451/1250 [00:13<00:31, 25.11ex/s][A[A[A[A[A



#0:  42%|███████████████▏                    | 526/1250 [00:13<00:14, 48.70ex/s][A[A[A[A


#3:  44%|███████████████▉                    | 552/1250 [00:13<00:09, 70.99ex/s][A[A[A






#7:  39%|██████████████▏                     | 493/1250 [00:13<00:18, 41.81ex/s][A[A[A[A[A[A[A





#6:  33%|███████████▋                        | 407/1250 [00:13<00:34, 24.58ex/s][A[A[A[A[A[A




#5:  37%|█████████████▏                      | 460/1250 [00:13<00:21, 36.10ex/s][A[A[A[A[A

#0:  43%|███████████████▎                    | 533/1250 [00:13<00:15, 46.88ex/s][A[A





#6:  33%|███████████▊                        | 411/1250 [00:13<00:33, 25.00ex/s][A[A[A[A[A[A




#5:  37%|███████████

#0:  52%|██████████████████▋                 | 648/1250 [00:15<00:15, 37.90ex/s][A

#2:  46%|████████████████▌                   | 575/1250 [00:15<00:19, 34.46ex/s][A[A



#4:  42%|███████████████▏                    | 526/1250 [00:15<00:24, 30.11ex/s][A[A[A[A
#1:  46%|████████████████▋                   | 579/1250 [00:15<00:14, 47.60ex/s][A




#5:  43%|███████████████▍                    | 537/1250 [00:15<00:39, 18.05ex/s][A[A[A[A[A


#0:  52%|██████████████████▉                 | 656/1250 [00:16<00:13, 43.65ex/s][A[A[A



#4:  43%|███████████████▍                    | 537/1250 [00:15<00:16, 42.57ex/s][A[A[A[A






#7:  47%|████████████████▊                   | 583/1250 [00:15<00:15, 42.45ex/s][A[A[A[A[A[A[A

#2:  46%|████████████████▋                   | 580/1250 [00:16<00:21, 31.56ex/s][A[A


#0:  53%|███████████████████                 | 663/1250 [00:16<00:13, 45.01ex/s][A[A[A




#5:  44%|███████████████▋                    | 546/1250 [00:16<00:30

#2:  53%|███████████████████                 | 662/1250 [00:18<00:09, 60.15ex/s][A[A


#3:  59%|█████████████████████▎              | 740/1250 [00:18<00:15, 33.94ex/s][A[A[A





#0:  57%|████████████████████▌               | 715/1250 [00:18<00:16, 31.70ex/s][A[A[A[A[A[A




#5:  49%|█████████████████▋                  | 615/1250 [00:18<00:22, 28.47ex/s][A[A[A[A[A






#7:  54%|███████████████████▎                | 670/1250 [00:18<00:21, 27.04ex/s][A[A[A[A[A[A[A



#0:  58%|████████████████████▋               | 719/1250 [00:18<00:16, 32.71ex/s][A[A[A[A
#1:  52%|██████████████████▊                 | 653/1250 [00:18<00:32, 18.22ex/s][A





#6:  48%|█████████████████▍                  | 605/1250 [00:18<00:13, 47.99ex/s][A[A[A[A[A[A






#7:  54%|███████████████████▍                | 674/1250 [00:18<00:19, 28.99ex/s][A[A[A[A[A[A[A


#0:  58%|████████████████████▉               | 729/1250 [00:18<00:11, 44.59ex/s][A[A[A





#6:  49%|███████████

#5:  56%|████████████████████▎               | 705/1250 [00:20<00:09, 55.01ex/s][A[A[A[A[A



#4:  58%|████████████████████▋               | 720/1250 [00:20<00:20, 26.25ex/s][A[A[A[A

#2:  62%|██████████████████████▏             | 772/1250 [00:20<00:10, 44.28ex/s][A[A





#6:  58%|████████████████████▊               | 721/1250 [00:20<00:11, 45.13ex/s][A[A[A[A[A[A


#0:  64%|██████████████████████▉             | 796/1250 [00:20<00:12, 35.93ex/s][A[A[A



#4:  58%|████████████████████▊               | 724/1250 [00:20<00:19, 27.22ex/s][A[A[A[A





#6:  60%|█████████████████████▋              | 751/1250 [00:20<00:05, 88.70ex/s][A[A[A[A[A[A

#2:  62%|██████████████████████▍             | 781/1250 [00:20<00:09, 48.05ex/s][A[A




#5:  57%|████████████████████▌               | 714/1250 [00:21<00:10, 49.23ex/s][A[A[A[A[A


#0:  65%|███████████████████████▎            | 810/1250 [00:21<00:09, 48.13ex/s][A[A[A



#4:  58%|████████████████████▉           

#2:  67%|████████████████████████▎           | 843/1250 [00:23<00:08, 48.75ex/s][A[A





#6:  68%|████████████████████████▍           | 848/1250 [00:23<00:07, 54.35ex/s][A[A[A[A[A[A


#3:  76%|███████████████████████████▍        | 953/1250 [00:23<00:08, 36.42ex/s][A[A[A




#5:  65%|███████████████████████▍            | 812/1250 [00:23<00:10, 41.43ex/s][A[A[A[A[A






#0:  73%|██████████████████████████          | 907/1250 [00:23<00:14, 24.29ex/s][A[A[A[A[A[A[A



#4:  67%|████████████████████████▏           | 839/1250 [00:23<00:09, 44.64ex/s][A[A[A[A






#7:  67%|████████████████████████            | 836/1250 [00:23<00:13, 29.95ex/s][A[A[A[A[A[A[A
#0:  73%|██████████████████████████▎         | 912/1250 [00:23<00:13, 24.77ex/s][A




#5:  65%|███████████████████████▌            | 818/1250 [00:23<00:11, 37.29ex/s][A[A[A[A[A



#4:  68%|████████████████████████▎           | 846/1250 [00:23<00:09, 43.62ex/s][A[A[A[A

#2:  68%|███████████████

#1:  74%|██████████████████████████▊         | 931/1250 [00:25<00:11, 26.88ex/s][A




#5:  71%|█████████████████████████▋          | 890/1250 [00:25<00:17, 21.08ex/s][A[A[A[A[A






#0:  80%|████████████████████████████▋       | 998/1250 [00:25<00:07, 32.65ex/s][A[A[A[A[A[A[A
#1:  76%|███████████████████████████▎        | 948/1250 [00:25<00:06, 47.09ex/s][A


#3:  85%|█████████████████████████████▊     | 1066/1250 [00:25<00:05, 34.05ex/s][A[A[A





#6:  77%|███████████████████████████▊        | 965/1250 [00:25<00:07, 39.80ex/s][A[A[A[A[A[A



#4:  75%|██████████████████████████▉         | 937/1250 [00:26<00:07, 44.21ex/s][A[A[A[A

#2:  77%|███████████████████████████▊        | 966/1250 [00:26<00:06, 42.64ex/s][A[A




#5:  72%|█████████████████████████▊          | 898/1250 [00:26<00:13, 25.67ex/s][A[A[A[A[A
#0:  80%|████████████████████████████       | 1002/1250 [00:26<00:08, 29.79ex/s][A
#1:  78%|████████████████████████████▏       | 977/1250 [00:

#3:  90%|███████████████████████████████▍   | 1124/1250 [00:28<00:03, 34.93ex/s][A[A[A






#7:  81%|████████████████████████████▏      | 1008/1250 [00:28<00:14, 17.11ex/s][A[A[A[A[A[A[A
#1:  85%|█████████████████████████████▊     | 1066/1250 [00:28<00:03, 56.90ex/s][A





#6:  85%|█████████████████████████████▋     | 1062/1250 [00:28<00:04, 37.81ex/s][A[A[A[A[A[A




#0:  85%|█████████████████████████████▊     | 1065/1250 [00:28<00:06, 26.73ex/s][A[A[A[A[A




#5:  81%|████████████████████████████▍      | 1015/1250 [00:28<00:04, 49.83ex/s][A[A[A[A[A


#3:  90%|███████████████████████████████▋   | 1130/1250 [00:28<00:03, 35.72ex/s][A[A[A
#1:  86%|██████████████████████████████     | 1074/1250 [00:28<00:03, 54.06ex/s][A





#6:  85%|█████████████████████████████▉     | 1067/1250 [00:28<00:05, 35.13ex/s][A[A[A[A[A[A

#0:  86%|█████████████████████████████▉     | 1070/1250 [00:28<00:06, 26.94ex/s][A[A


#3:  91%|███████████████████████████████▊   

#2:  91%|████████████████████████████████   | 1143/1250 [00:31<00:02, 43.64ex/s][A[A






#7:  89%|███████████████████████████████    | 1109/1250 [00:31<00:02, 48.36ex/s][A[A[A[A[A[A[A




#0:  94%|████████████████████████████████▊  | 1172/1250 [00:31<00:02, 33.01ex/s][A[A[A[A[A



#4:  92%|████████████████████████████████   | 1146/1250 [00:31<00:02, 35.64ex/s][A[A[A[A


#3:  98%|██████████████████████████████████▏| 1221/1250 [00:31<00:00, 35.60ex/s][A[A[A





#6:  92%|████████████████████████████████▎  | 1153/1250 [00:31<00:02, 37.01ex/s][A[A[A[A[A[A
#1:  91%|███████████████████████████████▋   | 1132/1250 [00:31<00:04, 24.69ex/s][A

#2:  92%|████████████████████████████████▏  | 1149/1250 [00:31<00:02, 37.33ex/s][A[A





#6:  94%|████████████████████████████████▋  | 1169/1250 [00:31<00:01, 57.73ex/s][A[A[A[A[A[A




#5:  90%|███████████████████████████████▌   | 1128/1250 [00:31<00:03, 40.56ex/s][A[A[A[A[A






#0:  94%|██████████████████████

#7: 100%|███████████████████████████████████| 1250/1250 [00:34<00:00, 36.66ex/s][A[A[A[A[A[A[A


In [38]:
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_seqs(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

train_dataset = train_dataset.map(
    group_seqs, 
    num_proc=8,
    batched=True,
)

#0:   0%|                                                 | 0/2 [00:00<?, ?ba/s]
#1:   0%|                                                 | 0/2 [00:00<?, ?ba/s][A

#2:   0%|                                                 | 0/2 [00:00<?, ?ba/s][A[A



#4:   0%|                                                 | 0/2 [00:00<?, ?ba/s][A[A[A[A


#3:   0%|                                                 | 0/2 [00:00<?, ?ba/s][A[A[A





#6:   0%|                                                 | 0/2 [00:00<?, ?ba/s][A[A[A[A[A[A




#5:   0%|                                                 | 0/2 [00:00<?, ?ba/s][A[A[A[A[A






#7:   0%|                                                 | 0/2 [00:00<?, ?ba/s][A[A[A[A[A[A[A
#1:  50%|████████████████████▌                    | 1/2 [00:00<00:00,  1.50ba/s][A


#3:  50%|████████████████████▌                    | 1/2 [00:00<00:00,  1.53ba/s][A[A[A



#4:  50%|████████████████████▌                    | 1/2 [00:00<00:00

In [39]:
model = AutoModelForCausalLM.from_pretrained('sshleifer/tiny-gpt2')
model.resize_token_embeddings(len(tokenizer.vocab))
# model.get_input_embeddings()

Embedding(218, 2)

In [40]:
trainer = Trainer(
    model=model,
    train_dataset=train_dataset
    # Data collator will default to DataCollatorWithPadding, so we change it.
    # data_collator=default_data_collator,
)

In [41]:
trainer.train()

***** Running training *****
  Num examples = 11098
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 4164


Step,Training Loss
500,5.3421
1000,5.276
1500,5.218
2000,5.1674
2500,5.1257
3000,5.0928
3500,5.0703
4000,5.0576


Saving model checkpoint to tmp_trainer/checkpoint-500
Configuration saved in tmp_trainer/checkpoint-500/config.json
Model weights saved in tmp_trainer/checkpoint-500/pytorch_model.bin
Saving model checkpoint to tmp_trainer/checkpoint-1000
Configuration saved in tmp_trainer/checkpoint-1000/config.json
Model weights saved in tmp_trainer/checkpoint-1000/pytorch_model.bin
Saving model checkpoint to tmp_trainer/checkpoint-1500
Configuration saved in tmp_trainer/checkpoint-1500/config.json
Model weights saved in tmp_trainer/checkpoint-1500/pytorch_model.bin
Saving model checkpoint to tmp_trainer/checkpoint-2000
Configuration saved in tmp_trainer/checkpoint-2000/config.json
Model weights saved in tmp_trainer/checkpoint-2000/pytorch_model.bin
Saving model checkpoint to tmp_trainer/checkpoint-2500
Configuration saved in tmp_trainer/checkpoint-2500/config.json
Model weights saved in tmp_trainer/checkpoint-2500/pytorch_model.bin
Saving model checkpoint to tmp_trainer/checkpoint-3000
Configuration

TrainOutput(global_step=4164, training_loss=5.164203669450011, metrics={'train_runtime': 1280.0138, 'train_samples_per_second': 26.011, 'train_steps_per_second': 3.253, 'total_flos': 31092867072.0, 'train_loss': 5.164203669450011, 'epoch': 3.0})

### to do: 
- funciones computar metricas
- callback guardar modelos
- guardar metricas training
- eval loop y guardar metricas
- __script + hydra__
- data parallelism?
- guardar dataset preprocesado? __load_from_cache_file__ en dataset.map