In [279]:
from miditok import CPWord, MuMIDI
from miditok.utils import get_midi_programs
from miditoolkit import MidiFile
import numpy as np
import torch
print(f"Cuda available: {torch.cuda.is_available()}")

Cuda available: True


In [280]:
BAR_TOKEN = (1, 2, 1, 1, 1)
PIANO_INSTRUMENTS = ((0, False), (0, False))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def split_track_by_bars(track):
    out = []
    last_bar_index = 0
    for index, token in enumerate(track):
        if tuple(token) == BAR_TOKEN and index != 0:
            out.append(track[last_bar_index:index])
            last_bar_index = index
    out.append(track[last_bar_index:])
    return out

def split_song_by_bars(song):
    right = split_track_by_bars(song[0])
    left = split_track_by_bars(song[1])
    assert len(right) == len(left)
    return right, left

def save_measures_as_midi(right, left, tokenizer, path):
    right_flat = [token for measure in right for token in measure[:]]
    left_flat = [token for measure in left for token in measure[:]]
    midi = tokenizer([right_flat, left_flat], PIANO_INSTRUMENTS)
    midi.dump(path)

def batchify_song(measures, batch_size, max_measure_length):
    right, left = measures
    assert len(right) == len(left)

    measure_data = []
    for measure in range(len(right)):
        if len(right[measure]) > max_measure_length: print(f"Max measure length of {max_measure_length} exceeded in right hand on measure {measure}")
        right_data = right[measure][:max_measure_length]
        while len(right_data) < max_measure_length:
            right_data.append([0, 0, 0, 0, 0])

        if len(left[measure]) > max_measure_length: print(f"Max measure length of {max_measure_length} exceeded in left hand on measure {measure}")
        left_data = left[measure][:max_measure_length]
        while len(left_data) < max_measure_length:
            left_data.append([0, 0, 0, 0, 0])

        measure_data.append([right + left for right, left in zip(right_data, left_data)])

    batches = []
    this_batch = []
    for measure in measure_data:
        this_batch.append(measure)
        if len(this_batch) >= batch_size:
            batches.append(this_batch)
            this_batch = []

    if len(batches) == 0:
        batches.append(this_batch)

    return torch.Tensor(batches).to(DEVICE)
        

In [281]:
midi = MidiFile('Feliz_Navidad_easy_piano.mid')
tokenizer = CPWord()

In [282]:
right, left  = split_song_by_bars(tokenizer(midi))

In [283]:
save_measures_as_midi(right, left, tokenizer, "test.mid")

In [284]:
batchify_song((right, left), 10, 32).shape

torch.Size([3, 10, 32, 10])