# Model Mistral + REMI + Maestro

First baseline model using a Mistral transformer and REMI tokenizer, inspired by https://github.com/Natooz/MidiTok/blob/main/colab-notebooks/Example_HuggingFace_Mistral_Transformer.ipynb

In [1]:
from copy import deepcopy
from datetime import datetime
import os
import subprocess

from evaluate import load as load_metric
from miditok import REMI, TokenizerConfig
from miditok.data_augmentation import augment_dataset
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training
from pathlib import Path
from sklearn.model_selection import train_test_split
from torch import argmax
from torch.cuda import is_available as cuda_available, is_bf16_supported
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig
from transformers.trainer_utils import set_seed
import wandb

In [2]:
TRANSFORMER_NAME = "mistral-309M"
TOKENIZER_NAME = "remi"
DATASET_NAME = "maestro"
MODEL_VERSION = "1"

MODEL_NAME = f"{TRANSFORMER_NAME}_{TOKENIZER_NAME}_{DATASET_NAME}_v{MODEL_VERSION}"

print(f"Model:\n{MODEL_NAME}")

Model:
mistral-309M_remi_maestro_v1


In [3]:
BASE_PATH = Path(".")
#BASE_PATH = Path("/hpcwork/lect0148")

DATA_RAW_PATH = BASE_PATH / "data"

MODEL_BASE_PATH = BASE_PATH / "models" / MODEL_NAME
DATA_PROCESSED_PATH = MODEL_BASE_PATH / "data_processed"
MODEL_PATH = MODEL_BASE_PATH / "model"
RUNS_PATH = MODEL_BASE_PATH / "runs"
OUTPUT_PATH = MODEL_BASE_PATH / "output"

In [4]:
os.environ["WANDB_ENTITY"] = "jonathanlehmkuhl-rwth-aachen-university"
os.environ["WANDB_PROJECT"] = "piano-transformer"
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mnico-bremes[0m ([33mjonathanlehmkuhl-rwth-aachen-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
SEED = 222
set_seed(SEED)

In [6]:
maestro_files = list((DATA_RAW_PATH / "maestro").resolve().glob("**/*.midi"))

## Tokenizer

In [7]:
tokenizer = REMI(params=MODEL_BASE_PATH / "tokenizer.json")

## Data Preparation

In [8]:
# Split data into train/validation/test datasets
midi_files_train, midi_files_temp = train_test_split(maestro_files, test_size=0.3, random_state=SEED)
midi_files_valid, midi_files_test = train_test_split(midi_files_temp, test_size=0.5, random_state=SEED)

In [9]:
# Split MIDIs into smaller chunks that approximately matches the token sequence length for training
for midi_files, subset_name in ((midi_files_train, "train"), (midi_files_valid, "valid"), (midi_files_test, "test")):
    subset_chunks_dir = Path(DATA_PROCESSED_PATH / f"maestro_{subset_name}")
    split_files_for_training(
        files_paths=midi_files,
        tokenizer=tokenizer,
        save_dir=subset_chunks_dir,
        max_seq_len=2048,
        num_overlap_bars=2,
    )

Splitting music files (models/mistral-309M_remi_maestro_v1/data_processed/maestro_train):  80%|███████▉  | 712/893 [00:13<00:03, 51.41it/s]


KeyboardInterrupt: 

In [None]:
# Augment training data set
augment_dataset(
    Path(DATA_PROCESSED_PATH / "maestro_train"),
    pitch_offsets=[-12, 12],
    velocity_offsets=[-4, 4],
    duration_offsets=[-0.5, 0.5],
)

In [None]:
# Create pytorch datasets
midi_files_train = list((DATA_PROCESSED_PATH / "maestro_train").glob("**/*.midi"))
midi_files_valid = list((DATA_PROCESSED_PATH / "maestro_valid").glob("**/*.midi"))
midi_files_test = list((DATA_PROCESSED_PATH / "maestro_train").glob("**/*.midi"))

dataset_kwargs = {
    "max_seq_len": 2048,
    "tokenizer": tokenizer,
    "bos_token_id": tokenizer["BOS_None"],
    "eos_token_id": tokenizer["EOS_None"]
}
dataset_train = DatasetMIDI(midi_files_train, **dataset_kwargs)
dataset_valid = DatasetMIDI(midi_files_valid, **dataset_kwargs)
dataset_test = DatasetMIDI(midi_files_test, **dataset_kwargs)

collator = DataCollator(tokenizer["PAD_None"], copy_inputs_as_labels=True)

In [None]:
len(midi_files_train)

## Model

In [None]:
model_config = MistralConfig(
    vocab_size=len(tokenizer),
    hidden_size=896,
    intermediate_size=896 * 4,
    num_hidden_layers=24,
    num_attention_heads=14,
    num_key_value_heads=14,
    sliding_window=2048,
    max_position_embeddings=2048,
    pad_token_id=tokenizer["PAD_None"],
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)
model = AutoModelForCausalLM.from_config(model_config)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## Training

In [None]:
def preprocess_logits(logits, _):
    pred_ids = argmax(logits, dim=-1)
    return pred_ids

In [None]:
USE_CUDA = cuda_available()
if not cuda_available():
    FP16 = FP16_EVAL = BF16 = BF16_EVAL = False
elif is_bf16_supported():
    BF16 = BF16_EVAL = True
    FP16 = FP16_EVAL = False
else:
    BF16 = BF16_EVAL = False
    FP16 = FP16_EVAL = True

In [None]:
trainer_config = TrainingArguments(
    output_dir=RUNS_PATH,
    overwrite_output_dir=False,
    do_train=True,
    do_eval=True,
    do_predict=False,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=3,
    eval_strategy="epoch",
    eval_accumulation_steps=None,
    #eval_steps=5,
    learning_rate=1e-4,
    weight_decay=0.01,
    max_grad_norm=3.0,
    #max_steps=5,
    num_train_epochs=20,
    lr_scheduler_type="cosine_with_restarts",
    warmup_ratio=0.3,
    log_level="debug",
    logging_strategy="steps",
    logging_steps=20,
    save_strategy="epoch",
    #save_steps=5,
    #save_total_limit=5,
    no_cuda=not USE_CUDA,
    seed=SEED,
    fp16=FP16,
    fp16_full_eval=FP16_EVAL,
    bf16=BF16,
    bf16_full_eval=BF16_EVAL,
    load_best_model_at_end=True,
    label_smoothing_factor=0.,
    optim="adamw_torch",
    report_to=["wandb"],
    run_name=MODEL_NAME,
    gradient_checkpointing=True,
)
trainer = Trainer(
    model=model,
    args=trainer_config,
    data_collator=collator,
    train_dataset=dataset_train,
    eval_dataset=dataset_valid,
    test_dataset=dataset_test,
    #compute_metrics=compute_metrics,
    callbacks=None,
    preprocess_logits_for_metrics=preprocess_logits,
)

In [None]:
train_result = trainer.train()
trainer.save_model(MODEL_PATH)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

## Generation

In [None]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)

In [None]:
generation_config = GenerationConfig(
    max_new_tokens=200,  # extends samples by 200 tokens
    num_beams=1,
    do_sample=True,
    temperature=0.9,
    top_k=15,
    top_p=0.95,
    epsilon_cutoff=3e-4,
    eta_cutoff=1e-3,
    pad_token_id=tokenizer.pad_token_id,
)

# Here the sequences are padded to the left, so that the last token along the time dimension
# is always the last token of each seq, allowing to efficiently generate by batch
collator.pad_on_left = True
collator.eos_token = None

model.eval()

In [None]:
def generate(dataset, output):
    (output_path := Path(output)).mkdir(parents=True, exist_ok=True)
    dataloader = DataLoader(dataset, batch_size=16, collate_fn=collator)
    
    count = 0
    for batch in tqdm(dataloader, desc="Generating outputs"):
        res = model.generate(
            inputs=batch["input_ids"].to(model.device),
            attention_mask=batch["attention_mask"].to(model.device),
            generation_config=generation_config
        )
    
        # Saves the generated music, as MIDI files and tokens (json)
        for prompt, continuation in zip(batch["input_ids"], res):
            generated = continuation[len(prompt):]
            tokens = [generated, prompt, continuation]
            tokens = [seq.tolist() for seq in tokens]

            midi_generated = tokenizer.decode([deepcopy(tokens[0])])
            midi_prompt = tokenizer.decode([deepcopy(tokens[1])])
            midi_full = tokenizer.decode([deepcopy(tokens[2])])

            # Name the tracks
            if midi_generated.tracks:
                midi_generated.tracks[0].name = f"Generated continuation ({len(tokens[0])} tokens)"
            if midi_prompt.tracks:
                midi_prompt.tracks[0].name = f"Original prompt ({len(tokens[1])} tokens)"
            if midi_full.tracks:
                midi_full.tracks[0].name = f"Full sequence ({len(tokens[2])} tokens)"

            # Save each as a separate MIDI file
            midi_generated.dump_midi(output_path / f"{count}_generated.midi")
            midi_prompt.dump_midi(output_path / f"{count}_prompt.midi")
            midi_full.dump_midi(output_path / f"{count}_full.midi")
            tokenizer.save_tokens(tokens, output_path / f"{count}.json")
    
            count += 1

In [None]:
generate(dataset_test, OUTPUT_PATH / "test")

## Convert to WAV

In [None]:
soundfont_path = "FluidR3_GM.sf2"
midi_folder = OUTPUT_PATH / "test"
output_folder = OUTPUT_PATH / "test_wav"

for filename in os.listdir(midi_folder):
    if filename.lower().endswith(".midi"):
        midi_path = os.path.join(midi_folder, filename)
        wav_filename = os.path.splitext(filename)[0] + ".wav"
        wav_path = os.path.join(output_folder, wav_filename)
        
        print(f"Converting {filename} to {wav_filename}...")
        
        # Build FluidSynth command
        command = [
            "fluidsynth",
            "-ni",  # no interactive mode
            soundfont_path,
            midi_path,
            "-F", wav_path,  # output file
            "-r", "44100"    # sample rate
        ]
        
        subprocess.run(command, check=True)