In [1]:
from collections import defaultdict

import pretty_midi
import mir_eval

import numpy as np

def midi_to_notes(midi, rate=10):
    notes = defaultdict(list)
    for i in midi.instruments:
        i_tuple = (i.program, i.is_drum)
        for n in i.notes:
            notes[i_tuple].append((n.start, n.end, n.pitch, n.velocity))
    return notes


def notes_to_mir_eval(notes):
    ref_intervals = np.array([(s, e) for s, e, _, _ in notes], dtype=np.float32)
    ref_pitches = np.array([p for _, _, p, _ in notes], dtype=np.float32)
    ref_freqs = np.power(2.0, (ref_pitches - 69) / 12) * 440.0
    return ref_intervals, ref_freqs


def write_vis_midi(
    prmt_path,
    cont_path=None,
    dedupe_prompt_in_continuation=True,
    show_drums=True,
    separate_drums=True):
    # Load and parse
    prmt = midi_to_notes(pretty_midi.PrettyMIDI(prmt_path))
    if cont_path is None:
        cont = {}
    else:
        cont = midi_to_notes(pretty_midi.PrettyMIDI(cont_path))
        for ins in prmt.keys():
            assert ins in cont.keys()

    # Create visualization
    vis_midi = pretty_midi.PrettyMIDI()
    prmt_instrument = pretty_midi.Instrument(0, is_drum=False, name='prompt')
    cont_instrument = pretty_midi.Instrument(1, is_drum=False, name='continuation')
    vis_midi.instruments = [prmt_instrument, cont_instrument]
    if separate_drums:
        prmt_drums_instrument = pretty_midi.Instrument(0, is_drum=True, name='prompt_drums')
        cont_drums_instrument = pretty_midi.Instrument(1, is_drum=True, name='continuation_drums')
        vis_midi.instruments.extend([prmt_drums_instrument, cont_drums_instrument])
    
    # Add prompt notes
    for ins, ins_notes_prmt in prmt.items():
        if ins[1] and not show_drums:
            continue
        ins_out = prmt_drums_instrument if separate_drums and ins[1] else prmt_instrument
        for n in ins_notes_prmt:
            ins_out.notes.append(pretty_midi.Note(
                start=n[0],
                end=n[1],
                pitch=n[2],
                velocity=n[3],
            ))
    
    # Add continuation notes
    for ins, ins_notes_cont in cont.items():
        if ins[1] and not show_drums:
            continue
        ins_out = cont_drums_instrument if separate_drums and ins[1] else cont_instrument
        
        # Align continuation notes to prompt notes w/ tolerance
        ins_notes_prmt = prmt.get(ins, [])
        for i in range(1 + int(dedupe_prompt_in_continuation)):
            if len(ins_notes_prmt) > 0 and len(ins_notes_cont) > 0:
                alignment = mir_eval.transcription.match_notes(
                    *notes_to_mir_eval(ins_notes_prmt),
                    *notes_to_mir_eval(ins_notes_cont),
                    onset_tolerance=0.05,
                    pitch_tolerance=0.01,
                    offset_ratio=None)

                # Warn about notes in prompt not being found in continuation
                if i == 0 and len(alignment) != len(ins_notes_prmt):
                    print('alignn', cont_path, ins, len(ins_notes_prmt) - len(alignment))
                if i == 1 and len(alignment) > 0:
                    print('dedupe', cont_path, ins, len(alignment))
                
                # Remove prompt notes
                aligned_notes = set(a[1] for a in alignment)
                ins_notes_cont = [n for i, n in enumerate(ins_notes_cont) if i not in aligned_notes]

        # Output unmatched notes
        for n in ins_notes_cont:
            ins_out.notes.append(pretty_midi.Note(
                start=n[0],
                end=n[1],
                pitch=n[2],
                velocity=n[3],
            ))

    # Write
    if cont_path is None:
        out_path = prmt_path.replace('.mid', '.vis.mid')
    else:
        out_path = cont_path.replace('.mid', '.vis.mid')
    vis_midi.write(out_path)
    

"""
write_vis_midi(
    'harmony/0-conditional.mid',
    None)

write_vis_midi(
    'harmony/0-conditional.mid',
    'harmony/system0/0-clip-v0.mid')
"""

import pathlib
import glob
from tqdm.notebook import tqdm

for study in ['harmony', 'prompt']:
    for prmt_path in tqdm(sorted(glob.glob(f'{study}/*.mid'))):
        if '.vis.mid' in prmt_path:
            continue
        write_vis_midi(prmt_path)
        i = pathlib.Path(prmt_path).stem.split('-')[0]
        for cont_path in sorted(glob.glob(f'{study}/system*/{i}-*.mid')):
            if '.vis.mid' in cont_path:
                continue
            write_vis_midi(prmt_path, cont_path, dedupe_prompt_in_continuation='system0' in cont_path)

  0%|          | 0/100 [00:00<?, ?it/s]

dedupe harmony/system0/0-clip-v0.mid (18, False) 144
dedupe harmony/system0/1-clip-v0.mid (52, False) 23
alignn harmony/system0/10-clip-v0.mid (49, False) 28
dedupe harmony/system0/10-clip-v0.mid (49, False) 55
alignn harmony/system2/10-clip-v0.mid (49, False) 5
alignn harmony/system0/11-clip-v0.mid (67, False) 7
dedupe harmony/system0/11-clip-v0.mid (67, False) 19
alignn harmony/system0/12-clip-v0.mid (49, False) 2
dedupe harmony/system0/12-clip-v0.mid (49, False) 27
dedupe harmony/system0/13-clip-v0.mid (21, False) 116
dedupe harmony/system0/14-clip-v0.mid (18, False) 42
alignn harmony/system0/15-clip-v0.mid (48, False) 1
dedupe harmony/system0/15-clip-v0.mid (48, False) 18
alignn harmony/system0/16-clip-v0.mid (73, False) 3
dedupe harmony/system0/16-clip-v0.mid (73, False) 50
alignn harmony/system0/17-clip-v0.mid (12, False) 2
dedupe harmony/system0/17-clip-v0.mid (12, False) 106
dedupe harmony/system0/18-clip-v0.mid (105, False) 160
alignn harmony/system0/19-clip-v0.mid (40, False)

  0%|          | 0/100 [00:00<?, ?it/s]

alignn prompt/system0/24-clip.mid (0, True) 6
alignn prompt/system0/24-clip.mid (65, False) 14
alignn prompt/system0/24-clip.mid (25, False) 48
dedupe prompt/system0/24-clip.mid (25, False) 2
