# Make scikitlearn classifier to distiguish samples

If you have heard ANY music made in the last 40 years you have heard the sounds of these two classic drum machines:
## [Roland TR-808](https://en.wikipedia.org/wiki/Roland_TR-808)

<img src="images/tr808.jpg" width="800" />

## [Roland TR-909](https://en.wikipedia.org/wiki/Roland_TR-909)

<img src="images/tr909.png" width="800" />

## Background
These two machines designed and marketed to replace the drummer in your band were met with such distain initially that their respective values plummeted from thier original list prices of around 900 USD retail on first release to around 50 USD on the secondary market only a few sort years after. These cheap machines were now availible to young people all over the United States and became a staple in the hiphop community. The synthetic drum sounds of the TR-808 and TR-909 have inspired many clones over the years and are still used and highly sought after to this day.

So I've been investigating sound synthesis recently and I came across a website [drumkito](https://www.drumkito.com/) that has some great sample packs based on iconic drums machines over the years.

Armed with this library of data I will attempt to create a classifier that can distiguish which type of sound (kick drum, high tom, ect.) any given sample sounds like

The application of this tool could be used in multiple ways but we will implement it to help organize an adhoc sample library that we will be able to use with the ipytone jupyter lab based drum pad that I'll be posting soon.


From the text file found in this download of the samples: 
https://www.drumkito.com/sample-packs/roland-tr-909-sample-pack/

---------------------

Sample Identification
---------------------
```
Instrument (first letter)       Settings (in order)             Directory

b       bass drum               t=tune, a=attack, d=decay       \bassdm         
s       snare drum              t=tune, t=tone, s=snappy        \snaredm
l       low tom                 t=tune, d=decay                 \lowtomdm
m       mid tom                 t=tune, d=decay                 \midtomdm
h       high tom                t=tune, d=decay                 \hitomdm
rim     rimshot                 #=velocity level                \rimshot
hand    handclap                #=velocity level                \handclap
hhc     closed high hat         d=decay                         \closedhh
hho     open high hat           d=decay                         \openhh
csh     crash cymbal            t=tune                          \crshcym
ride    ride cymbal             t=tune                          \ridecym
clop    closed->open hh         #=combination number            \misc
opcl    open->closed hh         #=combination number            \misc

```

In [2]:
import os
import numpy as np
import pandas as pd

import ipywidgets as widgets
from scipy.io import wavfile

import matplotlib.pyplot as plt

In [3]:
sample_dir = "Roland TR-909/."

sample_list = os.listdir(sample_dir)

sample_names = [i.split('.')[0] for i in sample_list]

In [6]:
sample_labels = """b       bass_drum
s       snare_drum
l       low_tom
m       mid_tom
ht       high_tom
rim     rimshot
hand    handclap
hhc     closed_high_hat
hho     open_high_hat
csh     crash_cymbal
ride    ride_cymbal
clop    closed_open_hh
opcl    open_closed_hh"""

sample_label_dict = dict([i.split() for i in sample_labels.split('\n')])

In [7]:
new_keys = [i.upper() for i in sample_label_dict.keys()]

sample_label_dict = dict(zip(new_keys, sample_label_dict.values()))

sample_label_dict

{'B': 'bass_drum',
 'S': 'snare_drum',
 'L': 'low_tom',
 'M': 'mid_tom',
 'HT': 'high_tom',
 'RIM': 'rimshot',
 'HAND': 'handclap',
 'HHC': 'closed_high_hat',
 'HHO': 'open_high_hat',
 'CSH': 'crash_cymbal',
 'RIDE': 'ride_cymbal',
 'CLOP': 'closed_open_hh',
 'OPCL': 'open_closed_hh'}

# Link the sample file name to the label

In [10]:
sample_df = pd.DataFrame(sample_names[:-1], columns=["sample_name"])

sample_label_df = pd.DataFrame([sample_label_dict.keys(), sample_label_dict.values()]).T

# Create a function to check the file title and link the label
def label_check(sample_name):
    mask = [i in sample_name for i in sample_label_dict.keys()]
    key = sample_label_df.loc[mask][0].values[0]
    return sample_label_dict[key]

sample_df['label'] = sample_df.sample_name.apply(label_check)

sample_df.head()

Unnamed: 0,sample_name,label
0,BT0A0A7,bass_drum
1,BT0A0D0,bass_drum
2,BT0A0D3,bass_drum
3,BT0A0DA,bass_drum
4,BT0AAD0,bass_drum


# Create a data frame of each wavfile's length and data points

In [11]:
wav_df = sample_df.sample_name.apply(lambda x: wavfile.read(sample_dir[:-1]+x+".WAV")[1])

wav_sample_df = pd.DataFrame([wav_df.apply(len), wav_df]).T

wav_sample_df.columns = ['length', 'data']

wav_sample_df.head()

Unnamed: 0,length,data
0,12554,"[18770, 17990, 26116, 0, 16727, 17750, 28006, ..."
1,4489,"[18770, 17990, 9686, 0, 16727, 17750, 28006, 8..."
2,7569,"[18770, 17990, 15316, 0, 16727, 17750, 28006, ..."
3,21992,"[0, 0, 295, 445, 613, 821, 1135, 1482, 1635, 1..."
4,4792,"[0, 0, 733, 1301, 1851, 1947, 1786, 1888, 2292..."


# Now we need to find a way to make our samples comparible

- One way to do this, and the simplest thatI thought of on the fly is to just make them all the same length by adding zeros
- Note: I am unsure how `scipy.wavfile.read` handles sample rate vs length when outputting to a matrix
- We are just going to assume that every file has the same sample rate so we will make them the same length

In [13]:
max_sample_len = wav_sample_df.data.apply(len).max()

def normalize_by_adding_zeros(array):
    max_sample_len = wav_sample_df.data.apply(len).max()
    if len(array) < max_sample_len:
        difference = max_sample_len - len(array)
        # print(difference)
        new_array = np.append(array, np.zeros(difference))
        
        return new_array
    else:
        return array[:max_sample_len]

normalized_data = wav_sample_df.data.apply(normalize_by_adding_zeros)

In [14]:
# Check to see if we normalized all of the data
sum(normalized_data.apply(len) == max_sample_len)

160

In [15]:
data = pd.DataFrame(normalized_data.tolist())

# Create our classifier

In [16]:
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import GradientBoostingClassifier

X_train, X_test, y_train, y_test = train_test_split(data, sample_df.label, test_size=0.2, random_state=42)

# Create and train a classifier
clf = DecisionTreeClassifier()

# clf = GradientBoostingClassifier(n_estimators=100,
#                                  learning_rate=1.0,
#                                  max_depth=1,
#                                  random_state=0)

clf.fit(X_train, y_train)
clf.score(X_test, y_test)

0.65625

In [17]:
# Make predictions on the test set
y_pred = clf.predict(X_test)


# Let's try the 909 classifier on the the 808

In [22]:
import glob

sample_dir_808 = "Roland TR-808/*/*.WAV"
sample_files_808 = glob.glob(sample_dir_808)

# Let's pick a file that I know is a bass drum
bd_808 = wavfile.read(sample_files_808[5])[1]

In [23]:
clf.predict([normalize_by_adding_zeros(bd_808)])

array(['snare_drum'], dtype=object)

# Conclusions:
- The classifer using the Roland 909 data failed to acurately predict a sample from the Roland 808 sample set
- There are a couple of things that we can do to improve our classifer
## How can we make better classifications
- we need to optimize this classifier
    - Every time that I have run this with a default `DecisionTreeClassifier` settings we are getting scores ranging from 0.60 to 0.75
    - We can try using hyperopt or optuna to find the best method and hyperparameters for making a classifier
    - We also need to look into the normalization method and the differences between the sample files