In [None]:
# !/usr/bin/python3 -m pip install --upgrade pip
# !pip install -e ..
# !pip install nvidia-cudnn-cu11

In [None]:
import glob
import json
import os
from pathlib import Path
from multiprocessing import Pool

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # or any {'0', '1', '2'}

import tensorflow as tf

from midistral.audio_analysis import (
    download_models,
    get_chords,
    get_mood_and_genre,
)
from midistral.midi_utils import (
    convert_midi_to_ogg,
    get_duration,
    get_duration_caps,
    get_instruments,
    get_key,
    get_tempo,
    get_tempo_caps,
    get_time_signature,
)
from tqdm import tqdm


In [None]:
download_models()

In [None]:
from midistral.prepare_dataset import (
    extract_midicaps_files
)
extract_midicaps_files()

In [None]:
NOTEBOOKS_FOLDER = Path(os.getcwd())
OUTPUT_FOLDER = NOTEBOOKS_FOLDER.parent / "output"
DATA_FOLDER = NOTEBOOKS_FOLDER.parent / "data"

DATASET_NAME = "irishman"
ANNOTATION_OUTPUT_PATH = OUTPUT_FOLDER / f"annotations_{DATASET_NAME}_output.jsonl"
TMP_WAV_FOLDER = OUTPUT_FOLDER / "tmp_wav" / DATASET_NAME
MIDI_FOLDER_TO_PROCESS = DATA_FOLDER / DATASET_NAME

TMP_WAV_FOLDER.mkdir(exist_ok=True, parents=True)

In [None]:
mid_l = glob.glob(f"{MIDI_FOLDER_TO_PROCESS}/**/*.mid", recursive=True)

def get_relative_path(mid: str) ->str:
    return mid.replace(str(DATA_FOLDER) + '/', '')

def get_filename_from_full_path(mid: str) ->str:
    relative_path = get_relative_path(mid)
    return relative_path.replace('/', '_').replace(".mid", "").strip("_")

# generate audio wav
mid_wav_l = [(Path(mid), TMP_WAV_FOLDER / f"{get_filename_from_full_path(mid)}.wav") for mid in mid_l]
with Pool(5) as pool:
    pool.starmap(convert_midi_to_ogg, mid_wav_l)

with ANNOTATION_OUTPUT_PATH.open("w", encoding="utf8") as f:
    for mid in tqdm(mid_l):
        mid_p = Path(mid)
        # analyse midi or wav file

        audio_wav_file = TMP_WAV_FOLDER / f"{get_filename_from_full_path(mid)}.wav"
        # while not audio_wav_file.exists():
        #     print(f"sleeping for {audio_wav_file}")
        #     time.sleep(10)

        try:
            dur = get_duration(mid_p)
            duration_word = get_duration_caps(dur)
            if dur < 5 or dur > 60:
                continue

            audio_wav_file = TMP_WAV_FOLDER / f"{get_filename_from_full_path(mid)}.wav"
            chords_out, chord_summary, chord_summary_occurence = get_chords(audio_wav_file)
            mood_tags, mood_cs, genre_tags, genre_cs = get_mood_and_genre(audio_wav_file)
            key = get_key(mid_p)
            time_signature = get_time_signature(mid_p)
            tempo = get_tempo(mid_p)
            tempo_word = get_tempo_caps(tempo)
            instrument_numbers_sorted, instrument_summary = get_instruments(mid_p)

        except Exception as e:
            print(e)
            continue

        # log analysis results
        row = {
            "location": get_relative_path(str(mid_p)),
            "genre": genre_tags[:2],
            "genre_prob": genre_cs[:2],
            "mood": mood_tags,
            "mood_prob": mood_cs,
            "key": key,
            "time_signature": time_signature,
            "tempo": tempo,
            "tempo_word": tempo_word,
            "duration": dur,
            "duration_word": duration_word,
            "chord_summary": chord_summary,
            "chord_summary_occurence": chord_summary_occurence,
            "instrument_summary": instrument_summary,
            "instrument_numbers_sorted": instrument_numbers_sorted,
            "all_chords": [e[0] for e in chords_out],
            "all_chords_timestamps": [e[1] for e in chords_out],
        }
        f.write(json.dumps(row) + "\n")
        f.flush()
