# Introduction

This notebook demonstrates how to train custom openWakeWord models using pre-defined datasets and an automated process for dataset generation and training. While not guaranteed to always produce the best performing model, the methods shown in this notebook often produce baseline models with releatively strong performance.

Manual data preparation and model training (e.g., see the [training models](training_models.ipynb) notebook) remains an option for when full control over the model development process is needed.

At a high level, the automatic training process takes advantages of several techniques to try and produce a good model, including:

- Early-stopping and checkpoint averaging (similar to [stochastic weight averaging](https://arxiv.org/abs/1803.05407)) to search for the best models found during training, according to the validation data
- Variable learning rates with cosine decay and multiple cycles
- Adaptive batch construction to focus on only high-loss examples when the model begins to converge, combined with gradient accumulation to ensure that batch sizes are still large enough for stable training
- Cycical weight schedules for negative examples to help the model reduce false-positive rates

See the contents of the `train.py` file for more details.

# Environment Setup

To begin, we'll need to install the requirements for training custom models. In particular, a relatively recent version of Pytorch and custom fork of the [piper-sample-generator](https://github.com/dscripka/piper-sample-generator) library for generating synthetic examples for the custom model.

**Important Note!** Currently, automated model training is only supported on linux systems due to the requirements of the text to speech library used for synthetic sample generation (Piper). It may be possible to use Piper on Windows/Mac systems, but that has not (yet) been tested.

In [None]:
## Environment setup

# install piper-sample-generator (currently only supports linux systems)
!git clone https://github.com/rhasspy/piper-sample-generator
!wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'
%pip install piper-phonemize
%pip install webrtcvad

# install openwakeword (full installation to support training)
!git clone https://github.com/dscripka/openwakeword
%pip install -e ./openwakeword
!cd openwakeword

# install other dependencies
%pip install mutagen==1.47.0
%pip install torchinfo==1.8.0
%pip install torchmetrics==1.2.0
%pip install speechbrain==0.5.14
%pip install audiomentations==0.33.0
%pip install torch-audiomentations==0.11.0
%pip install acoustics==0.2.6
%pip install tensorflow-cpu>=2.12.0
%pip install tensorflow_probability==0.16.0
%pip install onnx_tf==1.10.0
%pip install pronouncing==0.2.0
%pip install datasets==2.14.6 pyarrow==14.0.1
%pip install deep-phonemizer==0.0.19

# Download required models (workaround for Colab)
import os
os.makedirs("./openwakeword/openwakeword/resources/models", exist_ok=True)
!wget https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/embedding_model.onnx -O ./openwakeword/openwakeword/resources/models/embedding_model.onnx
!wget https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/embedding_model.tflite -O ./openwakeword/openwakeword/resources/models/embedding_model.tflite
!wget https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/melspectrogram.onnx -O ./openwakeword/openwakeword/resources/models/melspectrogram.onnx
!wget https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/melspectrogram.tflite -O ./openwakeword/openwakeword/resources/models/melspectrogram.tflite


In [None]:
# This cell is no longer needed but kept for reference if you need to force-clean again
# !conda remove -y --force pyarrow
# %pip uninstall -y pyarrow
# %pip install pyarrow==14.0.1 datasets

In [None]:
# Debugging
import sys
print("python:", sys.executable)

try:
    import pyarrow
    print("pyarrow version:", pyarrow.__version__)
    print("pyarrow path:", pyarrow.__file__)
except ImportError as e:
    print("ERROR: pyarrow not installed:", e)
    print("Please install pyarrow: pip install pyarrow>=14")
except Exception as e:
    print("ERROR loading pyarrow:", type(e).__name__, str(e))

try:
    import datasets
    print("datasets version:", datasets.__version__)
except ImportError as e:
    error_msg = str(e)
    if "HfFolder" in error_msg or "cannot import name" in error_msg:
        print("ERROR: Version incompatibility between datasets and huggingface_hub")
        print("datasets 2.14.6 requires huggingface_hub < 0.20.0, but newer versions removed HfFolder")
        print("Solution: pip install --force-reinstall 'datasets>=2.20.0,<3.0' 'huggingface_hub>=0.24.0'")
    else:
        print("ERROR: datasets not installed:", e)
        print("Please install datasets: pip install datasets>=2.20.0")
except TypeError as e:
    if "NoneType" in str(e) or "packaging" in str(e).lower():
        print("ERROR: Version compatibility issue between datasets and packaging")
        print("Try: pip install 'datasets>=2.20.0,<3.0' 'packaging>=24.0' --force-reinstall")
        print("Or: pip install 'datasets==2.20.0' 'packaging>=24.0' --force-reinstall")
    else:
        print("ERROR loading datasets:", type(e).__name__, str(e))
except Exception as e:
    print("ERROR loading datasets:", type(e).__name__, str(e))
    print("Full error:", e)

In [None]:
# Imports

import os
import numpy as np
import torch
import sys
from pathlib import Path
import uuid
import yaml
import datasets
import scipy
from tqdm import tqdm


# Download Data

When training new openWakeWord models using the automated procedure, four specific types of data are required:

1) Synthetic examples of the target word/phrase generated with text-to-speech models

2) Synthetic examples of adversarial words/phrases generated with text-to-speech models

3) Room impulse reponses and noise/background audio data to augment the synthetic examples and make them more realistic

4) Generic "negative" audio data that is very unlikely to contain examples of the target word/phrase in the context where the model should detect it. This data can be the original audio data, or precomputed openWakeWord features ready for model training.

5) Validation data to use for early-stopping when training the model.

For the purposes of this notebook, all five of these sources will either be generated manually or can be obtained from HuggingFace thanks to their excellent `datasets` library and extremely generous hosting policy. Also note that while only a portion of some datasets are downloaded, for the best possible performance it is recommended to download the entire dataset and keep a local copy for future training runs.

In [None]:
# Install/upgrade required packages for HuggingFace datasets
# Note: Run this cell once, then restart kernel if needed
# Using compatible version ranges to avoid NumPy 2.x, packaging, and HfFolder issues
# datasets 2.20.0+ works with huggingface_hub 0.24.0+ (HfFolder was removed in 0.20.0)

# If using conda (recommended for conda environments):
# !conda install -y -c conda-forge "numpy<2" "pyarrow>=14" "packaging>=24"
# !conda install -y -c huggingface "datasets>=2.20.0,<3.0" "huggingface_hub>=0.24.0"
# %pip install "fsspec>=2024.6.0" "aiohttp" "soundfile"

# If using pip (works in any environment, but can conflict with conda):
%pip install -q -U "numpy<2" "datasets>=2.20.0,<3.0" "huggingface_hub>=0.24.0" "fsspec>=2024.6.0" "pyarrow>=14" "aiohttp" "soundfile" "packaging>=24.0"

In [None]:
# This cell is kept for reference - use Cell 11 instead
# If you encounter issues, you can uncomment and run this to force reinstall:
# %pip uninstall -y huggingface-hub huggingface_hub datasets pyarrow
# %pip install --no-cache-dir --force-reinstall "huggingface_hub>=0.24.0" "datasets>=2.20.0" "fsspec>=2024.6.0" "pyarrow>=14" "packaging"
# Then restart kernel

In [None]:
# Download room impulse responses collected by MIT
# https://mcdermottlab.mit.edu/Reverb/IR_Survey.html

output_dir = "./mit_rirs"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
rir_dataset = datasets.load_dataset("davidscripka/MIT_environmental_impulse_responses", split="train", streaming=True)

# Save clips to 16-bit PCM wav files
for row in tqdm(rir_dataset):
    name = row['audio']['path'].split('/')[-1]
    scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))

In [None]:
from pathlib import Path
import itertools
import numpy as np
import scipy.io.wavfile
from tqdm.auto import tqdm

import datasets
from datasets import load_dataset

# -------------------------
# Download noise/background audio
# -------------------------

# AudioSet (parquet shard on HF)
# Using direct URL to parquet file from HuggingFace Hub
repo_id = "agkphysics/AudioSet"
shard_id = "09"
# Construct the direct URL to the parquet file
parquet_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/data/bal_train/{shard_id}.parquet"

output_dir = Path("audioset_16k")
output_dir.mkdir(parents=True, exist_ok=True)

try:
    # Try loading parquet file directly using URL
    print(f"Loading AudioSet parquet file from: {parquet_url}")
    audioset_dataset = load_dataset(
        "parquet",
        data_files=parquet_url,
        split="train",
        streaming=True,
    )
    
    # Find the audio column (it may be named "audio", but we detect it safely)
    audio_col = None
    for name, feat in audioset_dataset.features.items():
        if isinstance(feat, datasets.Audio):
            audio_col = name
            break
    if audio_col is None:
        # sometimes it's stored as a dict-like column; try the common default
        if "audio" in audioset_dataset.features:
            audio_col = "audio"
        else:
            # Try to inspect first row to find audio column
            first_row = next(iter(audioset_dataset))
            for key in first_row.keys():
                if isinstance(first_row[key], dict) and "array" in first_row[key]:
                    audio_col = key
                    break
            if audio_col is None:
                raise ValueError(f"No audio column found. Columns: {list(audioset_dataset.features.keys())}")
    
    print(f"Found audio column: {audio_col}")
    
    # Cast/Decode audio to 16kHz
    audioset_dataset = audioset_dataset.cast_column(audio_col, datasets.Audio(sampling_rate=16000))
    
    # IMPORTANT: the shard can still be big; limit how many files you convert if needed
    max_items = 2000  # change or set to None to do all (can take a while + lots of disk)
    iterator = audioset_dataset if max_items is None else itertools.islice(audioset_dataset, max_items)
    
    for i, row in enumerate(tqdm(iterator, total=(max_items or None), desc="AudioSet -> 16k wav")):
        audio = row[audio_col]
        # After cast_column, audio should be a dict with "array" and "path" keys
        if isinstance(audio, dict):
            audio_array = audio.get("array")
            audio_path = audio.get("path", "")
            if audio_array is None:
                print(f"Warning: No audio array found in row {i}, skipping...")
                continue
        else:
            # Fallback: if it's not a dict, assume it's already the array
            audio_array = audio
            audio_path = ""
            
        stem = Path(audio_path).stem if audio_path else f"audioset_{i:06d}"
        out_path = output_dir / f"{stem}.wav"
        scipy.io.wavfile.write(out_path, 16000, (audio_array * 32767).astype(np.int16))
        
except Exception as e:
    print(f"Error loading AudioSet: {e}")
    print("Trying alternative method...")
    # Alternative: try loading the dataset directly if it's available
    try:
        audioset_dataset = load_dataset(repo_id, split=f"train[{shard_id}]", streaming=True)
        audioset_dataset = audioset_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
        max_items = 2000
        iterator = itertools.islice(audioset_dataset, max_items)
        for i, row in enumerate(tqdm(iterator, total=max_items, desc="AudioSet -> 16k wav")):
            audio = row["audio"]
            stem = Path(audio.get("path", "")).stem if audio.get("path") else f"audioset_{i:06d}"
            out_path = output_dir / f"{stem}.wav"
            scipy.io.wavfile.write(out_path, 16000, (audio["array"] * 32767).astype(np.int16))
    except Exception as e2:
        print(f"Alternative method also failed: {e2}")
        print("Skipping AudioSet download. You may need to download manually or check dataset availability.")


# -------------------------
# Free Music Archive (FMA) small
# -------------------------
fma_dir = Path("fma")
fma_dir.mkdir(parents=True, exist_ok=True)

try:
    print("Loading FMA dataset...")
    fma_dataset = load_dataset("rudraml/fma", name="small", split="train", streaming=True)
    fma_dataset = fma_dataset.cast_column("audio", datasets.Audio(sampling_rate=16000))
    
    n_hours = 1  # FMA small clips are ~30s each
    n_items = n_hours * 3600 // 30
    
    for i, row in enumerate(tqdm(itertools.islice(fma_dataset, n_items), total=n_items, desc="FMA -> 16k wav")):
        audio = row["audio"]
        stem = Path(audio.get("path", "")).stem if audio.get("path") else f"fma_{i:06d}"
        out_path = fma_dir / f"{stem}.wav"
        scipy.io.wavfile.write(out_path, 16000, (audio["array"] * 32767).astype(np.int16))
except Exception as e:
    print(f"Error loading FMA dataset: {e}")
    print("Skipping FMA download. You may need to check dataset availability.")



In [None]:
# Download pre-computed openWakeWord features for training and validation

# training set (~2,000 hours from the ACAV100M Dataset)
# See https://huggingface.co/datasets/davidscripka/openwakeword_features for more information
!wget https://huggingface.co/datasets/davidscripka/openwakeword_features/resolve/main/openwakeword_features_ACAV100M_2000_hrs_16bit.npy

# validation set for false positive rate estimation (~11 hours)
!wget https://huggingface.co/datasets/davidscripka/openwakeword_features/resolve/main/validation_set_features.npy

# Define Training Configuration

For automated model training openWakeWord uses a specially designed training script and a [YAML](https://yaml.org/) configuration file that defines all of the information required for training a new wake word/phrase detection model.

It is strongly recommended that you review [the example config file](../examples/custom_model.yml), as each value is fully documented there. For the purposes of this notebook, we'll read in the YAML file to modify certain configuration parameters before saving a new YAML file for training our example model. Specifically:

- We'll train a detection model for the phrase "hey sebastian"
- We'll only generate 5,000 positive and negative examples (to save on time for this example)
- We'll only generate 1,000 validation positive and negative examples for early stopping (again to save time)
- The model will only be trained for 10,000 steps (larger datasets will benefit from longer training)
- We'll reduce the target metrics to account for the small dataset size and limited training.

On the topic of target metrics, there are *not* specific guidelines about what these metrics should be in practice, and you will need to conduct testing in your target deployment environment to establish good thresholds. However, from very limited testing the default values in the config file (accuracy >= 0.7, recall >= 0.5, false-positive rate <= 0.2 per hour) seem to produce models with reasonable performance.


In [None]:
# Import yaml
import yaml

In [None]:
# Load default YAML config file for training
config_path = Path('..') / 'examples' / 'custom_model.yml'
config = yaml.safe_load(config_path.read_text())
config



In [None]:
# Modify values in the config and save a new version

config["target_phrase"] = ["how do you wanna do this"]
config["model_name"] = config["target_phrase"][0].replace(" ", "_")
config["n_samples"] = 70000
config["n_samples_val"] = 2000
config["steps"] = 50000
config["target_accuracy"] = 0.5
config["target_recall"] = 0.5

config["background_paths"] = ['./audioset_16k', './fma']  # multiple background datasets are supported
config["false_positive_validation_data_path"] = "validation_set_features.npy"
config["feature_data_files"] = {"ACAV100M_sample": "openwakeword_features_ACAV100M_2000_hrs_16bit.npy"}

output_dir = Path(config["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True)

with open('my_model.yaml', 'w') as file:
    yaml.dump(config, file)



# Train the Model

With the data downloaded and training configuration set, we can now start training the model. We'll do this in parts to better illustrate the sequence, but you can also execute every step at once for a fully automated process.

In [None]:
conda install -y -c pytorch -c nvidia pytorch torchaudio pytorch-cuda=12.1
conda install -y -c piper

In [None]:
# Step 1: Generate synthetic clips
# For the number of clips we are using, this should take ~10 minutes on a free Google Colab instance with a T4 GPU
# If generation fails, you can simply run this command again as it will continue generating until the
# number of files meets the targets specified in the config file

!{sys.executable} openwakeword/openwakeword/train.py --training_config "/home/stud/j/js490/openWakeWord/examples/custom_model.yml" --generate_clips

In [None]:
# Step 2: Augment the generated clips

!{sys.executable} openwakeword/openwakeword/train.py --training_config my_model.yaml --augment_clips

In [None]:
# Step 3: Train model

!{sys.executable} openwakeword/openwakeword/train.py --training_config my_model.yaml --train_model

In [None]:
# Step 4 (Optional): On Google Colab, sometimes the .tflite model isn't saved correctly
# If so, run this cell to retry

# Manually save to tflite as this doesn't work right in colab
def convert_onnx_to_tflite(onnx_model_path, output_path):
    """Converts an ONNX version of an openwakeword model to the Tensorflow tflite format."""
    # imports
    import onnx
    import logging
    import tempfile
    from onnx_tf.backend import prepare
    import tensorflow as tf

    # Convert to tflite from onnx model
    onnx_model = onnx.load(onnx_model_path)
    tf_rep = prepare(onnx_model, device="CPU")
    with tempfile.TemporaryDirectory() as tmp_dir:
        tf_rep.export_graph(os.path.join(tmp_dir, "tf_model"))
        converter = tf.lite.TFLiteConverter.from_saved_model(os.path.join(tmp_dir, "tf_model"))
        tflite_model = converter.convert()

        logging.info(f"####\nSaving tflite mode to '{output_path}'")
        with open(output_path, 'wb') as f:
            f.write(tflite_model)

    return None

convert_onnx_to_tflite(f"my_custom_model/{config['model_name']}.onnx", f"my_custom_model/{config['model_name']}.tflite")


After the model finishes training, the auto training script will automatically convert it to ONNX and tflite versions, saving them as `my_custom_model/<model_name>.onnx/tflite` in the present working directory, where `<model_name>` is defined in the YAML training config file. Either version can be used as normal with `openwakeword`. I recommend testing them with the [`detect_from_microphone.py`](https://github.com/dscripka/openWakeWord/blob/main/examples/detect_from_microphone.py) example script to see how the model performs!