In [None]:
import os
import pathlib
import sys

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
sys.path.insert(0, os.path.abspath(".."))

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from IPython import display

from magicpacket.dataset import features
from magicpacket.models import simple_audio_model

In [None]:
# Set the seed value for experiment reproducibility.
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)

In [None]:
DATASET_PATH = "../data/mini_speech_commands"

# Dataset

## Download dataset and get labels

In [None]:
data_dir = pathlib.Path(DATASET_PATH)
origin = "http://storage.googleapis.com/download.tensorflow.org/data/mini_speech_commands.zip"  # noqa: E501
if not data_dir.exists():
    tf.keras.utils.get_file(
        "mini_speech_commands.zip",
        origin=origin,
        extract=True,
        cache_dir=".",
        cache_subdir="data",
    )

labels = np.array(tf.io.gfile.listdir(str(data_dir)))
labels = labels[labels != "README.md"]
print("labels:", labels)

## Dataset methods

In [None]:
def get_datasets(path, splits):
    datasets = []
    file_paths = tf.io.gfile.glob(path + "/*/*")
    file_paths = tf.random.shuffle(file_paths)
    n_rows, split_start = len(file_paths), 0
    for split in splits:
        n_split = int(split * n_rows)
        split_end = split_start + n_split
        dataset = tf.data.Dataset.from_tensor_slices(file_paths[split_start:split_end])
        datasets.append(dataset)
        split_start += n_split
    return datasets


def get_label(file_path):
    return tf.strings.split(input=file_path, sep=os.path.sep)[-2]


def get_waveform(file_path):
    tensor = tf.io.read_file(file_path)
    # The decode wave will be normalized to the range [-1, 1]
    audio, _ = tf.audio.decode_wav(contents=tensor)
    return tf.squeeze(audio, axis=-1)


def get_waveform_and_label(file_path):
    return get_waveform(file_path), get_label(file_path)


def plot_from_ds(ds, rows=3, cols=3):
    n = rows * cols
    _, axes = plt.subplots(rows, cols, figsize=(10, 12))

    for i, (audio, label) in enumerate(ds.take(n)):
        r = i // cols
        c = i % cols
        ax = axes[r][c]
        ax.plot(audio.numpy())
        ax.set_yticks(np.arange(-1.2, 1.2, 0.2))
        label = label.numpy().decode("utf-8")
        ax.set_title(label)
    plt.show()

## Read dataset into tf.Dataset

In [None]:
train_ds, val_ds, test_ds = get_datasets(DATASET_PATH, (0.8, 0.1, 0.1))
waveform_ds = train_ds.map(
    map_func=get_waveform_and_label, num_parallel_calls=tf.data.AUTOTUNE
)

## Example waveforms

In [None]:
plot_from_ds(waveform_ds)

# Feature Extraction

## Example MFCC

In [None]:
for waveform, label in waveform_ds.take(1):
    label = label.numpy().decode("utf-8")
    waveform = waveform
    mfcc = features.mfcc(waveform)

print("Label:", label)
print("Waveform shape:", waveform.shape)
print("Audio playback")
display.display(display.Audio(waveform, rate=16000))

In [None]:
def plot_spectrogram(spectrogram, ax):
    if len(spectrogram.shape) > 2:
        assert len(spectrogram.shape) == 3
        spectrogram = np.squeeze(spectrogram, axis=-1)
    # Convert the frequencies to log scale and transpose, so that the time is
    # represented on the x-axis (columns).
    # Add an epsilon to avoid taking a log of zero.
    log_spec = np.log(spectrogram.T + np.finfo(float).eps)
    height = log_spec.shape[0]
    width = log_spec.shape[1]
    X = np.linspace(0, np.size(spectrogram), num=width, dtype=int)
    Y = range(height)
    ax.pcolormesh(X, Y, log_spec)

In [None]:
_, axes = plt.subplots(3, figsize=(10, 12))

timescale = np.arange(waveform.shape[0])
axes[0].plot(timescale, waveform.numpy())
axes[0].set_title(label)
axes[0].set_xlim([0, 16000])

spectrogram = features.spectrogram(waveform)
plot_spectrogram(spectrogram.numpy(), axes[1])
axes[1].set_title("spectrogram")

mfcc = features.mfcc(S=spectrogram)
height, width = mfcc.shape
X = np.linspace(0, np.size(mfcc), num=width, dtype=int)
Y = range(height)
axes[2].pcolormesh(X, Y, mfcc.numpy())
axes[2].set_title("mfcc")

plt.show()

# Build and train model

## Preprocess datasets

In [None]:
def get_mfcc_and_label_id(file_path):
    waveform, label = get_waveform_and_label(file_path)
    label_id = tf.argmax(label == labels)  # labels is global!
    # Add a `channels` dimension, so that the spectrogram can be used
    # as image-like input data with convolution layers (which expect
    # shape (`batch_size`, `height`, `width`, `channels`).
    mfcc = features.mfcc(waveform)[..., tf.newaxis]
    return mfcc, label_id


def preprocess_dataset(ds):
    return ds.map(map_func=get_mfcc_and_label_id, num_parallel_calls=tf.data.AUTOTUNE)


train_ds = preprocess_dataset(train_ds)
val_ds = preprocess_dataset(val_ds)
test_ds = preprocess_dataset(test_ds)

## Model

In [None]:
for mfcc, _ in train_ds.take(1):
    input_shape = mfcc.shape
print("Input shape:", input_shape)

model = simple_audio_model(train_ds, input_shape, len(labels))
model.summary()

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

## Fit

In [None]:
batch_size = 64
train_ds = train_ds.batch(batch_size)
val_ds = val_ds.batch(batch_size)

In [None]:
train_ds = train_ds.cache().prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.cache().prefetch(tf.data.AUTOTUNE)

In [None]:
EPOCHS = 10
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=tf.keras.callbacks.EarlyStopping(verbose=1, patience=2),
)

In [None]:
metrics = history.history
plt.plot(history.epoch, metrics["loss"], metrics["val_loss"])
plt.legend(["loss", "val_loss"])
plt.show()

# Evaluation

In [None]:
test_audio = []
test_labels = []

for audio, label in test_ds:
    test_audio.append(audio.numpy())
    test_labels.append(label.numpy())

test_audio = np.array(test_audio)
test_labels = np.array(test_labels)

In [None]:
y_pred = np.argmax(model.predict(test_audio), axis=1)
y_true = test_labels

test_acc = sum(y_pred == y_true) / len(y_true)
print(f"Test set accuracy: {test_acc:.0%}")

In [None]:
confusion_mtx = tf.math.confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_mtx, xticklabels=labels, yticklabels=labels, annot=True, fmt="g")
plt.xlabel("Prediction")
plt.ylabel("Label")
plt.show()

In [None]:
# sample_file = data_dir/'no/01bb6a2a_nohash_0.wav'

# sample_ds = preprocess_dataset(get_datasets([str(sample_file)], (1))

# for mfcc, label in sample_ds.batch(1):
#     prediction = model(mfcc)
#     plt.bar(commands, tf.nn.softmax(prediction[0]))
#     plt.title(f'Predictions for "{commands[label[0]]}"')
#     plt.show()