In [345]:
import pretty_midi
import joblib
import glob
import os
import pandas as pd

from pretty_midi import program_to_instrument_name

def compute_list_average(l):
    return sum(l) / len(l)

def categorize_midi_instrument(program_number):
    if 0 <= program_number <= 7:
        return 'Piano'
    elif 8 <= program_number <= 15:
        return 'Chromatic Percussion'
    elif 16 <= program_number <= 23:
        return 'Organ'
    elif 24 <= program_number <= 31:
        return 'Guitar'
    elif 32 <= program_number <= 39:
        return 'Bass'
    elif 40 <= program_number <= 47:
        return 'Strings'
    elif 48 <= program_number <= 55:
        return 'Ensemble'
    elif 56 <= program_number <= 63:
        return 'Brass'
    elif 64 <= program_number <= 71:
        return 'Reed'
    elif 72 <= program_number <= 79:
        return 'Pipe'
    elif 80 <= program_number <= 87:
        return 'Synth Lead'
    elif 88 <= program_number <= 95:
        return 'Synth Pad'
    elif 96 <= program_number <= 103:
        return 'Synth Effects'
    elif 104 <= program_number <= 111:
        return 'Ethnic'
    elif 112 <= program_number <= 119:
        return 'Percussive'
    elif 120 <= program_number <= 127:
        return 'Sound Effects'

def track_name(midi_file):
    return os.path.basename(midi_file).split('.')[0]

def n_instruments(pm):
    if pm.instruments:
        return len(pm.instruments)
    else:
        return None

def n_unique_instruments(pm):
    if pm.instruments:
        return len(set([instrument.program for instrument in pm.instruments]))
    else:
        return None

def instrument_names(pm):
    if pm.instruments:
        return [list(set([program_to_instrument_name(instrument.program) for instrument in pm.instruments]))]
    else:
        return None

def instrument_families(pm):
    if pm.instruments:
        return [list(set([categorize_midi_instrument(instrument.program) for instrument in pm.instruments]))]
    else:
        return None

def number_of_instrument_families(pm):
    if pm.instruments:
        return len(set([categorize_midi_instrument(instrument.program) for instrument in pm.instruments]))
    else:
        return None

def number_of_notes(pm):
    if pm.instruments:
        return sum([len(instrument.notes) for instrument in pm.instruments])
    else:
        return None

def number_of_unique_notes(pm):
    if pm.instruments:
        return len(set([note.pitch for instrument in pm.instruments for note in instrument.notes]))
    else:
        return None

def avg_number_of_unique_notes_per_instrument(pm):
    if pm.instruments:
        return compute_list_average([len(set([note.pitch for note in instrument.notes])) for instrument in pm.instruments])
    else:
        return None

def average_note_duration(pm):
    if pm.instruments:
        return compute_list_average([note.end - note.start for instrument in pm.instruments for note in instrument.notes])
    else:
        return None

def average_note_velocity(pm):
    if pm.instruments:
        return compute_list_average([note.velocity for instrument in pm.instruments for note in instrument.notes])
    else:
        return None

def average_note_pitch(pm):
    if pm.instruments:
        return compute_list_average([note.pitch for instrument in pm.instruments for note in instrument.notes])
    else:
        return None

def range_of_note_pitches(pm):
    if pm.instruments:
        return max([note.pitch for instrument in pm.instruments for note in instrument.notes]) - min([note.pitch for instrument in pm.instruments for note in instrument.notes])
    else:
        return None

def average_range_of_note_pitches_per_instrument(pm):
    if pm.instruments:
        return compute_list_average([max([note.pitch for note in instrument.notes]) - min([note.pitch for note in instrument.notes]) for instrument in pm.instruments])
    else:
        return None

def number_of_note_pitch_classes(pm):
    if pm.instruments:
        return len(set([note.pitch % 12 for instrument in pm.instruments for note in instrument.notes]))
    else:
        return None

def average_number_of_note_pitch_classes_per_instrument(pm):
    if pm.instruments:
        return compute_list_average([len(set([note.pitch % 12 for note in instrument.notes])) for instrument in pm.instruments])
    else:
        return None

def number_of_octaves(pm):
    if pm.instruments:
        return len(set([note.pitch // 12 for instrument in pm.instruments for note in instrument.notes]))
    else:
        return None

def average_number_of_octaves_per_instrument(pm):
    if pm.instruments:
        return compute_list_average([len(set([note.pitch // 12 for note in instrument.notes])) for instrument in pm.instruments])
    else:
        return None

def number_of_notes_per_second(pm):
    if pm.instruments:
        return len([note for instrument in pm.instruments for note in instrument.notes]) / pm.get_end_time()
    else:
        return None

def shortest_note_length(pm):
    if pm.instruments:
        return min([note.end - note.start for instrument in pm.instruments for note in instrument.notes])
    else:
        return None

def longest_note_length(pm):
    if pm.instruments:
        return max([note.end - note.start for instrument in pm.instruments for note in instrument.notes])
    else:
        return None

def main_key_signature(pm):
    if pm.key_signature_changes:
        return pm.key_signature_changes[0].key_number
    else:
        return None

def n_key_changes(pm):
    if pm.key_signature_changes:
        return len(pm.key_signature_changes)
    else:
        return None

def n_tempo_changes(pm):
    return len(pm.get_tempo_changes())

def average_tempo(pm):
    try:
        return round(pm.estimate_tempo())
    except Exception:
        return None

def tempo_changes(pm):
    return [[pm.get_tempo_changes()]]

def main_time_signature(pm):
    if pm.time_signature_changes:
        return [str(ts.numerator) + '/' + str(ts.denominator) for ts in pm.time_signature_changes][0]
    else:
        return None

def n_time_signature_changes(pm):
    if pm.time_signature_changes:
        return len(pm.time_signature_changes)
    else:
        return None

def all_time_signatures(pm):
    if pm.time_signature_changes:
        return [[str(ts.numerator) + '/' + str(ts.denominator) for ts in pm.time_signature_changes]]
    else:
        return None

def four_to_the_floor(pm):
    if pm.time_signature_changes:
        time_signatures = [str(ts.numerator) + '/' + str(ts.denominator) for ts in pm.time_signature_changes]
        # check if time_signatures contains exclusively '2/4' or '4/4'
        return all([ts == '4/4' for ts in time_signatures]) and len(time_signatures) == 1
    else:
        return None

def track_length_in_seconds(pm):
    return pm.get_end_time()

def lyrics_number_of_words(pm):
    if pm.lyrics:
        return len([l.text for l in pm.lyrics])
    else:
        return None

def lyrics_number_of_unique_words(pm):
    if pm.lyrics:
        return len(set([l.text for l in pm.lyrics]))
    else:
        return None

def lyrics_boolean(pm):
    if pm.lyrics:
        return True
    else:
        return False


def compute_statistics(midi_file):
    try:
        pm = pretty_midi.PrettyMIDI(midi_file)
    except Exception:
        return None
    statistics = {
        # track md5 hash name without extension
        'md5': track_name(midi_file),
        # instruments
        'n_instruments': n_instruments(pm),
        'n_unique_instruments': n_unique_instruments(pm),
        'instrument_names': instrument_names(pm),
        'instrument_families': instrument_families(pm),
        'number_of_instrument_families': number_of_instrument_families(pm),
        # notes
        'n_notes': number_of_notes(pm),
        'n_unique_notes': number_of_unique_notes(pm),
        'average_n_unique_notes_per_instrument': avg_number_of_unique_notes_per_instrument(pm),
        'average_note_duration': average_note_duration(pm),
        'average_note_velocity': average_note_velocity(pm),
        'average_note_pitch': average_note_pitch(pm),
        'range_of_note_pitches': range_of_note_pitches(pm),
        'average_range_of_note_pitches_per_instrument': average_range_of_note_pitches_per_instrument(pm),
        'number_of_note_pitch_classes': number_of_note_pitch_classes(pm),
        'average_number_of_note_pitch_classes_per_instrument': average_number_of_note_pitch_classes_per_instrument(pm),
        'number_of_octaves': number_of_octaves(pm),
        'average_number_of_octaves_per_instrument': average_number_of_octaves_per_instrument(pm),
        'number_of_notes_per_second': number_of_notes_per_second(pm),
        'shortest_note_length': shortest_note_length(pm),
        'longest_note_length': longest_note_length(pm),
        # key signatures
        'main_key_signature': main_key_signature(pm), # hacky
        'n_key_changes': n_key_changes(pm),
        # tempo
        'n_tempo_changes': n_tempo_changes(pm),
        'tempo_estimate': average_tempo(pm), # hacky
        # time signatures
        'main_time_signature': main_time_signature(pm), # hacky
        'all_time_signatures': all_time_signatures(pm),
        'four_to_the_floor': four_to_the_floor(pm),
        'n_time_signature_changes': n_time_signature_changes(pm),
        # track length
        'track_length_in_seconds': track_length_in_seconds(pm),
        # lyrics
        'lyrics_nb_words': lyrics_number_of_words(pm),
        'lyrics_unique_words': lyrics_number_of_unique_words(pm),
        'lyrics_bool': lyrics_boolean(pm)
    }
    return statistics

In [347]:
midi_file = '../data/lmd_full/0/0af5af6d9785c93d65215031077bead3.mid'
statistics = compute_statistics(midi_file)
df = pd.DataFrame(statistics, index=[0])
print(statistics)
df.iloc[0]

{'md5': '0af5af6d9785c93d65215031077bead3', 'n_instruments': 10, 'n_unique_instruments': 8, 'instrument_names': [['Baritone Sax', 'Trumpet', 'Alto Sax', 'Soprano Sax', 'Acoustic Grand Piano', 'Tenor Sax', 'Tuba', 'Trombone']], 'instrument_families': [['Piano', 'Reed', 'Brass']], 'number_of_instrument_families': 3, 'n_notes': 1666, 'n_unique_notes': 41, 'average_n_unique_notes_per_instrument': 12.1, 'average_note_duration': 0.3788140256102438, 'average_note_velocity': 81.35894357743098, 'average_note_pitch': 63.278511404561826, 'range_of_note_pitches': 51, 'average_range_of_note_pitches_per_instrument': 14.9, 'number_of_note_pitch_classes': 12, 'average_number_of_note_pitch_classes_per_instrument': 8.8, 'number_of_octaves': 6, 'average_number_of_octaves_per_instrument': 2.2, 'number_of_notes_per_second': 12.526315789473685, 'shortest_note_length': 0.020833333333328596, 'longest_note_length': 1.5, 'main_key_signature': 7, 'n_key_changes': 1, 'n_tempo_changes': 2, 'tempo_estimate': 187, '

md5                                                                     0af5af6d9785c93d65215031077bead3
n_instruments                                                                                         10
n_unique_instruments                                                                                   8
instrument_names                                       [Baritone Sax, Trumpet, Alto Sax, Soprano Sax,...
instrument_families                                                                 [Piano, Reed, Brass]
number_of_instrument_families                                                                          3
n_notes                                                                                             1666
n_unique_notes                                                                                        41
average_n_unique_notes_per_instrument                                                               12.1
average_note_duration                                  

In [348]:
midi_file = '../data/lmd_full/0/0af5af6d9785c93d65215031077bead3.mid'
pm = pretty_midi.PrettyMIDI(midi_file)

# get all beat locations
beat_times = pm.get_beats()
print(beat_times)

[  0.     0.5    1.     1.5    2.     2.5    3.     3.5    4.     4.5
   5.     5.5    6.     6.5    7.     7.5    8.     8.5    9.     9.5
  10.    10.5   11.    11.5   12.    12.5   13.    13.5   14.    14.5
  15.    15.5   16.    16.5   17.    17.5   18.    18.5   19.    19.5
  20.    20.5   21.    21.5   22.    22.5   23.    23.5   24.    24.5
  25.    25.5   26.    26.5   27.    27.5   28.    28.5   29.    29.5
  30.    30.5   31.    31.5   32.    32.5   33.    33.5   34.    34.5
  35.    35.5   36.    36.5   37.    37.5   38.    38.5   39.    39.5
  40.    40.5   41.    41.5   42.    42.5   43.    43.5   44.    44.5
  45.    45.5   46.    46.5   47.    47.5   48.    48.5   48.75  49.
  49.25  49.5   49.75  50.    50.25  50.5   51.    51.5   52.    52.5
  53.    53.5   54.    54.5   55.    55.5   56.    56.5   57.    57.5
  58.    58.5   59.    59.5   60.    60.5   61.    61.5   62.    62.5
  63.    63.5   64.    66.    66.5   67.    67.5   68.    68.5   69.
  69.5   70.    70.5  

In [349]:
# Compute statistics for one midi file
midi_files = [
    '../data/lmd_full/0/0af5af6d9785c93d65215031077bead3.mid',
    '../data/lmd_full/0/07706096906421577e96b9252f590306.mid',
    '../data/lmd_full/1/1a0d67356a1c4b35c5103774f4cd0f1a.mid',
    '../data/lmd_full/1/1a0cf078518aa7d3c9713b6c0a354a68.mid',
    '../data/lmd_full/1/1a1ad63728ea30834f67b3a39dd7c83c.mid'
]

for f in midi_files:
    statistics = compute_statistics(f)
    print(statistics)

{'md5': '0af5af6d9785c93d65215031077bead3', 'n_instruments': 10, 'n_unique_instruments': 8, 'instrument_names': [['Baritone Sax', 'Trumpet', 'Alto Sax', 'Soprano Sax', 'Acoustic Grand Piano', 'Tenor Sax', 'Tuba', 'Trombone']], 'instrument_families': [['Piano', 'Reed', 'Brass']], 'number_of_instrument_families': 3, 'n_notes': 1666, 'n_unique_notes': 41, 'average_n_unique_notes_per_instrument': 12.1, 'average_note_duration': 0.3788140256102438, 'average_note_velocity': 81.35894357743098, 'average_note_pitch': 63.278511404561826, 'range_of_note_pitches': 51, 'average_range_of_note_pitches_per_instrument': 14.9, 'number_of_note_pitch_classes': 12, 'average_number_of_note_pitch_classes_per_instrument': 8.8, 'number_of_octaves': 6, 'average_number_of_octaves_per_instrument': 2.2, 'number_of_notes_per_second': 12.526315789473685, 'shortest_note_length': 0.020833333333328596, 'longest_note_length': 1.5, 'main_key_signature': 7, 'n_key_changes': 1, 'n_tempo_changes': 2, 'tempo_estimate': 187, '

In [350]:
input_paths = ['../data/music_picks/good_music_artists', '../data/music_picks/electronic_artists']
output_paths = ['../data/music_picks/statistics_good_music.csv', '../data/music_picks/statistics_electronic_artists.csv']

In [351]:
# remove RuntimeWarning
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

for input_path, output_path in zip(input_paths, output_paths):
    # Compute statistics about every file in our collection in parallel using joblib
    # We do things in parallel because there are tons so it would otherwise take too long!
    statistics = joblib.Parallel(n_jobs=-1, verbose=1)(
        joblib.delayed(compute_statistics)(midi_file)
        for midi_file in glob.glob(os.path.join(input_path, '*.mid')))
    # When an error occurred, None will be returned; filter those out.
    statistics = [s for s in statistics if s is not None]

    # export df to csv
    df = pd.DataFrame(statistics)
    df.to_csv(output_path, index=False)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    2.5s
[Parallel(n_jobs=-1)]: Done 488 tasks      | elapsed:    8.6s
[Parallel(n_jobs=-1)]: Done 1488 tasks      | elapsed:   22.8s
[Parallel(n_jobs=-1)]: Done 2888 tasks      | elapsed:   42.7s
[Parallel(n_jobs=-1)]: Done 4583 out of 4583 | elapsed:  1.1min finished
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  52 tasks      | elapsed:    1.1s
[Parallel(n_jobs=-1)]: Done 352 tasks      | elapsed:    7.3s
[Parallel(n_jobs=-1)]: Done 852 tasks      | elapsed:   17.0s
[Parallel(n_jobs=-1)]: Done 1325 out of 1325 | elapsed:   26.4s finished


In [352]:
# add artist and title information
# load scraped artist genre
def enrich_stats(df_stats, df_artist):
    # count number of rows per artist per genre
    df_count = df_artist.groupby(['artist', 'genre']).size().reset_index(name='counts')

    # keep highest count per artist per genre
    df_count = df_count.sort_values(['artist', 'counts'], ascending=False).drop_duplicates(['artist'])
    df_count.reset_index(drop=True, inplace=True)

    # select part of df_artist with md5 and drop duplicates
    df_artist_deduped = df_artist[['md5', 'artist', 'title']].drop_duplicates(['md5'])

    # merge with genre data
    df_genre_new = pd.merge(df_artist_deduped, df_count.drop('counts', axis=1), on='artist')

    # merge df_stats and df_artist on md5
    return pd.merge(df_stats, df_genre_new, on='md5')

In [353]:
# load statistics
df_stats_good = pd.read_csv(output_paths[0])
df_stats_electronic = pd.read_csv(output_paths[1])
df_artist = pd.read_csv('../data/mmd_scraped_artist_genre.csv')

# create new df with artist and title information
df_stats_electronic = enrich_stats(df_stats_electronic, df_artist)
df_stats_good = enrich_stats(df_stats_good, df_artist)

# add column with source info
df_stats_electronic['source'] = 'electronic'
df_stats_good['source'] = 'good'

# concatenate dataframes
df_stats = pd.concat([df_stats_electronic, df_stats_good], ignore_index=True)

# export df to csv
df_stats.to_csv('../data/music_picks/model_statistics.csv', index=False)

In [354]:
# load statistics
df_stats = pd.read_csv('../data/music_picks/model_statistics.csv')
df_stats.genre.value_counts()

rock                 2083
pop                  1751
alternative-indie    1261
blues                 285
metal                 121
dance-eletric          94
classic modern rb      63
best of british        45
pride playlist         29
fusion                 26
hits of 2011 2020      22
bluegrass              19
disco                   7
Name: genre, dtype: int64

In [355]:
# calculate percentage of null values per column of df_stats
df_stats.isnull().sum() / df_stats.shape[0] * 100

md5                                                     0.000000
n_instruments                                           0.000000
n_unique_instruments                                    0.000000
instrument_names                                        0.000000
instrument_families                                     0.000000
number_of_instrument_families                           0.000000
n_notes                                                 0.000000
n_unique_notes                                          0.000000
average_n_unique_notes_per_instrument                   0.000000
average_note_duration                                   0.000000
average_note_velocity                                   0.000000
average_note_pitch                                      0.000000
range_of_note_pitches                                   0.000000
average_range_of_note_pitches_per_instrument            0.000000
number_of_note_pitch_classes                            0.000000
average_number_of_note_pi

In [356]:
df_stats.iloc[10]

md5                                                                     98d8d1ea8abc0d662b926edc1a78e550
n_instruments                                                                                         10
n_unique_instruments                                                                                   7
instrument_names                                       [['Acoustic Grand Piano', 'Lead 2 (sawtooth)',...
instrument_families                                    [['Synth Lead', 'Pipe', 'Guitar', 'Piano', 'Ba...
number_of_instrument_families                                                                          5
n_notes                                                                                            12629
n_unique_notes                                                                                        56
average_n_unique_notes_per_instrument                                                               16.2
average_note_duration                                  