In [None]:
import numpy as np
import os
import pretty_midi
import mir_eval
from piano_transcription_inference import PianoTranscription, sample_rate, load_audio
import torch.cuda

path_maps = "C:/Users/amarmore/Desktop/Audio samples/MAPS"
path_outputs = f"{path_maps}/outputs_ByteDance"

def load_ref_in_arrays(ref_path, time_limit = 30):
    ref_intervals = []
    ref_pitches = []

    with open(ref_path) as f:
        truth_lines = f.readlines()[1:] # To discard the title/legend in ground truth

    for lines_index in range(len(truth_lines)):
        # Creates a list with the line of the reference, splitted on tabulations
        line_to_array = (truth_lines[lines_index].replace("\n", "")).split("\t")
        if line_to_array != [""]:
            if (time_limit == None) or ((time_limit != None) and (float(line_to_array[0]) < time_limit)): # if onset > time_limit (note outside of the cropped excerpt)
                ref_intervals.append([float(line_to_array[0]), float(line_to_array[1])])
                pitch = float(line_to_array[2])
                ref_pitches.append(pitch)
    return np.array(ref_intervals), np.array(ref_pitches)

def run_estimates_piano(piano_name):
    for file in os.listdir(f"{path_maps}/{piano_name}/MUS"):
        if "wav" in file:
            (audio, _) = load_audio(f"{path_maps}/{piano_name}/MUS/{file}", sr=sample_rate, mono=True)

            # Transcriptor
            device = "cuda" if torch.cuda.is_available() else "cpu"
            transcriptor = PianoTranscription(device='cpu')    # 'cuda' | 'cpu'

            # Transcribe and write out to MIDI file
            transcribed_dict = transcriptor.transcribe(audio, f"{path_outputs}/{piano_name}/{file.replace('.wav', '')}.mid")
    print("Done.")
    
def compute_score_this_piano(piano_name, time_limit = 30):
    scores = []
    accs = []
    
    for file in os.listdir(f"{path_maps}/{piano_name}/MUS"):
        if "wav" in file:
            ref_intervals, ref_pitches = load_ref_in_arrays(f"{path_maps}/{piano_name}/MUS/{file.replace('wav', 'txt')}", time_limit = time_limit)
            est_intervals = []
            est_pitches = []
            try: # If already computed
                os.open(f"{path_outputs}/{piano_name}/{file.replace('.wav', '')}.mid", os.O_RDONLY)
            except FileNotFoundError:
                run_estimates_piano(piano_name)
                
            pm = pretty_midi.PrettyMIDI(f"{path_outputs}/{piano_name}/{file.replace('.wav', '')}.mid")
            for instrument in pm.instruments:
                for note in instrument.notes:
                    start = note.start
                    if (time_limit == None) or ((time_limit != None) and (start < time_limit)):
                        end = note.end
                        est_intervals.append([start, end])
                        pitch = note.pitch
                        est_pitches.append(pitch)
            est_intervals = np.array(est_intervals)
            est_pitches = np.array(est_pitches)
            (prec, rec, f_mes, _) = mir_eval.transcription.precision_recall_f1_overlap(ref_intervals, ref_pitches, est_intervals, est_pitches, onset_tolerance=0.05, pitch_tolerance=50.0, offset_ratio=None, strict=False, beta=1.0)
            scores.append([prec, rec, f_mes])
            matching = mir_eval.transcription.match_notes(ref_intervals, ref_pitches, est_intervals,est_pitches, onset_tolerance=0.05,pitch_tolerance=50.0,offset_ratio=None)
            TP = len(matching)
            try:
                FP = int(TP * (1 - prec) / prec)
            except ZeroDivisionError:
                FP = 0
            try:
                FN = int(TP * (1 - rec) / rec)
            except ZeroDivisionError:
                FN = 0
            try:
                accuracy = TP/(TP + FP + FN)
            except ZeroDivisionError:
                accuracy = 0
            accs.append(accuracy)
    return [np.mean(np.array(scores)[:,i]) for i in range(3)], np.mean(accs)

In [None]:
list_pianos = ["AkPnBcht", "AkPnBsdf", "AkPnCGdD", "AkPnStgb", "ENSTDkAm", "ENSTDkCl", "SptkBGAm", "SptkBGCl", "StbgTGd2"]
for piano in list_pianos:
    print(piano)
    scores, acc = compute_score_this_piano(piano, time_limit = 30)
    print(f"30sec - Prec: {round(scores[0],4)}, Recall: {round(scores[1],4)}, F measure: {round(scores[2],4)}, accuracy: {round(acc,4)}")
    
    #scores, acc = compute_score_this_piano(piano, time_limit = None)
    #print(f"Without limit - Prec: {scores[0]}, Recall: {scores[1]}, F measure: {scores[2]}, accuracy: {acc}")