In [10]:
from pathlib import Path
from tempfile import TemporaryDirectory

import torch
from loguru import logger
from mido import MidiFile

from yoshimidi import player
from yoshimidi.data.parse import midi_parsing, one_hot_parsing, token_parsing, track_parsing

In [14]:
PATH_INDEX = 1
TRACK_INDEX = 11
CHANNEL_INDEX = None

midi_path = list(Path("../out/dataset/01_raw/lmd_full/").rglob("*.mid"))[PATH_INDEX]
# player.play(midi_path)

midi_file = MidiFile(midi_path)
logger.info(f"num tracks: {len(midi_file.tracks)}")
logger.info(f"track num messages: {[len(track) for track in midi_file.tracks]}")
midi_track = midi_file.tracks[TRACK_INDEX]
ym_track = track_parsing.from_midi(midi_track, ticks_per_beat=midi_file.ticks_per_beat)
assert ym_track is not None
ym_tokens = [token_parsing.from_channel(channel) for channel in ym_track.channels.values()]
if CHANNEL_INDEX is not None:
    ym_tokens = [ym_tokens[CHANNEL_INDEX]]
ym_one_hot = one_hot_parsing.from_tokens(ym_tokens[0], device=torch.device("cpu"), dtype=torch.float32)

ym_track_recons = track_parsing.from_tokens(ym_tokens)
midi_file_recons = midi_parsing.from_tracks([ym_track_recons])

with TemporaryDirectory() as temp:
    midi_path_recons = Path(temp) / "recons.mid"
    midi_file_recons.save(midi_path_recons)
    player.play(midi_path_recons)

[32m2023-07-31 17:20:26.980[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mnum tracks: 13[0m
[32m2023-07-31 17:20:26.980[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m10[0m - [1mtrack num messages: [2, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 311, 47][0m
[32m2023-07-31 17:20:26.993[0m | [1mINFO    [0m | [36myoshimidi.player[0m:[36mplay[0m:[36m12[0m - [1mPlaying /var/folders/7w/66fh_d3s5hb0br9f7wtqww1h0000gn/T/tmpns8x_7yp/recons.mid[0m
[32m2023-07-31 17:20:29.292[0m | [1mINFO    [0m | [36myoshimidi.player[0m:[36mplay[0m:[36m19[0m - [1mFinishing due to interrupt[0m
