# Chroma histogram extraction from MIDI files

This notebook uses the MIDIFileProcessor class to extract chroma histograms from provided MIDI files.

It is set up to expect the cleaned subset of the Lakh MIDI dataset: https://www.kaggle.com/datasets/imsparsh/lakh-midi-clean

Extracted chroma histograms are saved as CSV files into the `chords_dataset` directory in the same directory as this notebook.

In [None]:
import numpy as np
import pandas as pd
import pretty_midi as pm
import os
import time
import joblib
import matplotlib.pyplot as plt
from midi_file_processor import MIDIFileProcessor
from datetime import timedelta

#### Load key classifier, define excluded directories, initialise dataframe, initialise processor object

In [None]:
key_classifier = joblib.load("./key_signature_classifier/key_classification_svc_model_2023-04-09_13-47-19.pkl")

dataset_root_dir = "../datasets/Lakh_clean_MIDI/"
EXCLUDED = [".DS_Store", "midiindx.htm", ".gitattributes", "LICENSE", "README.md"]

DF_COLUMNS = ["melody_chroma", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"]

processor = MIDIFileProcessor(key_classifier)

#### Specify which artist folders to process using index

In [None]:
chords_df = pd.DataFrame(columns=DF_COLUMNS)

start_time = time.time()

artist_folders = os.listdir(dataset_root_dir)
artist_folders_count = len(artist_folders)

AF_START_IDX = 0
AF_END_IDX   = artist_folders_count

midi_files_processed = 0

af_subsection = artist_folders[AF_START_IDX:AF_END_IDX]

for af_idx, artist_folder in enumerate(af_subsection):
    artist_folder_path = os.path.join(dataset_root_dir, artist_folder)
    # Skip non-directories
    if not os.path.isdir(artist_folder_path):
        continue
    print(f"Artist folder {af_idx+AF_START_IDX+1} of {artist_folders_count}: {artist_folder}")
    
    artist_midi_files = os.listdir(artist_folder_path)
    artist_midi_files_count = len(artist_midi_files)
    
    # Loop through files in artist folder
    for mf_idx, midi_file_name in enumerate(artist_midi_files):
        try:
            if not midi_file_name.lower().endswith(".mid"):
                continue
            print(f"  MIDI file {mf_idx+1} of {artist_midi_files_count}: {midi_file_name}")
            midi_file_path = os.path.join(artist_folder_path, midi_file_name)
            midi_file = pm.PrettyMIDI(midi_file_path)
    
            # Find melody instrument
            melody_instrument = processor.get_melody_instrument(midi_file)
    
            # Get key signatures
            key_signatures = processor.get_key_signatures(midi_file)
    
            # Get chords and add to df
            midi_file_chords_array = processor.get_chords_as_array(midi_file, melody_instrument, key_signatures)
            midi_file_chords_df = pd.DataFrame.from_records(midi_file_chords_array, columns=DF_COLUMNS, coerce_float=True)
            chords_df = pd.concat((chords_df, midi_file_chords_df))
            
            # Update iterator for counting total files processed
            midi_files_processed += 1
            
        except Exception as e:
            print(f"    Error processing {midi_file_name}: {e.__class__}, skipping file.")
    
    print("-----")

print(f"COMPLETED dataset processing from artist folder {AF_START_IDX}-{AF_END_IDX-1} of {artist_folders_count} in {timedelta(seconds=(round(time.time() - start_time, 3)))}")
print(f"Total MIDI files processed: {midi_files_processed}")

# Save section to file
csv_filepath = f"./chords_datasets/chords_dataset_idx-{str(AF_START_IDX).rjust(4, '0')}-{str(AF_END_IDX-1).rjust(4, '0')}.csv"
chords_df.to_csv(csv_filepath, index=False)
print(f"Chords saved at: {csv_filepath}")

#### View dataset inline

In [None]:
chords_df

#### Plot quantities of each chroma pitch in the extracted chroma histograms dataset

In [None]:
chroma_sums = []
for col_name, series in chords_df.items():
    if col_name != "melody_chroma":
        print(f"{col_name}:\t{series.sum()}")
        chroma_sums.append(series.sum())

x_label = ["T", "m2", "M2", "m3", "M3", "P4", "a4", "P5", "m6", "M6", "m7", "M7"]

plt.bar(x_label, np.array(chroma_sums)/sum(chroma_sums))
plt.show()