In [149]:
import pretty_midi
import joblib
import glob
import os
import pandas as pd

def compute_list_average(l):
    """
    Given a list of numbers, compute the average.
    
    Parameters
    ----------
    l : list
        List of numbers.
    
    Returns
    -------
    average : float
        Average of the numbers in the list.
    """
    return sum(l) / len(l)

def categorize_midi_instrument(program_number):
    """
    Given a MIDI instrument program number, categorize it into a high-level
    instrument family.
    
    Parameters
    ----------
    program_number : int
        MIDI instrument program number.
    
    Returns
    -------
    instrument_family : str
        Name of the instrument family.
    """
    # See http://www.midi.org/techspecs/gm1sound.php

    if 0 <= program_number <= 7:
        return 'Piano'
    elif 8 <= program_number <= 15:
        return 'Chromatic Percussion'
    elif 16 <= program_number <= 23:
        return 'Organ'
    elif 24 <= program_number <= 31:
        return 'Guitar'
    elif 32 <= program_number <= 39:
        return 'Bass'
    elif 40 <= program_number <= 47:
        return 'Strings'
    elif 48 <= program_number <= 55:
        return 'Ensemble'
    elif 56 <= program_number <= 63:
        return 'Brass'
    elif 64 <= program_number <= 71:
        return 'Reed'
    elif 72 <= program_number <= 79:
        return 'Pipe'
    elif 80 <= program_number <= 87:
        return 'Synth Lead'
    elif 88 <= program_number <= 95:
        return 'Synth Pad'
    elif 96 <= program_number <= 103:
        return 'Synth Effects'
    elif 104 <= program_number <= 111:
        return 'Ethnic'
    elif 112 <= program_number <= 119:
        return 'Percussive'
    elif 120 <= program_number <= 127:
        return 'Sound Effects'


def compute_statistics(midi_file):
    """
    Given a path to a MIDI file, compute a dictionary of statistics about it
    
    Parameters
    ----------
    midi_file : str
        Path to a MIDI file.
    
    Returns
    -------
    statistics : dict
        Dictionary reporting the values for different events in the file.
    """
    # Some MIDI files will raise Exceptions on loading, if they are invalid.
    # We just skip those.
    try:
        pm = pretty_midi.PrettyMIDI(midi_file)
        # Extract informative events from the MIDI file
        statistics = {
            # track md5 hash name without extension
            'track_name': os.path.basename(midi_file).split('.')[0],
            # instruments
            'n_instruments': len(pm.instruments),
            'n_unique_instruments': len(set([i.program for i in pm.instruments])),
            'instruments': ', '.join([str(i.program) for i in pm.instruments]),
            'instrument_families': ', '.join(set([categorize_midi_instrument(i.program) for i in pm.instruments])),
            'number_of_instrument_families': len(set([categorize_midi_instrument(i.program) for i in pm.instruments])),
            # notes
            'n_notes': sum([len(i.notes) for i in pm.instruments]),
            'n_unique_notes': len(set([n.pitch for i in pm.instruments for n in i.notes])),
            'average_n_unique_notes_per_instrument': compute_list_average([len(set([n.pitch for n in i.notes])) for i in pm.instruments]),
            'average_note_duration': compute_list_average([n.end - n.start for i in pm.instruments for n in i.notes]),
            'average_note_velocity': compute_list_average([n.velocity for i in pm.instruments for n in i.notes]),
            'average_note_pitch': compute_list_average([n.pitch for i in pm.instruments for n in i.notes]),
            'range_of_note_pitches': (max([n.pitch for i in pm.instruments for n in i.notes]) - min([n.pitch for i in pm.instruments for n in i.notes])),
            'average_range_of_note_pitches_per_instrument': compute_list_average([max([n.pitch for n in i.notes]) - (min([n.pitch for n in i.notes])) for i in pm.instruments]),
            'number_of_note_pitch_classes': len(set([n.pitch % 12 for i in pm.instruments for n in i.notes])),
            'average_number_of_note_pitch_classes_per_instrument': compute_list_average([len(set([n.pitch % 12 for n in i.notes])) for i in pm.instruments]),
            'number_of_octaves': len(set([n.pitch // 12 for i in pm.instruments for n in i.notes])),
            'average_number_of_octaves_per_instrument': compute_list_average([len(set([n.pitch // 12 for n in i.notes])) for i in pm.instruments]),
            'number_of_notes_per_second': len([n for i in pm.instruments for n in i.notes]) / pm.get_end_time(),
            'shortest_note_length': min([n.end - n.start for i in pm.instruments for n in i.notes]),
            'longest_note_length': max([n.end - n.start for i in pm.instruments for n in i.notes]),
            # key signatures
            'main_key_signature': [k.key_number for k in pm.key_signature_changes][0], # hacky
            'n_key_changes': len(pm.key_signature_changes),
            # tempo
            'n_tempo_changes': len(pm.get_tempo_changes()[1]),
            'tempo_estimate': round(pm.estimate_tempo()), # weird results
            # time signatures
            'main_time_signature': [str(ts.numerator) + '/' + str(ts.denominator) for ts in pm.time_signature_changes][0], # hacky
            'n_time_signature_changes': len(pm.time_signature_changes),
            # track length
            'track_length_in_seconds': round(pm.get_end_time()),
            # lyrics
            'lyrics_nb_words': len([l.text for l in pm.lyrics]),
            'lyrics_unique_words': len(set([l.text for l in pm.lyrics])),
            'lyrics_bool': len(pm.lyrics) > 0,
        }
        # statistics = pd.DataFrame(statistics, index=[0])
        return statistics
    # Silently ignore exceptions for a clean presentation (sorry Python!)
    except Exception as e:
        pass




In [147]:
midi_file = '../data/lmd_full/0/0af5af6d9785c93d65215031077bead3.mid'
statistics = compute_statistics(midi_file)
df = pd.DataFrame(statistics, index=[0])
print(statistics)
display(df)

{'track_name': '0af5af6d9785c93d65215031077bead3', 'n_instruments': 10, 'n_unique_instruments': 8, 'instruments': '64, 56, 65, 56, 66, 57, 67, 58, 0, 0', 'instrument_families': 'Brass, Piano, Reed', 'number_of_instrument_families': 3, 'n_notes': 1666, 'n_unique_notes': 41, 'average_n_unique_notes_per_instrument': 12.1, 'average_note_duration': 0.3788140256102438, 'average_note_velocity': 81.35894357743098, 'average_note_pitch': 63.278511404561826, 'range_of_note_pitches': 51, 'average_range_of_note_pitches_per_instrument': 14.9, 'number_of_note_pitch_classes': 12, 'average_number_of_note_pitch_classes_per_instrument': 8.8, 'number_of_octaves': 6, 'average_number_of_octaves_per_instrument': 2.2, 'number_of_notes_per_second': 12.526315789473685, 'shortest_note_length': 0.020833333333328596, 'longest_note_length': 1.5, 'main_key_signature': 7, 'n_key_changes': 1, 'n_tempo_changes': 5, 'tempo_estimate': 187, 'main_time_signature': '1/4', 'n_time_signature_changes': 18, 'track_length_in_sec

Unnamed: 0,track_name,n_instruments,n_unique_instruments,instruments,instrument_families,number_of_instrument_families,n_notes,n_unique_notes,average_n_unique_notes_per_instrument,average_note_duration,...,main_key_signature,n_key_changes,n_tempo_changes,tempo_estimate,main_time_signature,n_time_signature_changes,track_length_in_seconds,lyrics_nb_words,lyrics_unique_words,lyrics_bool
0,0af5af6d9785c93d65215031077bead3,10,8,"64, 56, 65, 56, 66, 57, 67, 58, 0, 0","Brass, Piano, Reed",3,1666,41,12.1,0.378814,...,7,1,5,187,1/4,18,133,0,0,False


In [150]:
# Compute statistics for one midi file
midi_files = [
    '../data/lmd_full/0/0af5af6d9785c93d65215031077bead3.mid',
    '../data/lmd_full/0/07706096906421577e96b9252f590306.mid',
    '../data/lmd_full/1/1a0d67356a1c4b35c5103774f4cd0f1a.mid',
    '../data/lmd_full/1/1a0cf078518aa7d3c9713b6c0a354a68.mid',
    '../data/lmd_full/1/1a1ad63728ea30834f67b3a39dd7c83c.mid'
]

for f in midi_files:
    statistics = compute_statistics(f)
    print(statistics)

{'track_name': '0af5af6d9785c93d65215031077bead3', 'n_instruments': 10, 'n_unique_instruments': 8, 'instruments': '64, 56, 65, 56, 66, 57, 67, 58, 0, 0', 'instrument_families': 'Brass, Piano, Reed', 'number_of_instrument_families': 3, 'n_notes': 1666, 'n_unique_notes': 41, 'average_n_unique_notes_per_instrument': 12.1, 'average_note_duration': 0.3788140256102438, 'average_note_velocity': 81.35894357743098, 'average_note_pitch': 63.278511404561826, 'range_of_note_pitches': 51, 'average_range_of_note_pitches_per_instrument': 14.9, 'number_of_note_pitch_classes': 12, 'average_number_of_note_pitch_classes_per_instrument': 8.8, 'number_of_octaves': 6, 'average_number_of_octaves_per_instrument': 2.2, 'number_of_notes_per_second': 12.526315789473685, 'shortest_note_length': 0.020833333333328596, 'longest_note_length': 1.5, 'main_key_signature': 7, 'n_key_changes': 1, 'n_tempo_changes': 5, 'tempo_estimate': 187, 'main_time_signature': '1/4', 'n_time_signature_changes': 18, 'track_length_in_sec

In [155]:
# Compute statistics about every file in our collection in parallel using joblib
# We do things in parallel because there are tons so it would otherwise take too long!
statistics = joblib.Parallel(n_jobs=-1, verbose=1)(
    joblib.delayed(compute_statistics)(midi_file)
    # for midi_file in glob.glob(os.path.join('..', 'data', 'lmd_full', '*', '*.mid')))
    for midi_file in glob.glob(os.path.join('..', 'data', 'lmd_full', '0', '*.mid')))
# When an error occurred, None will be returned; filter those out.
statistics = [s for s in statistics if s is not None]

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    3.9s
[Parallel(n_jobs=-1)]: Done 184 tasks      | elapsed:    7.8s
[Parallel(n_jobs=-1)]: Done 668 tasks      | elapsed:   15.9s
[Parallel(n_jobs=-1)]: Done 1368 tasks      | elapsed:   25.9s
[Parallel(n_jobs=-1)]: Done 2268 tasks      | elapsed:   39.9s
[Parallel(n_jobs=-1)]: Done 3368 tasks      | elapsed:  1.0min
[Parallel(n_jobs=-1)]: Done 4594 tasks      | elapsed:  1.6min
[Parallel(n_jobs=-1)]: Done 5664 tasks      | elapsed:  2.1min
[Parallel(n_jobs=-1)]: Done 7364 tasks      | elapsed:  2.7min
[Parallel(n_jobs=-1)]: Done 9264 tasks      | elapsed:  3.2min
[Parallel(n_jobs=-1)]: Done 11117 out of 11132 | elapsed:  3.9min remaining:    0.3s
[Parallel(n_jobs=-1)]: Done 11132 out of 11132 | elapsed:  3.9min finished


In [158]:
df = pd.DataFrame(statistics)
# export df to csv
df.to_csv('statistics.csv', index=False)