## Model Training

We'll attempt to train a multi-class support vector machine on our feature-engineered data.

### Retrieve feature-engineered data into pandas dataframe

In [1]:
import config
import mysql
import mysql.connector
from mysql.connector import errorcode

# DB FUNCTIONS

def connect(db_name=None):
    if db_name:
        cnx = mysql.connector.connect(
            host = config.rds_host,
            user = config.rds_user,
            passwd = config.rds_password,
            database = db_name)
    else:
        cnx = mysql.connector.connect(
            host = config.rds_host,
            user = config.rds_user,
            passwd = config.rds_password)
    cursor = cnx.cursor()
    return cursor, cnx


In [2]:
import pandas as pd

cur, cnx = connect('instruments')
cur.execute('''SELECT * FROM harmonic_content''')

df = pd.DataFrame(cur.fetchall())
df.columns = [i[0] for i in cur.description]
cnx.close()

In [3]:
df

Unnamed: 0,sample_id,frequency_id,amplitude
0,1,0,0.003238
1,1,1,0.003516
2,1,2,0.003653
3,1,3,0.004583
4,1,4,0.005378
...,...,...,...
4786171,1847,2043,0.002954
4786172,1847,2044,0.002266
4786173,1847,2045,0.002954
4786174,1847,2046,0.001999


In [4]:
data = df.pivot(index='sample_id', columns='frequency_id', values='amplitude')
data

frequency_id,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,0.003238,0.003516,0.003653,0.004583,0.005378,0.005542,0.005779,0.005876,0.006885,0.006476,...,0.076767,0.073233,0.073390,0.071249,0.067423,0.074807,0.071793,0.074314,0.065015,0.060249
2,0.002859,0.004412,0.006860,0.010949,0.013922,0.015170,0.017266,0.020821,0.021021,0.021095,...,0.060817,0.060192,0.063314,0.061040,0.062658,0.058267,0.060641,0.061230,0.057176,0.052379
3,0.470771,0.314912,0.120817,0.095654,0.085359,0.060197,0.037835,0.032554,0.027049,0.026489,...,0.003026,0.002823,0.002957,0.003481,0.003591,0.003186,0.003226,0.003548,0.003635,0.003283
4,0.005541,0.009706,0.013444,0.023435,0.032274,0.034287,0.033264,0.040268,0.040223,0.034025,...,0.041233,0.033820,0.032133,0.029678,0.030990,0.035940,0.036911,0.036216,0.038078,0.038489
5,0.003521,0.003417,0.003128,0.003775,0.003662,0.004113,0.004573,0.005411,0.006391,0.005843,...,0.042432,0.043850,0.041825,0.037201,0.036231,0.039358,0.046034,0.049473,0.047220,0.040782
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2333,0.005842,0.006696,0.005504,0.008009,0.010848,0.008933,0.005405,0.005450,0.004055,0.004462,...,0.000296,0.000293,0.000254,0.000225,0.000219,0.000230,0.000214,0.000199,0.000189,0.000159
2334,0.011724,0.013896,0.011100,0.013689,0.022415,0.020919,0.010713,0.008141,0.006795,0.008993,...,0.000432,0.000410,0.000409,0.000413,0.000424,0.000351,0.000316,0.000306,0.000334,0.000327
2335,0.014983,0.020384,0.016735,0.018972,0.027308,0.021950,0.012783,0.012019,0.008091,0.008059,...,0.000304,0.000357,0.000451,0.000693,0.000934,0.000752,0.000527,0.000734,0.000921,0.000812
2336,0.010699,0.013480,0.009601,0.013410,0.019354,0.015173,0.010070,0.008354,0.005919,0.008271,...,0.000284,0.000300,0.000307,0.000298,0.000252,0.000240,0.000230,0.000376,0.000763,0.000884


In [5]:
import pandas as pd

cur, cnx = connect('instruments')
cur.execute('''SELECT * FROM samples''')

sample_info = pd.DataFrame(cur.fetchall())
sample_info.columns = [i[0] for i in cur.description]
cnx.close()

In [6]:
sample_info.set_index('sample_id', inplace=True)
sample_info

Unnamed: 0_level_0,instrument_name,note,expression,source,file_extension
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,Flute,Gb6,vib,Iowa2012,aif
2,Flute,Bb5,vib,Iowa2012,aif
3,Flute,G4,vib,Iowa2012,aif
4,Flute,Eb5,vib,Iowa2012,aif
5,Flute,A6,vib,Iowa2012,aif
...,...,...,...,...,...
2333,Crotale,Gb6,,Iowa2012,aif
2334,Crotale,A7,,Iowa2012,aif
2335,Crotale,E6,,Iowa2012,aif
2336,Crotale,Ab6,,Iowa2012,aif


In [7]:
data2 = data.join(sample_info)
data2

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,2043,2044,2045,2046,2047,instrument_name,note,expression,source,file_extension
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,0.003238,0.003516,0.003653,0.004583,0.005378,0.005542,0.005779,0.005876,0.006885,0.006476,...,0.074807,0.071793,0.074314,0.065015,0.060249,Flute,Gb6,vib,Iowa2012,aif
2,0.002859,0.004412,0.006860,0.010949,0.013922,0.015170,0.017266,0.020821,0.021021,0.021095,...,0.058267,0.060641,0.061230,0.057176,0.052379,Flute,Bb5,vib,Iowa2012,aif
3,0.470771,0.314912,0.120817,0.095654,0.085359,0.060197,0.037835,0.032554,0.027049,0.026489,...,0.003186,0.003226,0.003548,0.003635,0.003283,Flute,G4,vib,Iowa2012,aif
4,0.005541,0.009706,0.013444,0.023435,0.032274,0.034287,0.033264,0.040268,0.040223,0.034025,...,0.035940,0.036911,0.036216,0.038078,0.038489,Flute,Eb5,vib,Iowa2012,aif
5,0.003521,0.003417,0.003128,0.003775,0.003662,0.004113,0.004573,0.005411,0.006391,0.005843,...,0.039358,0.046034,0.049473,0.047220,0.040782,Flute,A6,vib,Iowa2012,aif
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2333,0.005842,0.006696,0.005504,0.008009,0.010848,0.008933,0.005405,0.005450,0.004055,0.004462,...,0.000230,0.000214,0.000199,0.000189,0.000159,Crotale,Gb6,,Iowa2012,aif
2334,0.011724,0.013896,0.011100,0.013689,0.022415,0.020919,0.010713,0.008141,0.006795,0.008993,...,0.000351,0.000316,0.000306,0.000334,0.000327,Crotale,A7,,Iowa2012,aif
2335,0.014983,0.020384,0.016735,0.018972,0.027308,0.021950,0.012783,0.012019,0.008091,0.008059,...,0.000752,0.000527,0.000734,0.000921,0.000812,Crotale,E6,,Iowa2012,aif
2336,0.010699,0.013480,0.009601,0.013410,0.019354,0.015173,0.010070,0.008354,0.005919,0.008271,...,0.000240,0.000230,0.000376,0.000763,0.000884,Crotale,Ab6,,Iowa2012,aif


In [10]:
import matplotlib.pyplot as plt
%matplotlib inline  
from sklearn import svm
from sklearn.model_selection import train_test_split

import numpy as np

In [11]:
from sklearn.ensemble import RandomForestClassifier as RFC

In [69]:
rfc = RFC(criterion='gini', n_estimators=50, min_impurity_decrease=0.1, max_depth=4, max_features=3)

In [80]:
from sklearn.svm import SVC

In [81]:
svc = SVC(kernel='rbf', gamma='scale')

In [70]:
X = data2[list(range(2047))]
y = data2['instrument_name']

In [95]:
from sklearn import decomposition

np.random.seed(5)

pca = decomposition.PCA(n_components=3)
pca.fit(X)
X2 = pca.transform(X)

In [96]:
X_train, X_test, y_train, y_test = train_test_split(X2, y, test_size=0.2)

In [97]:
rfc.fit(X_train, y_train)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
                       max_depth=4, max_features=3, max_leaf_nodes=None,
                       min_impurity_decrease=0.1, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=50,
                       n_jobs=None, oob_score=False, random_state=None,
                       verbose=0, warm_start=False)

In [98]:
svc.fit(X_train, y_train)

SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='scale', kernel='rbf',
    max_iter=-1, probability=False, random_state=None, shrinking=True,
    tol=0.001, verbose=False)

In [99]:
y_pred = svc.predict(X_test)

In [74]:
y_pred = rfc.predict(X_test)

In [75]:
y_train

sample_id
1301       Bass
1532    Marimba
1463       Bass
420      SopSax
1335       Bass
         ...   
1375       Bass
1620    Marimba
548     Trumpet
1442       Bass
343     Bassoon
Name: instrument_name, Length: 1869, dtype: object

In [100]:
results = list(zip(list(y_pred), list(y_train)))

In [101]:
len([x for x in results if x[0]==x[1]]) / len(results)

0.13034188034188035

In [102]:
results

[('Cello', 'Violin'),
 ('Marimba', 'Vibraphone'),
 ('Bass', 'BbClarinet'),
 ('Marimba', 'BassClarinet'),
 ('Bass', 'Marimba'),
 ('Marimba', 'Flute'),
 ('Xylophone', 'Tuba'),
 ('Marimba', 'bells'),
 ('Marimba', 'SopSax'),
 ('Marimba', 'Violin'),
 ('Marimba', 'Violin'),
 ('AltoSax', 'Vibraphone'),
 ('Marimba', 'Flute'),
 ('Xylophone', 'BbClarinet'),
 ('AltoSax', 'Marimba'),
 ('Marimba', 'Bass'),
 ('Marimba', 'Xylophone'),
 ('Tuba', 'Tuba'),
 ('Bass', 'Viola'),
 ('Marimba', 'Trumpet'),
 ('Marimba', 'Marimba'),
 ('Vibraphone', 'Viola'),
 ('Marimba', 'Bass'),
 ('Marimba', 'Vibraphone'),
 ('Bass', 'Bass'),
 ('Marimba', 'Vibraphone'),
 ('Marimba', 'BbClarinet'),
 ('Xylophone', 'Vibraphone'),
 ('Vibraphone', 'Bass'),
 ('Marimba', 'Marimba'),
 ('Vibraphone', 'Viola'),
 ('Vibraphone', 'Bass'),
 ('Marimba', 'Oboe'),
 ('Bass', 'Xylophone'),
 ('Marimba', 'Viola'),
 ('Marimba', 'Vibraphone'),
 ('Marimba', 'Trumpet'),
 ('Bass', 'Violin'),
 ('Marimba', 'Violin'),
 ('Cello', 'Violin'),
 ('Viola', 'Viol