Before running code, set working environment to **src** folder.

In [5]:
from train_model import train_classifier

# Training Classification Models

This notebook provides the interface to train (and potentially save) kick and drum classification models. Model training is performed by calling the **train_classifer()** function, which has parameters:

*   **model_name** (str): User-defined name of model.
*   **model_save** (bool): Will save model if True. Default is True.
*   **overwrite** (bool): Will overwrite existing models with same name if True. Default is False.
*   **drum_path** (str): File path to drum audio samples for training. Default is '../data/raw/drum'.
*   **kick_path** (str): File path to kick audio samples for training. Default is '../data/raw/kick'.
*   **x_train_labs** (list): Optional list of matches from data folders over which to train model. Remaining matches within each folder are held for testing. Default is None, meaning model is trained and tested using random partitions.
*   **x_percent** (float): Percentage of 0.4 second audio to train model. Default is 1.0.
*   **type** (string): Section of sound audio to retain if **x_percent** is not 1.0. Options include "start", "center", "end", or "random".
*   **noise_factor** (float): Articifical noise factor to add to raw audio. Default is 0.0.
*   **verbose** (bool): Outputs printed if True. Default is False.
*   **epochs** (int): Number of epochs to train model. Default is 3.
*   **batch_size** (int): Batch sizes to train model. Default is 32.
*   **validation_split** (float): Validation split to evaluate model. Default is 0.2.
*   **seed** (int): Seed for reproducability. Default is 1.
*   **plot_cm** (bool): Plots confusion matrix if True. Default is False.

## Base Model

The following code executes the baseline model for this project. The model is saved as "**base.keras**".

In [4]:
model, accuracy = train_classifier(model_name = "base")

Test accuracy: 0.9881305694580078, Test loss: 0.07220670580863953
Model successfully saved: ../models/base.h5


  saving_api.save_model(


## Further Illustrative Examples

The following code provides illustrative examples for training models. These models are purely illustrative and are not saved.

In [None]:
# Train model using first 0.24 seconds of the audio files
model, accuracy = train_classifier(model_name = "s24", model_save = False, x_percent = 0.4, type = "start")

# Train model using 10 epochs and batch size of 64
model, accuracy = train_classifier(model_name = "e10_b64", model_save = False, epochs = 10, batch_size = 64)

# Train model using only "Motherwell Drums" and "Southampton vs. Chelsea Kicks"
model, accuracy = train_classifier(model_name = "DrumMoth_KickSFCCFC", x_train_labs = ['Motherwell_Far_L_Drum_.wav', 'SFC_CFC_Kick_.wav'])