<a href="https://colab.research.google.com/github/carson-edmonds/AAI-511_Final_Project/blob/main/MSAAI_511_TEAM_3_Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MSAAI 511 Neural Networks and Deep Learning
# University of San Diego
# Summer 2023 Section 02 Final Project Report
# Professor: Dr. Mirsardar Esnaeilli
## Final Project Team 3
## Topic: Music Genre and Composer Classification Using Deep Learning

## Project Team GitHub: https://github.com/carson-edmonds/AAI-511_Final_Project
## Auguest 14, 2023

# 1. Data Pre-processing:

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title Converting .midi files:
import os
import glob

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import string
from tqdm import tqdm
np.random.seed(42)  # makes the randomness deterministic
%matplotlib inline
plt.rcParams['figure.figsize'] = (15, 5)
plt.rcParams['axes.grid'] = True

!pip install mido --quiet
import mido
import string

def msg2dict(msg):
    result = dict()
    if 'note_on' in msg:
        on_ = True
    elif 'note_off' in msg:
        on_ = False
    else:
        on_ = None
    result['time'] = int(msg[msg.rfind('time'):].split(' ')[0].split('=')[1].translate(
        str.maketrans({a: None for a in string.punctuation})))

    if on_ is not None:
        for k in ['note', 'velocity']:
            result[k] = int(msg[msg.rfind(k):].split(' ')[0].split('=')[1].translate(
                str.maketrans({a: None for a in string.punctuation})))
    return [result, on_]

def switch_note(last_state, note, velocity, on_=True):
    # piano has 88 notes, corresponding to note id 21 to 108, any note out of this range will be ignored
    result = [0] * 88 if last_state is None else last_state.copy()
    if 21 <= note <= 108:
        result[note-21] = velocity if on_ else 0
    return result

def get_new_state(new_msg, last_state):
    new_msg, on_ = msg2dict(str(new_msg))
    new_state = switch_note(last_state, note=new_msg['note'], velocity=new_msg['velocity'], on_=on_) if on_ is not None else last_state
    return [new_state, new_msg['time']]

def track2seq(track):
    # piano has 88 notes, corresponding to note id 21 to 108, any note out of the id range will be ignored
    result = []
    last_state, last_time = get_new_state(str(track[0]), [0]*88)
    for i in range(1, len(track)):
        new_state, new_time = get_new_state(track[i], last_state)
        if new_time > 0:
            result += [last_state]*new_time
        last_state, last_time = new_state, new_time
    return result

def mid2arry(mid, min_msg_pct=0.1):
    tracks_len = [len(tr) for tr in mid.tracks]
    min_n_msg = max(tracks_len) * min_msg_pct
    # convert each track to nested list
    all_arys = []
    for i in range(len(mid.tracks)):
        if len(mid.tracks[i]) > min_n_msg:
            ary_i = track2seq(mid.tracks[i])
            all_arys.append(ary_i)
    # make all nested list the same length
    max_len = max([len(ary) for ary in all_arys])
    for i in range(len(all_arys)):
        if len(all_arys[i]) < max_len:
            all_arys[i] += [[0] * 88] * (max_len - len(all_arys[i]))
    all_arys = np.array(all_arys)
    all_arys = all_arys.max(axis=0)
    # trim: remove consecutive 0s in the beginning and at the end
    sums = all_arys.sum(axis=1)
    ends = np.where(sums > 0)[0]
    return all_arys[min(ends): max(ends)]


# The wrapper function to load a MIDI file and extract features
def extract_features(file_path):
    mid = mido.MidiFile(file_path, clip=True)

    return mid2arry(mid)

In [None]:
# Specify your directory containing MIDI folders
main_dir = '/content/drive/MyDrive/Colab Notebooks/Data/Composer_Dataset/NN_midi_files_extended/train/'

# Prepare lists to store filenames and lengths
filenamelist = []
lengths = []

# Walk through all subdirectories
for dirpath, dirnames, filenames in os.walk(main_dir):
    for filename in tqdm(filenames):
      if filename.endswith('.mid'):
            full_file_path = os.path.join(dirpath, filename)  # get full file path
            mid = mido.MidiFile(full_file_path, clip=True)
            # Compute the total number of messages in all tracks
            total_msgs = sum(len(track) for track in mid.tracks)
            filenamelist.append(filename)
            lengths.append(total_msgs)

# Create a DataFrame
df = pd.DataFrame({
    'filename': filenamelist,
    'length': lengths
})

In [None]:
#@title Processing Feature Extracted Data:
import pandas as pd
from sklearn.preprocessing import LabelEncoder

#read csv files
def csv_read(path):
  df = pd.read_csv(path)
  df_shape = df.shape
  print("Shape of the dataframe (row, col):", df_shape,"\r\n")
  return df

#encode object datatype data with labelencoder
def df_encode(df):
  obj_cols = list(df.select_dtypes(include='object'))
  le = LabelEncoder()
  df[obj_cols] = df[obj_cols].apply(LabelEncoder().fit_transform)
  return df

#reading all data subsets and encoding for model use
def csv_to_df(train_path, val_path, test_path):
  print("Train set:")
  train_df = csv_read(train_path)
  print("Val set:")
  val_df = csv_read(val_path)
  print("Test set:")
  test_df = csv_read(test_path)

  train_df = df_encode(train_df)
  val_df = df_encode(val_df)
  test_df = df_encode(test_df)
  return train_df, val_df, test_df

#Using feature extracted datasets.
#edit path to match correct file path:
train_path = '/content/drive/MyDrive/AAI 511/Final Project/midi_features_traindataset_raw.csv'
val_path = '/content/drive/MyDrive/AAI 511/Final Project/midi_features_valdataset_raw.csv'
test_path = '/content/drive/MyDrive/AAI 511/Final Project/midi_features_testdataset_raw.csv'

train_df, val_df, test_df = csv_to_df(train_path, val_path, test_path)

In [None]:
#@title Exploratory Data Analysis with prettymidi library:
#prettymidi can be used for music data analysis
!pip install pretty_midi --quiet
import pretty_midi
import librosa.display

pm = pretty_midi.PrettyMIDI('/content/drive/MyDrive/AAI 511/Final Project/Composer_Dataset/NN_midi_files_extended/train/bartok/bartok396.mid')

#plot Piano roll of song
def plot_piano_roll(pm, start_pitch, end_pitch, fs=100):
    # Use librosa's specshow function for displaying the piano roll
    librosa.display.specshow(pm.get_piano_roll(fs)[start_pitch:end_pitch],
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch))

plt.figure(figsize=(12, 4))
plot_piano_roll(pm, 24, 84)

# Plot a pitch class distribution
plt.bar(np.arange(12), pm.get_pitch_class_histogram());
plt.xticks(np.arange(12), ['C', '', 'D', '', 'E', 'F', '', 'G', '', 'A', '', 'B'])
plt.xlabel('Note')
plt.ylabel('Proportion')

In [None]:
#@title Exploratory Data Analysis with music21 library:
#Music21 can be used for music data analysis
import music21
from music21 import *

#open/read a midi file
def open_midi(midi_path, remove_drums):
    mf = midi.MidiFile()
    mf.open(midi_path)
    mf.read()
    mf.close()
    if (remove_drums):
        for i in range(len(mf.tracks)):
            mf.tracks[i].events = [ev for ev in mf.tracks[i].events if ev.channel != 10]

    return midi.translate.midiFileToStream(mf)

base_midi = open_midi('/content/drive/MyDrive/AAI 511/Final Project/Composer_Dataset/NN_midi_files_extended/train/bartok/bartok396.mid', True)
base_midi

#list instruments within song
def list_instruments(midi):
    partStream = midi.parts.stream()
    print("List of instruments found on MIDI file:")
    for p in partStream:
        aux = p
        print (p.partName)

list_instruments(base_midi)


import matplotlib.lines as mlines

#extract notes from song
def extract_notes(midi_part):
    parent_element = []
    ret = []
    for nt in midi_part.flat.notes:
        if isinstance(nt, note.Note):
            ret.append(max(0.0, nt.pitch.ps))
            parent_element.append(nt)
        elif isinstance(nt, chord.Chord):
            for pitch in nt.pitches:
                ret.append(max(0.0, pitch.ps))
                parent_element.append(nt)

    return ret, parent_element

def print_parts_countour(midi):
    fig = plt.figure(figsize=(12, 5))
    ax = fig.add_subplot(1, 1, 1)
    minPitch = pitch.Pitch('C10').ps
    maxPitch = 0
    xMax = 0

# Plotting notes.
    for i in range(len(midi.parts)):
        top = midi.parts[i].flat.notes
        y, parent_element = extract_notes(top)
        if (len(y) < 1): continue

        x = [n.offset for n in parent_element]
        ax.scatter(x, y, alpha=0.6, s=7)

        aux = min(y)
        if (aux < minPitch): minPitch = aux

        aux = max(y)
        if (aux > maxPitch): maxPitch = aux

        aux = max(x)
        if (aux > xMax): xMax = aux

    for i in range(1, 10):
        linePitch = pitch.Pitch('C{0}'.format(i)).ps
        if (linePitch > minPitch and linePitch < maxPitch):
            ax.add_line(mlines.Line2D([0, xMax], [linePitch, linePitch], color='red', alpha=0.1))

    plt.ylabel("Note index (each octave has 12 notes)")
    plt.xlabel("Number of quarter notes (beats)")
    plt.title('Voices motion approximation, each color is a different instrument, red lines show each octave')
    plt.show()

# Focusing only on 6 first measures.
print_parts_countour(base_midi.measures(0, 6))

#Plot pitch class
base_midi.plot('histogram', 'pitchClass', 'count')

#getting chords from song
temp_midi_chords = base_midi.chordify()
temp_midi = stream.Score()
temp_midi.insert(0, temp_midi_chords)

# Printing merged tracks.
print_parts_countour(temp_midi)

# Dumping first measure notes
temp_midi_chords.measures(0, 1).show("text")

In [None]:
#@title Exploratory Data Analysis with dataprep.eda:
!pip install dataprep --quiet
from dataprep.eda import create_report
create_report(train_df)

# 2. Feature Extraction:

# 3. Model Building:

# 4. Model Training:

# 5. Model Evaluation:

# 6. Model Optimization:

# 7. Conclusion and Discussion