<a href="https://colab.research.google.com/github/monatis/trax-samples/blob/main/trax_speech_commands.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Audio Classification with Speech Commands Dataset in Trax
By [M. Yusuf Sarıgöz](https://github.com/monatis), Google Developer Expert on Machine Learning

## Introduction
[Trax](https://github.com/google/trax) is an end-to-end deep learning library focusing on clear code and speed. I loved its simple yet powerful API with minimum boilerplate. You can are new to Trax, you may want to check [official examples](https://github.com/google/trax/tree/master/trax/examples) and [API docs with tutorials](https://trax-ml.readthedocs.io/en/latest/) first. Most of the official examples in either image or textual domains, so I want to provide an example in the audio domain.

## Overview
This is roughly a re-implementation of [Simple Audio Recognition: Recognizing Keywords](https://www.tensorflow.org/tutorials/audio/simple_audio) TensorFlow tutorial, and you may want to refer to it whenever you think a concept or code snippet is not explained.

## Setup
First, we will install `trax` and `pydub` libraries. `pydub` will be used to decode audio later on.

In [None]:
!pip install -U trax pydub

## TPU Initialization
Any Trax code can be accelerated on either GPU or TPU. To achieve this, Trax can rely on `tensorflow-numpy` or `jax`, which is the default. We can connect to and initialize TPU as the accelerator of `jax` by running this cell.

In [None]:
# Run this cell to set TPU in Colab
import os
import jax
import requests
# Run this to get the TPU address.
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

## Imports and Hyperparameters
In the  following cells, we will import required modules from the `trax`package and then set some hyperparameters like `batch_size` and `num_epochs`.

In [None]:
import trax
from trax import fastmath
from trax.fastmath import numpy as jnp

In [None]:
batch_size = 256 # we can use a large batch as it runs on a TPU V2-8
input_size = 32 # we will extract STFT features from audio signal to and then resize features matrix to this size
num_epochs = 10
num_items = 85515 # number of train samples in the Speech Commands dataset.
num_steps = int(num_items // batch_size * num_epochs) # training loop will run this many steps
output_dir = '/content/output' # trainer will save model checkpoints to this directory

# remove output directory if you restart the notebook
!rm -rf {output_dir}

## Data Preparation
We will make use of `Serial` combinator from `trax.data` module to prepare the `Speech Commands` dataset as inputs to our model, and this function will extract STFT features from raw audio signals. Any function in a `trax.data.combinators.Serial` stack is supposed to accept a Python generator, apply zero or more transformations each one of the items from that generator and yield those transformations as a generator again.

In [None]:
import librosa
import numpy as np
import cv2
def stft(gen):
  it = iter(gen)
  for item in it:
    wave = np.asfortranarray(item[0], dtype=np.float32)
    wave = np.pad(wave, 16000-item[0].shape[0])[:16000]
    feat = np.abs(librosa.stft(wave, n_fft=255, hop_length=128))
    feat = cv2.resize(feat, (input_size, input_size))
    feat = feat.reshape(input_size, input_size, 1)
    yield (feat, item[1])

And, we will create train and evaluation dataset pipelines with the our function above. Be aware that it will download a large dataset from `tensorflow-datasets` (yes, you have access to a large collection of datasets by using `trax.data.TFDS`!).

In [None]:
train_ds = trax.data.Serial(
    trax.data.TFDS('speech_commands', 'commands', keys=['audio', 'label'], train=True), # load train dataset from tensorflow-datasets
    trax.data.Shuffle(512), # shuffle dataset in a buffer sized 512
    stft, # extract stft features
    trax.data.inputs.Batch(batch_size) # batch dataset to feed into our model
)

eval_ds = trax.data.Serial(
    trax.data.TFDS('speech_commands', 'commands', keys=['audio', 'label'], train=False), # load eval dataset from tensorflow-datasets
    stft, # extract features 
    trax.data.inputs.Batch(batch_size) # batch dataset
)

## Model Building
We will build our model with `Serial` combinator. You have other combinator options for this, but it will be the subject of another sample notebook. The model architecture is basically two convolutional layers with 32 and 64 filters, followed by a fully connected layer with 128 features. Relu is selected as a non-linear activation, and two dropout layers are introduced to reduce overfitting. Finally, a fully-connected output layer is used with 12 neurons (number of classes in the dataset).

In [None]:
from trax.layers import tl
model = tl.combinators.Serial(
    tl.BatchNorm(),
    tl.Conv(32, (3, 3)),
    tl.Relu(),
    tl.Conv(64, (3, 3)),
    tl.Relu(),
    tl.MaxPool(),
    tl.Dropout(0.1),
    tl.Flatten(),
    tl.Dense(128),
    tl.Dropout(0.2),
    tl.Dense(12)
)
print(model)

## Training
Trax allows us to define tasks for training and evaluation and provides a loop mechanism to run these tasks in the specified number. It also checkpoints the model in an output directory and logs metrics to TensorBoard under the same directory.

In [None]:
from trax.supervised import training
train_task = training.TrainTask(
    labeled_data=train_ds(), # training dataset we created in the data preparation section
    loss_layer=tl.CategoryCrossEntropy(), # yes, losses are ordinary layers
    optimizer=trax.optimizers.Adam(0.01), # we use Adam optimizer with a learning rate of 0.01
    n_steps_per_checkpoint=200 # save checkpoints and log metrics every 200 steps
)

eval_task = training.EvalTask(
    labeled_data=eval_ds(), # evaluation dataset
    metrics=[tl.CategoryCrossEntropy()], # evaluation metrics. you can also add tl.Accuracy() to this list
    n_eval_batches=20 # evaluate for 20 steps
)

Finally, create the loop and run it.

In [None]:
training_loop = training.Loop(
    model,
    train_task,
    eval_tasks=[eval_task],
    output_dir=output_dir # save checkpoints and TensorBoard logs in this directory
)
# run the loop. the first step might be slower, but subsequent ones will be much faster
training_loop.run(num_steps) 

Congrats! Training is complete, and the model was saved in `output_dir`.

In [None]:
!ls -l {output_dir}