# Training a microWakeWord Model

This notebook steps you through training a basic mWW model. It is intended as a starting point for advanced users.

**The model generated will most likely not be usable for everyday use, you will most likely have to experiment with many different settings to obtain a decent model!**

In the comment at the start of certain blocks, I note some specific settings to consider modifying.

This runs on Google Colab, but is extremely slow compared to training on a local GPU. If you must use Colab, be sure to Change the runtime type to a GPU. Even then, it still slow!

In [7]:
# Installs microWakeWord. If running on Google Colab, be sure to restart the runtime after this is finished.

!git clone -b november-update https://github.com/kahrendt/microWakeWord
!pip install -e .

fatal: destination path 'microWakeWord' already exists and is not an empty directory.
/content/microWakeWord
Obtaining file:///content/microWakeWord
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: microwakeword
  Building editable for microwakeword (pyproject.toml) ... [?25l[?25hdone
  Created wheel for microwakeword: filename=microwakeword-0.0.1-0.editable-py3-none-any.whl size=10833 sha256=024c4ba496b0cf5550b939536471e05ff7cfc42f6d5d3b3adafcf197426081ff
  Stored in directory: /tmp/pip-ephem-wheel-cache-waw9s9bb/wheels/55/b3/ef/cf7f00aa13967f85dca43ecbb53f3c04f2fba8e4f749467771
Successfully built microwakeword
Installing collected packages: microwakeword
  Attempting uninstall: microwakeword
    Found existing installation: microwakeword 0

In [1]:
# Generates 1 sample of the target word for manual verification.

target_word = 'khum_puter'  # Phonetic spellings may produce better samples

import os
import sys
from IPython.display import Audio
if not os.path.exists("./piper-sample-generator"):
    !git clone https://github.com/rhasspy/piper-sample-generator
    !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'

    # Install system dependencies
    !pip install piper-phonemize
    !pip install webrtcvad

    if "piper-sample-generator/" not in sys.path:
        sys.path.append("piper-sample-generator/")

!python3 piper-sample-generator/generate_samples.py {target_word} \
--max-samples 1 \
--batch-size 1 \
--output-dir generated_samples

Audio("generated_samples/0.wav", autoplay=True)

Cloning into 'piper-sample-generator'...
remote: Enumerating objects: 124, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (15/15), done.[K
remote: Total 124 (delta 28), reused 22 (delta 22), pack-reused 87 (from 1)[K
Receiving objects: 100% (124/124), 1.03 MiB | 19.52 MiB/s, done.
Resolving deltas: 100% (51/51), done.
--2024-11-22 15:11:28--  https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/642029941/73f4af3c-7cf8-4547-a7b9-3bd29e7f3c33?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20241122%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20241122T151128Z&X-Amz-Expires=300&X-Amz-Signature=26d3763c81df58b805ac

In [2]:
# Generates a larger amount of wake word samples.
# Start here when trying to improve your model.
# See https://github.com/rhasspy/piper-sample-generator for the full set of
# parameters. In particular, experiment with noise-scales and noise-scale-ws
# along with generating negative sample similar to the wake word

!python3 piper-sample-generator/generate_samples.py {target_word} \
--max-samples 2000 \
--batch-size 100 \
--output-dir generated_samples

DEBUG:__main__:Loading /content/piper-sample-generator/models/en_US-libritts_r-medium.pt
  torch_model = torch.load(model_path)
INFO:__main__:Successfully loaded the model
DEBUG:__main__:CUDA available, using GPU
DEBUG:__main__:Batch 1/20 complete
DEBUG:__main__:Batch 2/20 complete
DEBUG:__main__:Batch 3/20 complete
DEBUG:__main__:Batch 4/20 complete
DEBUG:__main__:Batch 5/20 complete
DEBUG:__main__:Batch 6/20 complete
DEBUG:__main__:Batch 7/20 complete
DEBUG:__main__:Batch 8/20 complete
DEBUG:__main__:Batch 9/20 complete
DEBUG:__main__:Batch 10/20 complete
DEBUG:__main__:Batch 11/20 complete
DEBUG:__main__:Batch 12/20 complete
DEBUG:__main__:Batch 13/20 complete
DEBUG:__main__:Batch 14/20 complete
DEBUG:__main__:Batch 15/20 complete
DEBUG:__main__:Batch 16/20 complete
DEBUG:__main__:Batch 17/20 complete
DEBUG:__main__:Batch 18/20 complete
DEBUG:__main__:Batch 19/20 complete
DEBUG:__main__:Batch 20/20 complete
INFO:__main__:Done


In [3]:
# Downloads audio data for augmentation
# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024
#
# **Important note!** The data downloaded here has a mixture of difference
# licenses and usage restrictions. As such, any custom models trained with this
# data should be considered as appropriate for **non-commercial** personal use only.


import datasets
import scipy
import os

import numpy as np

from pathlib import Path
from tqdm import tqdm

## Download MIR RIR data

output_dir = "./mit_rirs"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    rir_dataset = datasets.load_dataset("davidscripka/MIT_environmental_impulse_responses", split="train", streaming=True)
    # Save clips to 16-bit PCM wav files
    for row in tqdm(rir_dataset):
        name = row['audio']['path'].split('/')[-1]
        scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))

## Download noise and background audio

# Audioset Dataset (https://research.google.com/audioset/dataset/index.html)
# Download one part of the audioset .tar files, extract, and convert to 16khz
# For full-scale training, it's recommended to download the entire dataset from
# https://huggingface.co/datasets/agkphysics/AudioSet, and
# even potentially combine it with other background noise datasets (e.g., FSD50k, Freesound, etc.)

if not os.path.exists("audioset"):
    os.mkdir("audioset")

    fname = "bal_train09.tar"
    out_dir = f"audioset/{fname}"
    link = "https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/" + fname
    !wget -O {out_dir} {link}
    !cd audioset && tar -xf bal_train09.tar

    output_dir = "./audioset_16k"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    # Save clips to 16-bit PCM wav files
    audioset_dataset = datasets.Dataset.from_dict({"audio": [str(i) for i in Path("audioset/audio").glob("**/*.flac")]})
    audioset_dataset = audioset_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
    for row in tqdm(audioset_dataset):
        name = row['audio']['path'].split('/')[-1].replace(".flac", ".wav")
        scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))

# Free Music Archive dataset
# https://github.com/mdeff/fma

output_dir = "./fma"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    fname = "fma_xs.zip"
    link = "https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/" + fname
    out_dir = f"fma/{fname}"
    !wget -O {out_dir} {link}
    !cd {output_dir} && unzip -q {fname}

    output_dir = "./fma_16k"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    # Save clips to 16-bit PCM wav files
    fma_dataset = datasets.Dataset.from_dict({"audio": [str(i) for i in Path("fma/fma_small").glob("**/*.mp3")]})
    audioset_dataset = audioset_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
    for row in tqdm(audioset_dataset):
        name = row['audio']['path'].split('/')[-1].replace(".flac", ".wav")
        scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/936 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/270 [00:00<?, ?it/s]

270it [02:44,  1.64it/s]


--2024-11-22 15:15:16--  https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train09.tar
Resolving huggingface.co (huggingface.co)... 13.35.210.61, 13.35.210.77, 13.35.210.114, ...
Connecting to huggingface.co (huggingface.co)|13.35.210.61|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.hf.co/repos/be/e0/bee0f1da8c82b29f9518546133bb1ea7f58db75f059d5dc36fc0c28464d5386c/9ab88550e77c8289acf71d19ae3f254fadad16f3cea305ef08a16949a928acf9?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27bal_train09.tar%3B+filename%3D%22bal_train09.tar%22%3B&response-content-type=application%2Fx-tar&Expires=1732547716&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMjU0NzcxNn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9iZS9lMC9iZWUwZjFkYThjODJiMjlmOTUxODU0NjEzM2JiMWVhN2Y1OGRiNzVmMDU5ZDVkYzM2ZmMwYzI4NDY0ZDUzODZjLzlhYjg4NTUwZTc3YzgyODlhY2Y3MWQxOWFlM2YyNTRmYWRhZDE2ZjNjZWEzMDVlZj

100%|██████████| 685/685 [00:37<00:00, 18.05it/s]


--2024-11-22 15:16:04--  https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip
Resolving huggingface.co (huggingface.co)... 13.35.210.77, 13.35.210.61, 13.35.210.114, ...
Connecting to huggingface.co (huggingface.co)|13.35.210.77|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.hf.co/repos/b6/94/b6949420efe7987c075219703fbd5649e8e30f4aa4783eb019be14dfc9e7f52e/e5876c13cfb0f7ef668327c75d1c40bc4a2ed9d5b8e62ce383d093319c9ff663?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27fma_xs.zip%3B+filename%3D%22fma_xs.zip%22%3B&response-content-type=application%2Fzip&Expires=1732547764&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMjU0Nzc2NH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zL2I2Lzk0L2I2OTQ5NDIwZWZlNzk4N2MwNzUyMTk3MDNmYmQ1NjQ5ZThlMzBmNGFhNDc4M2ViMDE5YmUxNGRmYzllN2Y1MmUvZTU4NzZjMTNjZmIwZjdlZjY2ODMyN2M3NWQxYzQwYmM0YTJlZDlkNWI4ZTYyY2UzODNkMDkzMzE5

100%|██████████| 685/685 [00:37<00:00, 18.30it/s]


In [4]:
# Sets up the augmentations.
# To improve your model, expeirment with these settings and use more sources of
# background clips.

from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration

clips = Clips(input_directory='generated_samples',
              file_pattern='*.wav',
              max_clip_duration_s=None,
              remove_silence=False,
              random_split_seed=10,
              split_count=0.1,
              )
augmenter = Augmentation(augmentation_duration_s=3.2,
                         augmentation_probabilities = {
                                "SevenBandParametricEQ": 0.1,
                                "TanhDistortion": 0.1,
                                "PitchShift": 0.1,
                                "BandStopFilter": 0.1,
                                "AddColorNoise": 0.1,
                                "AddBackgroundNoise": 0.75,
                                "Gain": 1.0,
                                "RIR": 0.5,
                            },
                         impulse_paths = ['mit_rirs'],
                         background_paths = ['fma_16k', 'audioset_16k'],
                         background_min_snr_db = -5,
                         background_max_snr_db = 10,
                         min_jitter_s = 0.195,
                         max_jitter_s = 0.205,
                         )


In [2]:
# Augment a random clip and play it back to verify it works well

import IPython
from microwakeword.audio.audio_utils import save_clip

random_clip = clips.get_random_clip()
augmented_clip = augmenter.augment_clip(random_clip)
save_clip(augmented_clip, 'augmented_clip.wav')

IPython.display.display(IPython.display.Audio("augmented_clip.wav"))

In [5]:
# Augmented samples and save them the training, validation, and testing sets

import os
from mmap_ninja.ragged import RaggedMmap

output_dir = 'generated_augmented_features'

if not os.path.exists(output_dir):
    os.mkdir(output_dir)

splits = ["training", "validation", "testing"]
for split in splits:
  out_dir = os.path.join(output_dir, split)
  if not os.path.exists(out_dir):
      os.mkdir(out_dir)


  split_name = "train"
  repitition = 2

  spectrograms = SpectrogramGeneration(clips=clips,
                                     augmenter=augmenter,
                                     slide_frames=10,    # Uses the same spectrogram repeatedly, just shifted over by one frame. This simulates the streaming inferences while training/validating in nonstreaming mode.
                                     step_ms=10,
                                     )
  if split == "validation":
    split_name = "validation"
    repitition = 1
  elif split == "testing":
    split_name = "test"
    repitition = 1
    spectrograms = SpectrogramGeneration(clips=clips,
                                     augmenter=augmenter,
                                     slide_frames=1,    # The testing set uses the streaming version of the model, so no artificial repetition is necessary
                                     step_ms=10,
                                     )

  RaggedMmap.from_generator(
      out_dir=os.path.join(out_dir, 'wakeword_mmap'),
      sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repitition),
      batch_size=100,
      verbose=True,
  )

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [None]:
# Downloads pregenerated spectrogram features (made for microWakeWord in particular) for various negative datasets

output_dir = './negative_datasets'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    link_root = "https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/"
    filenames = ['dinner_party.zip', 'no_speech.zip', 'speech.zip']
    for fname in filenames:
        link = link_root + fname

        zip_path = f"negative_datasets/{fname}"
        !wget -O {zip_path} {link}
        !unzip -q {zip_path} -d {output_dir}

In [1]:
# Save a yaml config that controls the training process
# These hyperparamters can make a huge different in model quality.
# Experiment with sampling and penalty weights and increase the number of
# training steps.

import yaml
import os

config = {}

config["window_step_ms"] = 10

config["train_dir"] = (
    "trained_models/wakeword"
)


# Each feature_dir should have at least one of the following folders with this structure:
#  training/
#    ragged_mmap_folders_ending_in_mmap
#  testing/
#    ragged_mmap_folders_ending_in_mmap
#  testing_ambient/
#    ragged_mmap_folders_ending_in_mmap
#  validation/
#    ragged_mmap_folders_ending_in_mmap
#  validation_ambient/
#    ragged_mmap_folders_ending_in_mmap
#
#  sampling_weight: Weight for choosing a spectrogram from this set in the batch
#  penalty_weight: Penalizing weight for incorrect predictions from this set
#  truth: Boolean whether this set has positive samples or negative samples
#  truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated
#       - random: choose a random portion of the entire spectrogram - useful for long negative samples
#       - truncate_start: remove the start of the spectrogram
#       - truncate_end: remove the end of the spectrogram
#       - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets

config["features"] = [
    {
        "features_dir": "generated_augmented_features",
        "sampling_weight": 2.0,
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/speech",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/dinner_party",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/no_speech",
        "sampling_weight": 5.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
]

# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps
config["training_steps"] = [10000]

# Penalizing weight for incorrect class predictions - lists that correspond to training steps
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [20]

config["learning_rates"] = [
    0.001,
]  # Learning rates for Adam optimizer - list that corresponds to training steps
config["batch_size"] = 128

config["time_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["time_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps
config["freq_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["freq_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps

config["eval_step_interval"] = (
    500  # Test the validation sets after every this many steps
)
config["clip_duration_ms"] = (
    1500  # Maximum length of wake word that the streaming model will accept
)

# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization
# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize
# Available metrics:
#   - "loss" - cross entropy error on validation set
#   - "accuracy" - accuracy of validation set
#   - "recall" - recall of validation set
#   - "precision" - precision of validation set
#   - "false_positive_rate" - false positive rate of validation set
#   - "false_negative_rate" - false negative rate of validation set
#   - "ambient_false_positives" - count of false positives from the split validation_ambient set
#   - "ambient_false_positives_per_hour" - estimated number of false positives per hour on the split validation_ambient set
config["target_minimization"] = 0.9
config["minimization_metric"] = None  # Set to None to disable

config["maximization_metric"] = "average_viable_recall"

with open(os.path.join("training_parameters.yaml"), "w") as file:
    documents = yaml.dump(config, file)

In [3]:
# Trains a model. When finished, it will quantize and convert the model to a
# streaming version suitable for on device detection.
# It will resume if stopped, but it will start over at the configured training
# steps in the yaml file.
# Change --train 0 to only convert and test the best-weighted model.
# On Google colab, it doesn't print the mini-batch results, so it may appear
# stuck for several minutes! Additionally, it is very slow compared to training
# on a local GPU.

!python -m microwakeword.model_train_eval \
--training_config='training_parameters.yaml' \
--train 0 \
--restore_checkpoint 1 \
--test_tf_nonstreaming 0 \
--test_tflite_nonstreaming 0 \
--test_tflite_nonstreaming_quantized 0 \
--test_tflite_streaming 0 \
--test_tflite_streaming_quantized 1 \
--use_weights "best_weights" \
mixednet \
--pointwise_filters "64,64,64,64" \
--repeat_in_block  "1, 1, 1, 1" \
--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \
--residual_connection "0,0,0,0" \
--first_conv_filters 32 \
--first_conv_kernel_size 5 \
--stride 3

2024-11-22 16:28:39.936736: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-22 16:28:39.937791: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-22 16:28:40.098366: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-22 16:28:40.418295: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
INFO:absl:Loading and analyzing data sets.
20

In [6]:
# Downloads the model file. To use on the device, you need to write a Model JSON
# file. See https://esphome.io/components/micro_wake_word for the documentation
# and https://github.com/esphome/micro-wake-word-models/tree/main/models/v2 for
# examples. Adjust the probability threshold based on the test results obtained
# after training is finished. You may also need to increase the Tensor arena
# model size if the model fails to load.

from google.colab import files

files.download(f"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>