# Setup

In [None]:
import os
import subprocess
import sys

from IPython import get_ipython

IS_COLAB = "google.colab" in str(get_ipython())

In [None]:
if IS_COLAB:
    module_dir = "./magic-packet"
    if not os.path.exists(module_dir):
        subprocess.run(
            [
                "git",
                "clone",
                "-q",
                "https://github.com/jjgp/magic-packet.git",
                module_dir,
            ]
        )

    subprocess.run(["pip", "-q", "install", "-e", module_dir], capture_output=True)

    content_dir = "/content/magic-packet/"
    if content_dir not in sys.path:
        sys.path.insert(0, "/content/magic-packet/")
else:
    sys.path.insert(0, os.path.abspath(".."))

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from IPython import display

import magicpacket.dataset.mini_speech_commands  # noqa: F401
from magicpacket.dataset import features
from magicpacket.models import simple_audio_model

In [None]:
# Set the seed value for experiment reproducibility.
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)

# Dataset

In [None]:
(train_ds, val_ds, test_ds), ds_info = tfds.load(
    "mini_speech_commands",
    split=["train[:80%]", "train[80%:90%]", "train[90%:]"],
    shuffle_files=True,
    with_info=True,
)

In [None]:
def plot_from_ds(ds, ds_info, rows=3, cols=3):
    n = rows * cols
    _, axes = plt.subplots(rows, cols, figsize=(10, 12))

    for i, example in enumerate(ds.take(n)):
        audio, label = example["audio"], example["label"]
        r = i // cols
        c = i % cols
        ax = axes[r][c]
        normalized = features.normalize(audio)
        ax.plot(normalized.numpy())
        ax.set_yticks(np.arange(-1.2, 1.2, 0.2))
        name = ds_info.features["label"].names[label]
        ax.set_title(name)
    plt.show()

In [None]:
plot_from_ds(train_ds, ds_info)

# Feature extraction example

In [None]:
def plot_spectrogram(spectrogram, ax):
    if len(spectrogram.shape) > 2:
        assert len(spectrogram.shape) == 3
        spectrogram = np.squeeze(spectrogram, axis=-1)
    # Convert the frequencies to log scale and transpose, so that the time is
    # represented on the x-axis (columns).
    # Add an epsilon to avoid taking a log of zero.
    log_spec = np.log(spectrogram.T + np.finfo(float).eps)
    height = log_spec.shape[0]
    width = log_spec.shape[1]
    X = np.linspace(0, np.size(spectrogram), num=width, dtype=int)
    Y = range(height)
    ax.pcolormesh(X, Y, log_spec)

In [None]:
for example in train_ds.take(1):
    audio, label = example["audio"], example["label"]
    # the waveform is normalized to the range [-1, 1]
    wavename = ds_info.features["label"].names[label]
    waveform = features.normalize(audio)
    mfcc = features.mfcc(waveform)

print("Name:", wavename)
print("Waveform shape:", waveform.shape)
print("Audio playback")
display.display(display.Audio(waveform, rate=16000))

In [None]:
_, axes = plt.subplots(3, figsize=(10, 12))

timescale = np.arange(waveform.shape[0])
axes[0].plot(timescale, waveform.numpy())
axes[0].set_title(wavename)
axes[0].set_xlim([0, 16000])

spectrogram = features.spectrogram(waveform)
plot_spectrogram(spectrogram.numpy(), axes[1])
axes[1].set_title("spectrogram")

mfcc = features.mfcc(S=spectrogram)
height, width = mfcc.shape
X = np.linspace(0, np.size(mfcc), num=width, dtype=int)
Y = range(height)
axes[2].pcolormesh(X, Y, mfcc.numpy())
axes[2].set_title("mfcc")

plt.show()

# Build and train model

## Preprocess datasets

In [None]:
def get_mfcc_and_label(example):
    audio, label = example["audio"], example["label"]
    normalized = features.normalize(audio)
    # Add a `channels` dimension, so that the spectrogram can be used
    # as image-like input data with convolution layers (which expect
    # shape (`batch_size`, `height`, `width`, `channels`).
    mfcc = features.mfcc(normalized)[..., tf.newaxis]
    return mfcc, label


def preprocess_dataset(ds):
    return ds.map(map_func=get_mfcc_and_label, num_parallel_calls=tf.data.AUTOTUNE)


train_ds = preprocess_dataset(train_ds)
val_ds = preprocess_dataset(val_ds)
test_ds = preprocess_dataset(test_ds)

## Model

In [None]:
for mfcc, _ in train_ds.take(1):
    input_shape = mfcc.shape
print("Input shape:", input_shape)

n_labels = len(ds_info.features["label"].names)
model = simple_audio_model(train_ds, input_shape, n_labels)
model.summary()

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

## Fit

In [None]:
batch_size = 64
train_ds = train_ds.batch(batch_size)
val_ds = val_ds.batch(batch_size)

In [None]:
train_ds = train_ds.cache().prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.cache().prefetch(tf.data.AUTOTUNE)

In [None]:
EPOCHS = 10
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=tf.keras.callbacks.EarlyStopping(verbose=1, patience=2),
)

In [None]:
metrics = history.history
plt.plot(history.epoch, metrics["loss"], metrics["val_loss"])
plt.legend(["loss", "val_loss"])
plt.show()

## Evaluate

In [None]:
# TODO: use test_ds to evaluate the model