In [None]:
!pip install pretty_midi==0.2.9

In [1]:
import json

with open('ryy08/sync.json', 'r') as f:
  RWC_TAG_TO_SYNC = json.load(f)

with open('ryy08/ref.json', 'r') as f:
  RWC_TAG_TO_REF_SIXTEENTHS = json.load(f)

In [2]:
import numpy as np
import pretty_midi
from scipy.interpolate import interp1d

def _extrapolating_linear_interp1d(a, b):
    return interp1d(a, b, kind="linear", fill_value="extrapolate")


def create_beat_to_time_fn(beats, times):
    return _extrapolating_linear_interp1d(beats, times)


def create_time_to_beat_fn(beats, times):
    return _extrapolating_linear_interp1d(times, beats)


def fix_sync(midi, sync_ms):
    midi_sync = pretty_midi.PrettyMIDI()
    midi_sync.instruments = [pretty_midi.Instrument(0)]
    for i, n in enumerate(midi.instruments[0].notes):
        note = pretty_midi.Note(
            start=max(n.start + (sync_ms / 1000), 0),
            end=n.end + (sync_ms / 1000),
            velocity=n.velocity,
            pitch=n.pitch,
        )
        midi_sync.instruments[0].notes.append(note)
    return midi_sync


def quantize_midi_to_sixteenths(midi, beats, times):
    assert len(beats) == len(times)
    num_beats = len(beats) - 1
    assert isinstance(num_beats, int)
    assert num_beats % 4 == 0
    num_sixteenths = num_beats * 4
    
    # Quantize fn
    time_to_beat_fn = create_time_to_beat_fn(beats, times)
    def _quantize_sixteenth(t):
        b = time_to_beat_fn(t)
        b *= 4
        return round(float(b))

    # Quantize notes to sixteenths
    notes = sorted(midi.instruments[0].notes, key=lambda n: n.start)
    notes_quantized = []
    for i, n in enumerate(notes):
        # Quantize start
        s = _quantize_sixteenth(n.start)
        if s < 0:
            continue
        if s >= num_sixteenths:
            continue

        # Quantize end
        e = _quantize_sixteenth(n.end)
        e = max(_quantize_sixteenth(n.end), s + 1)
        if i + 1 < len(notes):
            e = min(_quantize_sixteenth(notes[i+1].start), e)
        e = min(e, num_sixteenths)
        assert e >= s
        
        # Compute duration
        d = e - s
        if d == 0:
            continue

        notes_quantized.append((s, d, n.pitch))

    # Assert monophonic
    last_s = float('-inf')
    last_e = float('-inf')
    for s, d, p in notes_quantized:
        e = s + d
        assert s > last_s
        assert s >= last_e
        assert d > 0
        last_s = s
        last_e = e
    
    return notes_quantized


def onset_pitch_metrics(ref_sixteenths, est_midi, beats):
    # NOTE: Gold is always quantized before comparison because the reference should be precise
    
    # Quantize estimated transcription
    est_sixteenths = quantize_midi_to_sixteenths(est_midi, list(range(len(beats))), beats)
    
    ref_onset_to_pitch = {o:p for o, _, p in ref_sixteenths}
    assert len(ref_onset_to_pitch) == len(ref_sixteenths)
    est_onset_to_pitch = {o:p for o, _, p in est_sixteenths}
    assert len(est_onset_to_pitch) == len(est_sixteenths)
    
    octaves = list(range(-10, 11))
    octave_f1_scores = []
    octave_pr_scores = []
    for octave in octaves:
        true_positives = 0
        precision_denominator = len(est_onset_to_pitch)
        recall_denominator = len(ref_onset_to_pitch)
        for eo, ep in est_onset_to_pitch.items():
            true_positives += int((ep + octave * 12) == ref_onset_to_pitch.get(eo))
        p = true_positives / precision_denominator
        r = true_positives / recall_denominator
        try:
            f = 2 * ((p * r) / (p + r))
        except ZeroDivisionError:
            f = 0
        octave_f1_scores.append(f)
        octave_pr_scores.append((p, r))

    best_octave_idx = np.argmax(octave_f1_scores)
    f1 = octave_f1_scores[best_octave_idx]
    p, r = octave_pr_scores[best_octave_idx]
    
    return f1, p, r

In [3]:
from collections import defaultdict
import pretty_midi

# NOTE: Compensates for latency introduced by Onsets and Frames
RYY_OAF_SYNC_ERROR_MS = 31

method_to_metrics = defaultdict(list)
for rwc_tag, ref_sixteenths in RWC_TAG_TO_REF_SIXTEENTHS.items():
  _, _, beats = RWC_TAG_TO_SYNC[rwc_tag]
  for method in ['ryy', 'mel', 'ssh', 'ssj']:
    est = pretty_midi.PrettyMIDI(f'ryy08/est_unquant_midi/{method}_{rwc_tag}.mid')
    if method == 'ryy':
      est = fix_sync(est, RYY_OAF_SYNC_ERROR_MS)
    metrics = onset_pitch_metrics(ref_sixteenths, est, beats)
    method_to_metrics[method].append(metrics)

for method, metrics in method_to_metrics.items():
    f1 = np.mean([f1 for f1, _, _ in metrics])
    p = np.mean([p for _, p, _, in metrics])
    r = np.mean([r for _, _, r in metrics])
    print(method, f1, p, r)

ryy 0.4760328819774705 0.47900457777551503 0.49472290693071175
mel 0.27673425734254886 0.28386365399857477 0.2894663409916158
ssh 0.5866153659559636 0.6347876689011206 0.5697072579915735
ssj 0.7428765428606062 0.74844514579763 0.7508046124376337
