In [30]:
from collections import defaultdict

import pretty_midi
import mir_eval

import numpy as np

def midi_to_notes(midi, rate=10):
    notes = defaultdict(list)
    for i in midi.instruments:
        i_tuple = (i.program, i.is_drum)
        for n in i.notes:
            notes[i_tuple].append((n.start, n.end, n.pitch, n.velocity))
    return notes


def notes_to_mir_eval(notes):
    ref_intervals = np.array([(s, e) for s, e, _, _ in notes], dtype=np.float32)
    ref_pitches = np.array([p for _, _, p, _ in notes], dtype=np.float32)
    ref_freqs = np.power(2.0, (ref_pitches - 69) / 12) * 440.0
    return ref_intervals, ref_freqs


def write_vis_midi(
    prmt_path,
    cont_path=None,
    show_drums=True,
    separate_drums=True):
    # Load and parse
    prmt = midi_to_notes(pretty_midi.PrettyMIDI(prmt_path))
    if cont_path is None:
        cont = {}
    else:
        cont = midi_to_notes(pretty_midi.PrettyMIDI(cont_path))
        for i in prmt.keys():
            assert i in cont.keys()

    # Create visualization
    vis_midi = pretty_midi.PrettyMIDI()
    prmt_instrument = pretty_midi.Instrument(0, is_drum=False, name='prompt')
    cont_instrument = pretty_midi.Instrument(1, is_drum=False, name='continuation')
    vis_midi.instruments = [prmt_instrument, cont_instrument]
    if separate_drums:
        prmt_drums_instrument = pretty_midi.Instrument(0, is_drum=True, name='prompt_drums')
        cont_drums_instrument = pretty_midi.Instrument(1, is_drum=True, name='continuation_drums')
        vis_midi.instruments.extend([prmt_drums_instrument, cont_drums_instrument])
    
    # Add prompt notes
    for i, i_notes_prmt in prmt.items():
        if i[1] and not show_drums:
            continue
        i_out = prmt_drums_instrument if separate_drums and i[1] else prmt_instrument
        for n in i_notes_prmt:
            i_out.notes.append(pretty_midi.Note(
                start=n[0],
                end=n[1],
                pitch=n[2],
                velocity=n[3],
            ))
    
    # Add continuation notes
    for i, i_notes_cont in cont.items():
        if i[1] and not show_drums:
            continue
        i_out = cont_drums_instrument if separate_drums and i[1] else cont_instrument
        
        # Align continuation notes to prompt notes w/ tolerance
        i_notes_prmt = prmt.get(i, [])
        if len(i_notes_prmt) == 0:
            alignment = []
        else:
            alignment = mir_eval.transcription.match_notes(
                *notes_to_mir_eval(i_notes_prmt),
                *notes_to_mir_eval(i_notes_cont),
                onset_tolerance=0.05,
                pitch_tolerance=1.0,
                offset_ratio=None)
        
        # Warn about notes in prompt not being found in continuation
        if len(alignment) != len(i_notes_prmt):
            print(cont_path, i, len(i_notes_prmt) - len(alignment))
        
        # Output unmatched notes
        aligned_notes = set(a[1] for a in alignment)
        assert len(aligned_notes) == len(alignment)
        for j, n in enumerate(i_notes_cont):
            if j in aligned_notes:
                continue
            i_out.notes.append(pretty_midi.Note(
                start=n[0],
                end=n[1],
                pitch=n[2],
                velocity=n[3],
            ))

    # Write
    if cont_path is None:
        out_path = prmt_path.replace('.mid', '.vis.mid')
    else:
        out_path = cont_path.replace('.mid', '.vis.mid')
    vis_midi.write(out_path)
    

"""
write_vis_midi(
    'harmony/0-conditional.mid',
    None)

write_vis_midi(
    'harmony/0-conditional.mid',
    'harmony/system0/0-clip-v0.mid')
"""

import pathlib
import glob
from tqdm.notebook import tqdm

for study in ['harmony', 'prompt']:
    for prmt_path in tqdm(sorted(glob.glob(f'{study}/*.mid'))):
        if '.vis.mid' in prmt_path:
            continue
        write_vis_midi(prmt_path)
        i = pathlib.Path(prmt_path).stem.split('-')[0]
        for cont_path in sorted(glob.glob(f'{study}/system*/{i}-*.mid')):
            if '.vis.mid' in cont_path:
                continue
            write_vis_midi(prmt_path, cont_path)

  0%|          | 0/100 [00:00<?, ?it/s]

harmony/system0/10-clip-v0.mid (49, False) 28
harmony/system2/10-clip-v0.mid (49, False) 5
harmony/system0/11-clip-v0.mid (67, False) 7
harmony/system0/12-clip-v0.mid (49, False) 2
harmony/system0/15-clip-v0.mid (48, False) 1
harmony/system0/16-clip-v0.mid (73, False) 3
harmony/system0/17-clip-v0.mid (12, False) 2
harmony/system0/19-clip-v0.mid (40, False) 1
harmony/system0/20-clip-v0.mid (20, False) 1
harmony/system0/24-clip-v0.mid (48, False) 2
harmony/system0/28-clip-v0.mid (24, False) 1
harmony/system0/29-clip-v0.mid (48, False) 1
harmony/system0/3-clip-v0.mid (73, False) 1
harmony/system0/31-clip-v0.mid (56, False) 2
harmony/system0/32-clip-v0.mid (18, False) 1
harmony/system0/4-clip-v0.mid (68, False) 2
harmony/system0/40-clip-v0.mid (3, False) 4
harmony/system0/44-clip-v0.mid (53, False) 5
harmony/system0/46-clip-v0.mid (56, False) 3
harmony/system0/47-clip-v0.mid (88, False) 5
harmony/system0/48-clip-v0.mid (16, False) 26
harmony/system0/5-clip-v0.mid (48, False) 1
harmony/syst

  0%|          | 0/100 [00:00<?, ?it/s]

prompt/system0/24-clip.mid (0, True) 6
prompt/system0/24-clip.mid (65, False) 14
prompt/system0/24-clip.mid (25, False) 48


In [None]:
ra