In [1]:
import pretty_midi
import music21
import numpy as np
import pandas as pd
import os

np.int = np.int64 # Fix for pretty_midi

In [4]:
clean_base_path = "../data/clean/"
raw_base_path = "../data/raw/"

metadata = pd.read_csv(os.path.join(clean_base_path, "tags.csv"))
midi_files = [os.path.join(raw_base_path, f) for f in metadata["midi_filename"]]

In [5]:
all_midi_data = []

Function = pretty_midi.PrettyMIDI
all_midi_data = [Function(f) for f in midi_files]

In [6]:
# Function to compute the pitch histogram
def compute_pitch_histogram(midi_data):
    """
    Compute the pitch histogram for a MIDI file and return a dictionary with keys like key_0, key_1, ..., key_87.
    """
    # Initialize histogram with zeros for each pitch (128 pitches in total)
    pitch_range_histogram = np.zeros(128)
    
    # Populate histogram
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            pitch_range_histogram[note.pitch] += 1

    piano_range_histogram = pitch_range_histogram[21:109] # Slice for piano range (A0 to C8)

    # Convert histogram to a dictionary with pitch_x keys
    pitch_class_histogram_dict = {f"key_{i}": piano_range_histogram[i] for i in range(88)}
    return pitch_class_histogram_dict

In [7]:
# Function to compute pitch-related features
def compute_pitch_range(midi_data):
    """
    Compute pitch range and related features.
    """
    all_pitches = [note.pitch for instrument in midi_data.instruments for note in instrument.notes]
    pitch_range = max(all_pitches) - min(all_pitches)
    return pitch_range

In [8]:
# Function to compute rhythm-related features
def compute_rhythm_features(midi_data):
    """
    Compute tempo and average note duration.
    """
    tempo = midi_data.estimate_tempo()
    note_durations = [note.end - note.start for instrument in midi_data.instruments for note in instrument.notes]
    avg_note_duration = np.mean(note_durations)
    return tempo, avg_note_duration

In [9]:
# Function to compute dynamic-related features
def compute_dynamic_features(midi_data):
    """
    Compute average velocity and velocity range.
    """
    velocities = [note.velocity for instrument in midi_data.instruments for note in instrument.notes]
    avg_velocity = np.mean(velocities)
    velocity_range = max(velocities) - min(velocities)
    return avg_velocity, velocity_range

In [10]:
def fetch_key_from_midi(midi_file):
    """
    Fetch the key of a MIDI file using music21.
    """

    # Load the MIDI file
    midi_data = music21.converter.parse(midi_file)

    # Analyze the key
    analyzed_key = midi_data.analyze('key')

    # Return the key signature
    return analyzed_key

In [11]:
# Master function to extract all MIDI features
def extract_midi_features(midi_data):
    result = {}

    # Pitch histogram
    result.update(compute_pitch_histogram(midi_data))

    # Pitch range
    result["pitch_range"] = compute_pitch_range(midi_data)

    # Rhythm features
    result["tempo"], result["avg_note_duration"] = compute_rhythm_features(midi_data)

    # Dynamic features
    result["avg_velocity"], result["velocity_range"] = compute_dynamic_features(midi_data)

    return result

In [12]:
features = metadata.copy()

for i, midi_data in enumerate(all_midi_data):
    midi_features = extract_midi_features(midi_data)
    features.loc[i, midi_features.keys()] = midi_features.values()

features.head()

Unnamed: 0,split,midi_filename,duration,tags,key_0,key_1,key_2,key_3,key_4,key_5,...,key_83,key_84,key_85,key_86,key_87,pitch_range,tempo,avg_note_duration,avg_velocity,velocity_range
0,train,2017/MIDI-Unprocessed_066_PIANO066_MID--AUDIO-...,464.649433,modern,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,75.0,195.593838,0.371628,67.903488,112.0
1,train,2004/MIDI-Unprocessed_XP_21_R1_2004_01_ORIG_MI...,872.640588,post-romantic,0.0,2.0,1.0,3.0,1.0,6.0,...,0.0,0.0,0.0,0.0,0.0,81.0,202.21236,0.291267,62.954085,102.0
2,validation,2006/MIDI-Unprocessed_17_R1_2006_01-06_ORIG_MI...,397.857508,romantic,0.0,0.0,0.0,4.0,11.0,1.0,...,1.0,13.0,0.0,7.0,0.0,83.0,204.545455,0.13026,65.008211,119.0
3,validation,2009/MIDI-Unprocessed_07_R1_2009_04-05_ORIG_MI...,400.557826,romantic,0.0,0.0,2.0,0.0,2.0,0.0,...,0.0,0.0,0.0,0.0,0.0,69.0,199.021019,0.379225,60.342821,108.0
4,train,2013/ORIG-MIDI_03_7_8_13_Group__MID--AUDIO_19_...,563.904351,romantic,0.0,0.0,1.0,1.0,2.0,4.0,...,0.0,0.0,0.0,0.0,0.0,77.0,197.028327,0.252819,70.214286,110.0


In [13]:
clean_base_path = "../data/clean/"
raw_base_path = "../data/raw/"

output_file = os.path.join(clean_base_path, "features.csv")

features.drop(columns=["midi_filename"], inplace=True, errors="ignore")
features.to_csv(output_file, index=False)

features.head()

Unnamed: 0,split,duration,tags,key_0,key_1,key_2,key_3,key_4,key_5,key_6,...,key_83,key_84,key_85,key_86,key_87,pitch_range,tempo,avg_note_duration,avg_velocity,velocity_range
0,train,464.649433,modern,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,75.0,195.593838,0.371628,67.903488,112.0
1,train,872.640588,post-romantic,0.0,2.0,1.0,3.0,1.0,6.0,17.0,...,0.0,0.0,0.0,0.0,0.0,81.0,202.21236,0.291267,62.954085,102.0
2,validation,397.857508,romantic,0.0,0.0,0.0,4.0,11.0,1.0,0.0,...,1.0,13.0,0.0,7.0,0.0,83.0,204.545455,0.13026,65.008211,119.0
3,validation,400.557826,romantic,0.0,0.0,2.0,0.0,2.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,69.0,199.021019,0.379225,60.342821,108.0
4,train,563.904351,romantic,0.0,0.0,1.0,1.0,2.0,4.0,21.0,...,0.0,0.0,0.0,0.0,0.0,77.0,197.028327,0.252819,70.214286,110.0
