In [294]:
import pandas as pd
import numpy as np
import mido
from collections import defaultdict
import math

# Note Class

In [331]:
class Note:
    def __init__(self, pitch, instrument, start, duration, dict_key):
        """
        parameters:
            - pitch: MIDI pitch (int)
            - instrument
            - start: start time (int)
            - end: end time (int)
        
        fields:
            - interval: wrt = (pitch-key) % 12
            - tags: set of strings, flexible labeling
        """
        self.pitch = pitch # MIDI pitch
        self.start = start
        self.duration = duration
        self.instrument = instrument
        self.dict_key = dict_key

        # self.interval = pitch 
        self.key = None
        self.melody = True
        self.chord = False

        self.track_number = None
        
        # motif, repetition, rigid/rubato
        self.tags = set()

    def add_tag(self, tag: str):
        self.tags.add(tag)

    def set_key_and_interval(self, key):
        self.key = key
        self.interval = abs(self.pitch-key) % 12
    
    def set_melody(self, is_melody):
        self.melody = is_melody
    
    def set_chord(self, is_chord):
        self.chord = is_chord

    def set_track_number(self, track_number: int):
        self.track_number = track_number
    
    def get_track_number(self):
        return self.track_number
    
    def get_start_time(self):
        return self.start
    
    def get_duration(self):
        return self.duration

    def get_dict_key(self):
        return self.dict_key
    
    def __str__(self):
        return f"Note(pitch={self.pitch}, start={self.start}, duration={self.duration}, instrument={self.instrument}, dict_key={self.dict_key}, track_number={self.track_number})"


# Pre-Processing Functions + Variables

In [3]:
# configuration variables from original repo 

MAX_TIME_IN_SECONDS = 100          # exclude very long training sequences
MAX_DURATION_IN_SECONDS = 10       # maximum duration of a note
TIME_RESOLUTION = 100 # ticks per second

MAX_PITCH = 128                    # 128 MIDI pitches
MAX_INSTR = 129                    # 129 MIDI instruments (128 + drums)
MAX_NOTE = MAX_PITCH*MAX_INSTR     # note = pitch x instrument

MAX_TIME = TIME_RESOLUTION*MAX_TIME_IN_SECONDS
MAX_DUR = TIME_RESOLUTION*MAX_DURATION_IN_SECONDS

EVENT_OFFSET = 0
TIME_OFFSET = EVENT_OFFSET
DUR_OFFSET = TIME_OFFSET + MAX_TIME
NOTE_OFFSET = DUR_OFFSET + MAX_DUR
REST = NOTE_OFFSET + MAX_NOTE

CONTROL_OFFSET = NOTE_OFFSET + MAX_NOTE + 1
ATIME_OFFSET = CONTROL_OFFSET + 0
ADUR_OFFSET = ATIME_OFFSET + MAX_TIME
ANOTE_OFFSET = ADUR_OFFSET + MAX_DUR

# the special block
SPECIAL_OFFSET = ANOTE_OFFSET + MAX_NOTE
SEPARATOR = SPECIAL_OFFSET

In [443]:
### modified repo code to parse MIDI file for tokens

def midi_to_compound(midifile, debug=False):
    """
    parameters:
        - midifile: MIDI file (all tracks)

    returns:
        - list of tokens in the pattern: (start time, duration, MIDI note, instrument, velocity, channel)
        - list of keys of notes: (instrument, note, channel, onset_time_in_ticks, duration_ticks)
            - correspond 1:6 with the notes in the tokens
    """
    if type(midifile) == str:
        midi = mido.MidiFile(midifile)
    else:
        midi = midifile

    tokens = []
    note_idx = 0
    open_notes = defaultdict(list)
    closed_notes = []

    time = 0
    instruments = defaultdict(int) # default to code 0 = piano
    tempo = 500000 # default tempo: 500000 microseconds per beat

    for message in midi:
        time += message.time
        # print("MESSAGE", message, round(TIME_RESOLUTION*time))

        # sanity check: negative time?
        if message.time < 0:
            raise ValueError

        if message.type == 'program_change':
            instruments[message.channel] = message.program
        elif message.type in ['note_on', 'note_off']:
            # special case: channel 9 is drums!
            instr = 128 if message.channel == 9 else instruments[message.channel]

            if message.type == 'note_on' and message.velocity > 0: # onset
                # time quantization
                time_in_ticks = round(TIME_RESOLUTION*time)
                # print(time, time_in_ticks)

                # Our compound word is: (time, duration, note, instr, velocity)
                tokens.append(time_in_ticks) # 5ms resolution
                tokens.append(-1) # placeholder (we'll fill this in later)
                tokens.append(message.note)
                tokens.append(instr)
                tokens.append(message.velocity)

                tokens.append(message.channel) # CHANNEL INFO FOR KEY

                open_notes[(instr,message.note,message.channel)].append((note_idx, time))
                note_idx += 1
            else: # offset
                try:
                    open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0)
                except IndexError:
                    if debug:
                        print('WARNING: ignoring bad offset')
                else:
                    duration_ticks = round(TIME_RESOLUTION*(time-onset_time))
                    # tokens[5*open_idx + 1] = duration_ticks
                    tokens[6*open_idx + 1] = duration_ticks # ACCOUNT FOR CHANNEL
                    # print("onset time:", onset_time, "duration_ticks:", duration_ticks)
                    onset_time_in_ticks = round(TIME_RESOLUTION*onset_time)

                    # closed_notes[(instr,message.note, message.channel, onset_time_in_ticks, duration_ticks)].append("")
                    closed_notes.append((instr,message.note, message.channel, onset_time_in_ticks, duration_ticks))

                    #del open_notes[(instr,message.note,message.channel)]
        elif message.type == 'set_tempo':
            tempo = message.tempo
        elif message.type == 'time_signature':
            # print('TIME SIGNATURE', message)
            pass # we use real time
        elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
            pass # we don't attempt to model these
        elif message.type == 'control_change':
            pass # this includes pedal and per-track volume: ignore for now
        elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
                              'copyright', 'marker', 'instrument_name', 'cue_marker',
                              'device_name', 'sequence_number']:
            pass # possibly useful metadata but ignore for now
        elif message.type == 'channel_prefix':
            pass # relatively common, but can we ignore this?
        elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
            pass # I have no idea what this is
        else:
            if debug:
                print('UNHANDLED MESSAGE', message.type, message)
        # print(tokens)

    unclosed_count = 0
    for _,v in open_notes.items():
        unclosed_count += len(v)

    if debug and unclosed_count > 0:
        print(f'WARNING: {unclosed_count} unclosed notes')
        print('  ', midifile)

    return tokens, closed_notes


In [205]:
def midi_tracks_to_dict(track_midifile, track_num, debug=False):
    """
    parameters:
        - track_midifile: MIDI file (single track)
        - track_num: the track number (int)

    returns a dictionary of notes with:
        - key = (instrument, note, channel, onset_time_in_ticks, duration_ticks)
        - value = track number
    """
    if type(track_midifile) == str:
        midi = mido.MidiFile(track_midifile)
    else:
        midi = track_midifile

    tokens = []
    note_idx = 0
    open_notes = defaultdict(list)
    closed_notes = defaultdict(int)

    time = 0
    instruments = defaultdict(int) # default to code 0 = piano
    tempo = 500000 # default tempo: 500000 microseconds per beat
    for message in midi:
        time += message.time

        # sanity check: negative time?
        if message.time < 0:
            raise ValueError

        if message.type == 'program_change':
            instruments[message.channel] = message.program
        elif message.type in ['note_on', 'note_off']:
            # special case: channel 9 is drums!
            instr = 128 if message.channel == 9 else instruments[message.channel]

            if message.type == 'note_on' and message.velocity > 0: # onset
                # time quantization
                time_in_ticks = round(TIME_RESOLUTION*time)

                # Our compound word is: (time, duration, note, instr, velocity)
                tokens.append(time_in_ticks) # 5ms resolution
                tokens.append(-1) # placeholder (we'll fill this in later)
                tokens.append(message.note)
                tokens.append(instr)
                tokens.append(message.velocity)

                open_notes[(instr,message.note,message.channel)].append((note_idx, time))
                note_idx += 1
            else: # offset
                try:
                    open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0)
                except IndexError:
                    if debug:
                        print('WARNING: ignoring bad offset')
                else:
                    duration_ticks = round(TIME_RESOLUTION*(time-onset_time))
                    tokens[5*open_idx + 1] = duration_ticks
                    # if duration_ticks == 36: duration_ticks = 18

                    # print("onset_time", onset_time, "duration_ticks", duration_ticks)
                    onset_time_in_ticks = round(TIME_RESOLUTION*onset_time)

                    closed_notes[(instr,message.note,message.channel, onset_time_in_ticks, duration_ticks)] = track_num

                    #del open_notes[(instr,message.note,message.channel)]
        elif message.type == 'set_tempo':
            tempo = message.tempo
        elif message.type == 'time_signature':
            # print('TIME SIGNATURE', message)
            pass # we use real time
        elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
            pass # we don't attempt to model these
        elif message.type == 'control_change':
            pass # this includes pedal and per-track volume: ignore for now
        elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
                            'copyright', 'marker', 'instrument_name', 'cue_marker',
                            'device_name', 'sequence_number']:
            pass # possibly useful metadata but ignore for now
        elif message.type == 'channel_prefix':
            pass # relatively common, but can we ignore this?
        elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
            pass # I have no idea what this is
        else:
            if debug:
                print('UNHANDLED MESSAGE', message.type, message)

    unclosed_count = 0
    for _,v in open_notes.items():
        unclosed_count += len(v)

    if debug and unclosed_count > 0:
        print(f'WARNING: {unclosed_count} unclosed notes')
        print('  ', track_midifile)

    return closed_notes


In [6]:
def compound_to_events_and_notes(tokens, stats=False):
    """
    parameters:
        - tokens: list of tokens in the pattern (start time, duration, MIDI note, instrument, velocity, channel)
            - should be multiple of 6

    returns:
        - events: tokens for input to model (start_time, duration_token, note_token)
            - removes velocity and channel
            - duration_token capped at MAX_DUR, if unknown set to 250ms
            - note_token = NOTE_OFFSET + (MAX_PITCH*instr + note)
        - notes: list of Note objects
            - sets pitch, instrument, start time, duration, dict_key
    """

    assert len(tokens) % 6 == 0
    tokens = tokens.copy()

    # remove velocities
    del tokens[4::6]

    # combine (note, instrument)
    assert all(-1 <= tok < 2**7 for tok in tokens[2::5]) # check that values valid
    assert all(-1 <= tok < 129 for tok in tokens[3::5]) # check that values valid

    notes = []

    for start, dur, note, instr, channel in zip(tokens[0::5], tokens[1::5], tokens[2::5], tokens[3::5], tokens[4::5]):
        if note == -1:
            notes.append(SEPARATOR)
        else:
            cur_duration = TIME_RESOLUTION//4 if dur == -1 else min(dur, MAX_DUR-1)
            # don't readjust duration by dur_offset (not necessary?)
            # (instr,message.note, message.channel, onset_time, duration_ticks)
            dict_key = (instr, note, channel, start, dur)
            cur_note = Note(pitch=note, instrument=instr, start=start, duration=cur_duration, dict_key=dict_key)
            notes.append(cur_note)
    
    del tokens[4::5]

    # set to separator if note is -1
    # otherwise set the note to MAX_PITCH*instr + note
    tokens[2::4] = [SEPARATOR if note == -1 else MAX_PITCH*instr + note
                    for note, instr in zip(tokens[2::4],tokens[3::4])]
    tokens[2::4] = [NOTE_OFFSET + tok for tok in tokens[2::4]]
    del tokens[3::4]

    # max duration cutoff and set unknown durations to 250ms
    truncations = sum([1 for tok in tokens[1::3] if tok >= MAX_DUR])
    tokens[1::3] = [TIME_RESOLUTION//4 if tok == -1 else min(tok, MAX_DUR-1)
                    for tok in tokens[1::3]]
    tokens[1::3] = [DUR_OFFSET + tok for tok in tokens[1::3]]

    assert min(tokens[0::3]) >= 0
    tokens[0::3] = [TIME_OFFSET + tok for tok in tokens[0::3]]

    assert len(tokens) % 3 == 0

    if stats:
        return tokens, truncations

    return tokens, notes

In [66]:
def set_note_track_number(notes, note_track_dict):
    """
    parameters:
        - notes: list of Note objects
        - note_track_dict: dictionary with
            - key = (instrument, note, channel, onset_time_in_ticks, duration_ticks)
            - value = track number

    sets the track number for each Note object in notes
    """
    for note in notes:
        key = note.get_dict_key()
        if key in note_track_dict:
            note.set_track_number(note_track_dict[key])
        else:
            print("ERROR: key not found in note_track_dict:", key)

In [70]:
def get_track_notes(notes):
    """
    parameters:
        - notes: list of Note objects (MUST have track_number field set)
    
    since notes are in order, it should add them to each dictionary entry list in time order

    returns a dictionary with:
        - key = track number
        - value = list of Note objects in that track
    """
    track_notes = defaultdict(list)

    for note in notes:
        track_num = note.get_track_number()
        if track_num is None:
            print("ERROR: note does not have track number set:", note)
        else:
            track_notes[note.get_track_number()].append(note)

    return track_notes

# Pre-Processing Code

`midi_to_compound(midifile)`
- input a MIDI file with **all** tracks
- outputs tokens (to turn into events) and a list of keys for all notes

`midi_tracks_to_dict(midifile, track_num)`
- input a MIDI file with a **single** track
- outputs a dictionary of keys corresponding to note in track, value = track number

`compound_to_events_and_notes(tokens)`
- input the tokens from `midi_to_compound`
- outputs events (the actual tokens for model) and a list of Note objects corresponding to each note (in same order)

`set_note_track_number(notes, note_track_dict)`
- input the list of Note objects from `compound_to_events_and_notes` and dictionary matching note to track number from `midi_tracks_to_dict`
- modifies each Note object to store the correct track number
- does not output anything

`get_track_notes(notes)`
- input a list of Note objects AFTER calling `set_note_track_number`
- outputs a dictionary with key as track number and value as list of Note objects in that track

### To process a MIDI file:
1. pass MIDI with **all tracks** into `isolate_midi_tracks(midifile)`. this splits up the file into separate tracks (for tagging later)
2. pass MIDI with **all tracks** into `midi_to_compound(midifile)` to get tokenized rep `(start time, duration, MIDI note, instrument, velocity, channel)`, as well as a list of dictionary key reps for all the notes
3. take tokens from `midi_to_compound(midifile)` output and put into `compound_to_events_and_notes(tokens)` to get events (model input) and list of Note objects corresponding to the parsed notes
4. pass MIDI for **each track** into `midi_tracks_to_dict(midifile, track_num)` with respective track number to get a dictionary with key=note key and val=track number. combine all dictionary outputs so we have a big dictionary mapping all note keys to respective track number
5. pass in list of Note objects from `midi_to_compound(midifile)` and combined dictionary from step 4 into `set_note_track_number(notes, note_track_dict)` to set the correct track number field for each Note object.
6. pass in list of Notes **after setting track numbers** into `get_track_notes(notes)` to get a dictionary mapping each track number to a list of Note objects in that track (for more efficient acccess)



In [437]:
def process_midi_file(midifile, track_midifiles):
    """
    processes a given MIDI file by creating Note objects corresponding to the MIDI notes, as well as labelling Notes by track
    parameters: 
        - midifile: MIDI file path with all tracks
        - track_midifiles: list of MIDI file paths for each track

    returns:
        - all_note_keys: list of all the keys of notes (instrument, note, channel, onset_time_in_ticks, duration_ticks) in MIDI order
        - notes: list of all Note objects in the model token order
        - notes_per_track: dictionary with key = track number, value = list of Note objects in that track (in notes)
    """
    midi = mido.MidiFile(midifile)
    
    # tokens = to be input into compound_to_events_and_notes
    # all_note_keys = list of keys of notes (instrument, note, channel, onset_time_in_ticks, duration_ticks)
    tokens, all_note_keys = midi_to_compound(midi)

    # events = tokens for model input
    # notes = list of Note objects
    events, notes = compound_to_events_and_notes(tokens)

    # store track number for every note key
    note_track_dict = {}
    for i, track_midifile in enumerate(track_midifiles):
        track_note_dict = midi_tracks_to_dict(track_midifile, i)
        # combine into larger dictionary mapping notes to track number
        note_track_dict.update(track_note_dict)
        
    # DEBUGGING CODE
    # keys_from_tracks = list(note_track_dict.keys())
    # sorted_all_note_keys = sorted(all_note_keys)
    # sorted_keys_from_tracks = sorted(keys_from_tracks)
    # print("all note keys sorted", sorted_all_note_keys)
    # print("all track note keys sorted", sorted_keys_from_tracks)
    # print(len(all_note_keys),len(keys_from_tracks))
    
    # set the track number field for each Note object
    set_note_track_number(notes, note_track_dict)

    # get dictionary mapping track number to the list of Note objects in that track
    notes_per_track = get_track_notes(notes)

    return all_note_keys, notes, notes_per_track

In [None]:
def isolate_midi_tracks(midifile):
    """
    given a MIDI file with multiple tracks, saves each track into a separate MIDI file (preserving tempo, etc.)
    parameters: 
        - midifile: MIDI file path with all tracks

    returns:
        - all_track_midis: list of all track MIDI file paths
    """
    midi = mido.MidiFile(midifile)

    all_track_midis = []
    prefix = midifile[:midifile.index(".mid")]

    metadata = []

    for msg in midi.tracks[0]:
        # if msg.is_meta:
        #     metadata.append(msg.copy())
        if msg.type in ['time_signature', 'key_signature', 'set_tempo']:
            metadata.append(msg.copy())
    
    new_mid = mido.MidiFile()

    for i, track in enumerate(midi.tracks):

        new_mid = mido.MidiFile()
        # set ticks per beat to be the same
        new_mid.ticks_per_beat = midi.ticks_per_beat
        new_mid.tracks.append(metadata.copy())
        new_mid.tracks[0] += track
        # for message in new_mid:
        #     print(message)

        filename = f'{prefix}_track{i}.mid'
        all_track_midis.append(filename)
        new_mid.save(filename)
    
    return all_track_midis

In [None]:
def make_csv_from_midi(midifile):
    """
    FOR DEBUGGING: iterate through entire MIDI and store all info in CSV file
    parameters: 
        - midifile: input MIDI file path
    
    saves MIDI file as a CSV with columns "start", "start_ticks", "end", "duration_ticks", "instr", "note", "channel", "key", "final_key"

    returns:
        - nothing
    """

    if type(midifile) == str:
        midi = mido.MidiFile(midifile)
    else:
        midi = midifile

    df = pd.DataFrame(columns=["start", "start_ticks", "end", "duration_ticks", "instr", "note", "channel", "key", "final_key"])
    print(df)

    tokens = []
    note_idx = 0
    open_notes = defaultdict(list)
    closed_notes = []

    time = 0
    instruments = defaultdict(int) # default to code 0 = piano
    tempo = 500000 # default tempo: 500000 microseconds per beat

    for message in midi:
        time += message.time
        # print("MESSAGE", message, round(TIME_RESOLUTION*time))

        # sanity check: negative time?
        if message.time < 0:
            raise ValueError

        if message.type == 'program_change':
            instruments[message.channel] = message.program
        elif message.type in ['note_on', 'note_off']:
            # special case: channel 9 is drums!
            instr = 128 if message.channel == 9 else instruments[message.channel]

            if message.type == 'note_on' and message.velocity > 0: # onset
                # time quantization
                time_in_ticks = round(TIME_RESOLUTION*time)

                # Our compound word is: (time, duration, note, instr, velocity)
                tokens.append(time_in_ticks) # 5ms resolution
                tokens.append(-1) # placeholder (we'll fill this in later)
                tokens.append(message.note)
                tokens.append(instr)
                tokens.append(message.velocity)

                tokens.append(message.channel) # CHANNEL INFO FOR KEY

                key = (instr,message.note,message.channel)
                open_notes[key].append((note_idx, time))
                print(note_idx)
                print([time, time_in_ticks, -1, -1, instr, message.note, message.channel, key, -1])
                df.loc[note_idx] = [time, time_in_ticks, -1.0, -1, instr, message.note, message.channel, key, "-1"]

                note_idx += 1

            else: # offset
                try:
                    open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0)
                except IndexError:
                    if debug:
                        print('WARNING: ignoring bad offset')
                else:
                    duration_ticks = round(TIME_RESOLUTION*(time-onset_time))
                    # tokens[5*open_idx + 1] = duration_ticks
                    tokens[6*open_idx + 1] = duration_ticks # ACCOUNT FOR CHANNEL
                    # print("onset time:", onset_time, "duration_ticks:", duration_ticks)
                    onset_time_in_ticks = round(TIME_RESOLUTION*onset_time)

                    closed_notes.append((instr,message.note, message.channel, onset_time_in_ticks, duration_ticks))
                    # print(type(time))
                    df.loc[open_idx, "end"] = time
                    df.loc[open_idx, "duration_ticks"] = duration_ticks
                    df.loc[open_idx, "final_key"] = str((instr,message.note, message.channel, onset_time_in_ticks, duration_ticks))

        elif message.type == 'set_tempo':
            tempo = message.tempo
        elif message.type == 'time_signature':
            # print('TIME SIGNATURE', message)
            pass # we use real time
        elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
            pass # we don't attempt to model these
        elif message.type == 'control_change':
            pass # this includes pedal and per-track volume: ignore for now
        elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
                              'copyright', 'marker', 'instrument_name', 'cue_marker',
                              'device_name', 'sequence_number']:
            pass # possibly useful metadata but ignore for now
        elif message.type == 'channel_prefix':
            pass # relatively common, but can we ignore this?
        elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
            pass # I have no idea what this is
        else:
            if debug:
                print('UNHANDLED MESSAGE', message.type, message)
        # print(tokens)

    # print(df)
    df.to_csv(midifile[:midifile.index(".mid")]+"_notes.csv")
    return 

# Labeling Helper Functions

In [216]:
def modify_notes_by_index_range(notes, start, end, func):
    """
    modifies the notes to do [func] in an index range (inclusive)
    parameters:
        - notes: list of all Note objects
        - start: starting index (inclusive)
        - end: ending index (inclusive)
        - func: function corresponding to Note class function
    """
    for i in range(start, end+1):
        func(notes[i])
    return

In [451]:
def modify_notes_by_time_range(notes, start, end, func, in_ticks=False):
    """
    modifies the notes within a time range
    time in the Notes is stored by a converted time_in_ticks (time_in_ticks = round(TIME_RESOLUTION*time))
    set in_ticks to be true if start, end are in ticks, else will convert the Note time to seconds
    parameters:
        - notes: list of all Note objects
        - start: starting time (inclusive), in seconds if in_ticks=False else in ticks
        - end: ending time (inclusive), in seconds if in_ticks=False else in ticks
        - func: function corresponding to Note class function
        - in_ticks: whether or not start, end are in ticks. default is False (start, end in seconds)
    """
    # if start, end in seconds, then convert start, end to ticks
    if not in_ticks:
        start = round(TIME_RESOLUTION*start)
        end = round(TIME_RESOLUTION*end)
    
    for note in notes:
        note_start = note.get_start_time()
        if note_start >= start and note_start <= end:
            func(note)
        elif note_start > end:
            break
    return
    

In [218]:
def modify_notes_by_track(notes_per_track, track_num, func):
    """
    modifies all notes in a track to do some [func]
    parameters:
        - notes_per_track: dictionary mapping track number to a list of Note objects in that track
        - track_num: the track we want to modify
        - func: function corresponding to Note class function
    """
    notes = notes_per_track[track_num]
    for note in notes:
        func(note)
    return

In [457]:
def modify_notes_in_list(note_list, func):
    """
    modifies all notes in a given list to do some [func]
    parameters:
        - note_list: list of Note objects
        - func: function corresponding to Note class function
    """
    for note in note_list:
        func(note)
    return

# Labeling Code w/ MIDI Input

list of modification functions in Note class:
- `add_tag(tag: str)`: tag is unspecified string (manually set), ex: motif, repetition, rigid/rubato
- `set_key_and_interval(key)`: key is MIDI note number, sets key=key, interval=(pitch-key) % 12
- `set_melody(is_melody)`: is_melody is boolean
- `set_chord(is_chord)`: is_chord is boolean

use these modification functions in the `func` param for the labeling code above
- ex: `func = lambda x: x.set_key_and_interval(60)`

In [332]:
repetition1_midis = isolate_midi_tracks("repetition1.mid")

all_note_keys, notes, notes_per_track = process_midi_file("repetition1.mid", repetition1_midis)

# set key (and interval) for all notes
modify_notes_by_index_range(notes, 0, len(notes)-1, lambda x: x.set_key_and_interval(62))

# 2 bars repeated
# tag first 2 bars as motif 1
# tag last 2 bars as motif 2
# tag track 0 as melody
# tag track 2 and 3 as repetitive + accompaniment
# tag groups of rhythm (like 3/8ths + )

# does it generate something similar to the first motif + second motif
# 

# for note in notes:
#     print(note)


# Analysis Code

In [None]:
# load attention weight matrices
# tokens in pattern of (start_token, duration_token, note_token)
# tokens in order of MIDI file

# RESCALE ATTENTION WEIGHTS
# z-score, scale by uniform attention, etc.

# iterate through each attention matrix
# mark which notes it pays the most attention to
# sum up values? how do we analyze

# maybe first start by seeing which fields appear to be most correlated with higher attention on each matrix
# "X matrix is most active during notes with Y characteristic"
# and then we can narrow down our hypotheses to be "X matrix correlates to X characteristic"

### Load Attention Matrices

In [394]:
def load_attn_matrices(filename, filepath_prefix):
    """
    parameters:
        - filename: midi file name
        - filepath_prefix: path prefix to the attention matrix file

    returns:
        - attention_matrices = list of pandas dataframes corresponding to each attention matrix
    """

    filename = filename[:filename.index(".mid")] if ".mid" in filename else filename

    attention_heads = []
    for i in range(0, 12):
        cur_filename = f'{filename}_head{i}.npy'
        cur_filepath = f'{filepath_prefix}/{cur_filename}'
        cur_matrix = np.load(cur_filepath)

        matrix_df = pd.DataFrame(cur_matrix)
        matrix_df = matrix_df.iloc[1:, 1:] # remove first row+col (corresponds to CLS token)
        attention_heads.append(matrix_df)
    return attention_heads

### Create Rescaled Attention Matrices

1. **rescale wrt uniform attention**: each entry in the attention matrices = `attn_val / (1/(# tokens attended to so far))`, which basically corresponds to uniform attention in a row. since the attention matrices are diagonal, this helps maintain a more consistent value across different rows where the model is paying attention to different numbers of tokens.
2. **compute Z score of attention row**: each entry in the attention matrices = `# std deviations away from the avg attention of that row`. this does a similar thing but now we are able to assess average attentions. 

**NOTE:** can also maybe calculate the distribution/std dev of attentions per row? this might tell us if an attention head is very focused on specific notes vs focused on a lot of notes at once

In [292]:
def rescale_attn_matrix_unif(attention_matrices):
    """
    create rescaled matrices wrt uniform attention
    each entry is (attn value) / (1/(num tokens attended to so far))
    which corresponds to uniform attention in that row
    
    parameters:
        - attention_matrices: list of pandas dataframes corresponding to each attention matrix

    returns:
        - rescaled_attention_matrices: list of pandas dataframes corresponding to each rescaled attention matrix
    """
    rescaled_attention_matrices = []

    for attn_matrix in attention_matrices:
        rescaled_matrix = attn_matrix.copy()

        for i in range(rescaled_matrix.shape[0]):
            unif_value = 1 / (i+1)
            rescaled_matrix.iloc[i,:] = rescaled_matrix.iloc[i,:] / unif_value
        rescaled_attention_matrices.append(rescaled_matrix)
    
    return rescaled_attention_matrices

In [290]:
def rescale_attn_matrix_zscore(attention_matrices):
    """
    create Z-score matrices (per row)
    each entry is how much attention is given to a token wrt mean attention for that row
    ex: token 2 attends to token 1 with attention that is 2 std above mean attention for that row 
    NOTE: upper right triangle left as 0

    parameters:
        - attention_matrices: list of pandas dataframes corresponding to each attention matrix

    returns:
        - zscore_attention_matrices: list of pandas dataframes corresponding to each Z-score attention matrix
    """
    zscore_attn_matrices = []
    attn_matrices_stats = []

    for attn_head in attention_matrices:
        copy_attn_head = attn_head.copy()
        stats_df = pd.DataFrame(index=range(attn_head.shape[0]), columns=['mean', 'std'])

        for i in range(attn_head.shape[0]):
            # get mean and std dev for that row (up to col i+1 bc only for tokens attended to so far)
            row_mean = attn_head.iloc[i,:i+1].mean()
            row_std = attn_head.iloc[i,:i+1].std()

            # if row_std = 0 or nan
            if row_std == 0 or np.isnan(row_std):
                # if all attention values are the same, set to 0
                copy_attn_head.iloc[i,:] = 0
            else:
                # calculate z-score
                copy_attn_head.iloc[i,:i+1] = (copy_attn_head.iloc[i,:i+1] - row_mean) / row_std

            # populate mean and std dev dataframe
            stats_df.loc[i, 'mean'] = row_mean
            stats_df.loc[i, 'std'] = row_std

        zscore_attn_matrices.append(copy_attn_head)
        attn_matrices_stats.append(stats_df)

    return zscore_attn_matrices, attn_matrices_stats

### Top K Most Attended to Tokens per Attention Head

this is done by:
1. taking rescaled attention matrix (z-score)
2. averaging the z-score for every single token (theoretically should not be skewed by order / # tokens attended to)
3. return top K highest average tokens

In [None]:
def get_top_k_zscore_tokens(zscore_attn_matrices, k, num_input_tokens):
    """
    for each attention matrix, get the top k tokens with highest z-score attention for each row

    parameters:
        - zscore_attn_matrices: list of pandas dataframes corresponding to each Z-score attention matrix
        - k: top k tokens to get
    returns:
        - top_k_tokens_per_matrix: list of lists of lists
            - outer list: per attention matrix
            - middle list: per row
            - inner list: top k token indices with highest z-score attention

    """
    top_k_tokens_per_matrix = []
    # num_tokens = len(zscore_attn_matrices[0])

    for matrix in zscore_attn_matrices:
        cur_matrix_index_sums = []
        # sum up each matrix column
        for i in range(matrix.shape[1]):
            # don't consider generated tokens
            # if i >= num_input_tokens: break
            
            # # sum up starting at row i to ignore upper triangle
            # col_sum = matrix.iloc[i:,i].sum()
            # # get average (divide by # times the token is attended to)
            # col_avg = col_sum / matrix.iloc[i:,i].shape[0]

            # col_med = matrix.iloc[i:,i].median()
            col_avg = matrix.iloc[i:,i].mean()

            cur_matrix_index_sums.append((i, col_avg.item()))
            # cur_matrix_index_sums.append((i, col_med.item()))

        # sort by column average
        cur_matrix_index_sums = sorted(cur_matrix_index_sums, key=lambda x: x[1])

        top_k_tokens_per_matrix.append(cur_matrix_index_sums[-k:]) # get top k tokens
    return top_k_tokens_per_matrix

In [None]:
# load attention weight matrices
# split up into every three (start, duration, note)


# can find top k neighbor for each token, bottom k neighbor for each token
# run statistics on each group to determine if there is a trait that is more common in top k vs. bottom k
# compare across different samples!

# can we use traits to predict the attention?



In [None]:
# split up into hypotheses

# Hypothesis 1: some attention head will be looking more at 3rd interval + 7th interval to determine the key of the piece


# Hypothesis 2: some attention head will be looking more at notes in the melody (vs. accompaniment)


# Hypothesis 3:

# "attention sync"
# the model uses the first token to store summary information
# ignore first token in each attention matrix


In [428]:
repetition1 = mido.MidiFile("repetition1.mid")
# print(repetition1)
repetition1_track1 = mido.MidiFile(repetition1_midis[0])
time = 0
for msg in repetition1_track1:
    time += msg.time
    print(msg, "TIME: ", time)
# print(time)

MetaMessage('time_signature', numerator=4, denominator=4, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=0) TIME:  0
MetaMessage('key_signature', key='D', time=0) TIME:  0
MetaMessage('set_tempo', tempo=714286, time=0) TIME:  0
MetaMessage('track_name', name='Piano', time=0) TIME:  0
MetaMessage('time_signature', numerator=4, denominator=4, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=0) TIME:  0
MetaMessage('key_signature', key='D', time=0) TIME:  0
MetaMessage('set_tempo', tempo=714286, time=0) TIME:  0
control_change channel=2 control=121 value=0 time=0 TIME:  0
control_change channel=2 control=100 value=0 time=0 TIME:  0
control_change channel=2 control=101 value=0 time=0 TIME:  0
control_change channel=2 control=6 value=12 time=0 TIME:  0
control_change channel=2 control=100 value=127 time=0 TIME:  0
control_change channel=2 control=101 value=127 time=0 TIME:  0
program_change channel=2 program=0 time=0 TIME:  0
control_change channel=2 control=7 value=100 ti

In [458]:
# LOAD MIDI FILE + PROCESS NOTES

repetition1_midis = isolate_midi_tracks("repetition1.mid")

all_note_keys, notes, notes_per_track = process_midi_file("repetition1.mid", repetition1_midis)

# set key (and interval) for all notes
modify_notes_by_index_range(notes, 0, len(notes)-1, lambda x: x.set_key_and_interval(62))

# TAG NOTES
# first 2 bars are motif 1
modify_notes_by_time_range(notes, 0, 5.714288000000003, lambda x: x.add_tag("motif1"), in_ticks=False)

# last 2 bars are motif 2
modify_notes_by_time_range(notes, 5.714288000000003, 11.428576000000006, lambda x: x.add_tag("motif2"), in_ticks=False)

# track 0 is melody
modify_notes_by_track(notes_per_track, 0, lambda x: x.add_tag("melody"))

# track 2 + track 3 as repetitive + accompaniment
modify_notes_by_track(notes_per_track, 2, lambda x: x.add_tag("repetitive_notes"))
modify_notes_by_track(notes_per_track, 2, lambda x: x.add_tag("accompaniment"))

modify_notes_by_track(notes_per_track, 3, lambda x: x.add_tag("repetitive_notes"))
modify_notes_by_track(notes_per_track, 3, lambda x: x.add_tag("accompaniment"))

# tag groups of rhythm (like 3/8ths + )

# for note in notes:
#     print(note.pitch)
#     print(note.start)
#     print(note.tags)

In [None]:
# LOAD + RESCALE ATTENTION MATRICES

repetition1_attn_matrices = load_attn_matrices("repetition1.mid", "repetition1_heads")
repetition1_unif_attn_matrices = rescale_attn_matrix_unif(repetition1_attn_matrices)
repetition1_zscore_attn_matrices, repetition1_attn_stats = rescale_attn_matrix_zscore(repetition1_attn_matrices)
# print(repetition1_attn_stats[0].head)

# STORE TOP K ATTENTION TOKENS PER MATRIX

repetition1_topks = get_top_k_zscore_tokens(repetition1_zscore_attn_matrices, k=10, num_input_tokens=3*len(notes))


In [465]:
# head 0 is accompaniment + repetitive 8th notes (tracks 2 and 3)
# head 2 is all duration tokens (more melody focused)
# head 6 is all note tokens
test = [repetition1_zscore_attn_matrices[3]]
top_k = get_top_k_zscore_tokens(test, 20, 387)
token_type = ["start time", "duration", "note"]

for item in top_k[0]:
    index, score = item
    print(item)
    note_index = math.floor(index/3)
    # print(note_index)
    # print(index%3)
    # print(notes[note_index])
    cur_note = notes[note_index]
    
    print("   pitch: ", cur_note.pitch, "\n   start: ", cur_note.start, "\n   duration: ", cur_note.duration, "\n   interval: ", cur_note.interval, "\n   token type: ", token_type[index%3], "\n   TAGS: ", cur_note.tags)



(239, 0.2836686372756958)
   pitch:  62 
   start:  679 
   duration:  18 
   interval:  0 
   token type:  note 
   TAGS:  {'motif2', 'accompaniment', 'repetitive_notes'}
(280, 0.3046455979347229)
   pitch:  59 
   start:  786 
   duration:  71 
   interval:  3 
   token type:  duration 
   TAGS:  {'motif2'}
(295, 0.3269103467464447)
   pitch:  61 
   start:  857 
   duration:  24 
   interval:  1 
   token type:  duration 
   TAGS:  {'motif2', 'melody'}
(175, 0.33267125487327576)
   pitch:  58 
   start:  500 
   duration:  53 
   interval:  4 
   token type:  duration 
   TAGS:  {'melody', 'motif1'}
(224, 0.34401729702949524)
   pitch:  62 
   start:  631 
   duration:  12 
   interval:  0 
   token type:  note 
   TAGS:  {'motif2', 'melody'}
(373, 0.34809422492980957)
   pitch:  58 
   start:  1071 
   duration:  71 
   interval:  4 
   token type:  duration 
   TAGS:  {'motif2', 'melody'}
(193, 0.3501931130886078)
   pitch:  62 
   start:  571 
   duration:  24 
   interval:  0 
 

In [None]:
# head 0 is accompaniment + repetitive 8th notes (tracks 2 and 3)
# head 2 is all duration tokens (more melody focused)
# head 6 is all note tokens
test = [repetition1_zscore_attn_matrices[3]]
top_k = get_top_k_zscore_tokens(test, 20, 387)
token_type = ["start time", "duration", "note"]

for item in top_k[0]:
    index, score = item
    print(item)
    note_index = math.floor(index/3)
    # print(note_index)
    # print(index%3)
    # print(notes[note_index])
    cur_note = notes[note_index]
    
    print("   pitch: ", cur_note.pitch, "\n   start: ", cur_note.start, "\n   duration: ", cur_note.duration, "\n   interval: ", cur_note.interval, "\n   token type: ", token_type[index%3], "\n   TAGS: ", cur_note.tags)

