In [None]:
from pathlib import Path
from pprint import pprint
from tempfile import TemporaryDirectory
from loguru import logger

from mido import MidiFile

from yoshimidi import player
from yoshimidi.data import midi_parsing, token_parsing, track_parsing


In [None]:
PATH_INDEX = 18
TRACK_INDEX = 2
CHANNEL_INDEX = None

midi_path = list(Path("../out/dataset/dataset_raw/lmd_full/").rglob("*.mid"))[PATH_INDEX]
player.play(midi_path)

midi_file = MidiFile(midi_path)
logger.info(f"num tracks: {len(midi_file.tracks)}")
logger.info(f"track num messages: {[len(track) for track in midi_file.tracks]}")
midi_track = midi_file.tracks[TRACK_INDEX]
ym_track = track_parsing.from_midi(midi_track, ticks_per_beat=midi_file.ticks_per_beat)
assert ym_track is not None
ym_tokens = [token_parsing.from_channel(channel) for channel in ym_track.channels.values()]
if CHANNEL_INDEX is not None:
    ym_tokens = [ym_tokens[CHANNEL_INDEX]]

ym_track_recons = track_parsing.from_tokens(ym_tokens)
midi_file_recons = midi_parsing.from_tracks([ym_track_recons])

with TemporaryDirectory() as temp:
    midi_path_recons = Path(temp) / "recons.mid"
    midi_file_recons.save(midi_path_recons)
    player.play(midi_path_recons)