In [None]:
import glob
import json
import os
from pathlib import Path

from midistral.audio_analysis import (
    download_models,
    get_chords,
    get_mood_and_genre,
)
from midistral.midi_utils import (
    convert_midi_to_ogg,
    get_duration,
    get_instruments,
    get_key,
    get_tempo,
    get_time_signature,
)
from tqdm import tqdm

In [None]:
download_models()

In [None]:
NOTEBOOKS_FOLDER = Path(os.getcwd())
OUTPUT_FOLDER = NOTEBOOKS_FOLDER.parent / "output"
DATA_FOLDER = NOTEBOOKS_FOLDER.parent / "data"

DATASET_NAME = "vgm"
ANNOTATION_OUTPUT_PATH = OUTPUT_FOLDER / f"annotations_{DATASET_NAME}_output.jsonl"
TMP_WAV_FOLDER = OUTPUT_FOLDER / "tmp_wav" / DATASET_NAME
MIDI_FOLDER_TO_PROCESS = DATA_FOLDER / DATASET_NAME

TMP_WAV_FOLDER.mkdir(exist_ok=True, parents=True)

In [None]:
mid_l = glob.glob(f"{MIDI_FOLDER_TO_PROCESS}/**/*.mid", recursive=True)
with ANNOTATION_OUTPUT_PATH.open("w", encoding="utf8") as f:
    for mid in tqdm(mid_l):
        mid_p = Path(mid)

        # generate audio wav
        audio_wav_file = TMP_WAV_FOLDER / f"{mid_p.stem}.wav"
        if not audio_wav_file.exists():
            convert_midi_to_ogg(mid_p, audio_wav_file)

        # analyse midi or wav file
        chords_out, chord_summary = get_chords(audio_wav_file)
        mood_tags, mood_cs, genre_tags, genre_cs = get_mood_and_genre(audio_wav_file)
        key = get_key(mid_p)
        time_signature = get_time_signature(mid_p)
        bpm = get_tempo(mid_p)
        dur = get_duration(mid_p)
        instrument_numbers_sorted, instrument_summary = get_instruments(mid_p)

        # log analysis results
        row = {
            "location": str(mid_p),
            "genre": genre_tags[:2],
            "genre_prob": genre_cs[:2],
            "mood": mood_tags,
            "mood_prob": mood_cs,
            "key": key,
            "time_signature": time_signature,
            "tempo": bpm,
            "tempo_word": "",
            "duration": dur,
            "duration_word": "",
            "chord_summary": chord_summary,
            "chord_summary_occurence": 0,
            "instrument_summary": instrument_summary,
            "instrument_numbers_sorted": instrument_numbers_sorted,
            "all_chords": [e[0] for e in chords_out],
            "all_chords_timestamps": [e[1] for e in chords_out],
        }
        f.write(json.dumps(row) + "\n")
        f.flush()
