<div align="center">
  <img src="https://raw.githubusercontent.com/MasterPhooey/MicroWakeWord-Trainer-Docker/refs/heads/main/mmw.png" alt="MicroWakeWord Trainer Logo" width="100" />
  <h1>MicroWakeWord Trainer Docker</h1>
</div>

This notebook steps you through training a robust microWakeWord model. It is intended as a **starting point** for users looking to create a high-performance wake word detection model. This notebook is optimized for Python 3.10.

**The model generated from this notebook is designed for practical use, but achieving optimal performance will require experimentation with various settings and datasets. The provided scripts and configurations aim to give you a strong foundation to build upon.**

Throughout the notebook, you will find comments suggesting specific settings to modify and experiment with to enhance your model's performance.

By the end of this notebook, you will have:
- A trained TensorFlow Lite model ready for deployment.
- A JSON manifest file to integrate the model with ESPHome.

To use the generated model in ESPHome, refer to the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for integration details. You can also explore example configurations in the [model repository](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2).

In [None]:
# Installs microWakeWord. Be sure to restart the session after this is finished.
import platform
import sys
import os

if platform.system() == "Darwin":
    # `pymicro-features` is installed from a fork to support building on macOS
    !"{sys.executable}" -m pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version' --root-user-action=ignore

# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter
!"{sys.executable}" -m pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f' --root-user-action=ignore

# Clone the microWakeWord repository
repo_path = "./microWakeWord"
if not os.path.exists(repo_path):
    print("Cloning microWakeWord repository...")
    !git clone https://github.com/kahrendt/microWakeWord.git {repo_path}

# Ensure the repository exists before attempting to install
if os.path.exists(repo_path):
    print("Installing microWakeWord...")
    !"{sys.executable}" -m pip install -e {repo_path} --root-user-action=ignore
else:
    print(f"Repository not found at {repo_path}. Cloning might have failed.")

In [None]:
# Generates 1 sample of the target word for manual verification.

target_word = 'hey_norman'  # Phonetic spellings may produce better samples

import os
import sys
import platform

from IPython.display import Audio

# Ensure the repository is cloned correctly
if not os.path.exists("./piper-sample-generator"):
    if platform.system() == "Darwin":
        !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator
    else:
        !git clone https://github.com/rhasspy/piper-sample-generator

# Download the required model
if not os.path.exists("piper-sample-generator/models/en_US-libritts_r-medium.pt"):
    !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'

# Install system dependencies
!"{sys.executable}" -m pip install torch torchaudio piper-phonemize-cross==1.2.1

# Ensure the repository path is in sys.path
if "piper-sample-generator/" not in sys.path:
    sys.path.append("piper-sample-generator/")

# Generate sample
!"{sys.executable}" piper-sample-generator/generate_samples.py "{target_word}" \
--max-samples 1 \
--batch-size 1 \
--output-dir generated_samples

# Play the generated audio sample
audio_path = "generated_samples/0.wav"
if os.path.exists(audio_path):
    display(Audio(audio_path, autoplay=True))
else:
    print(f"Audio file not found at {audio_path}")

In [None]:
# Generates a larger amount of wake word samples.
# Start here when trying to improve your model.
# See https://github.com/rhasspy/-m piper-sample-generator for the full set of
# parameters. In particular, experiment with noise-scales and noise-scale-ws,
# generating negative samples similar to the wake word, and generating many more
# wake word samples, possibly with different phonetic pronunciations.

!"{sys.executable}" piper-sample-generator/generate_samples.py "{target_word}" \
--max-samples 10000 \
--batch-size 100 \
--output-dir generated_samples

In [None]:
# Downloads audio data for augmentation. This can be slow!
# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024
#
# **Important note!** The data downloaded here has a mixture of difference
# licenses and usage restrictions. As such, any custom models trained with this
# data should be considered as appropriate for **non-commercial** personal use only.

import os
import scipy.io.wavfile
import numpy as np
import soundfile as sf
from pathlib import Path
from tqdm import tqdm
import requests
import tarfile
import zipfile
from datasets import load_dataset

# Function to download and process RIR dataset
def download_rir_dataset(dataset_name, output_dir, split="train"):
    output_dir = Path(output_dir)
    if not output_dir.exists():
        output_dir.mkdir(parents=True, exist_ok=True)
        try:
            rir_dataset = load_dataset(dataset_name, split=split, streaming=True)
            print(f"Downloading {dataset_name} to {output_dir}...")
            for row in tqdm(rir_dataset):
                name = Path(row['audio']['path']).name
                scipy.io.wavfile.write(
                    output_dir / name,
                    16000,
                    (row['audio']['array'] * 32767).astype(np.int16)
                )
            print(f"Finished downloading {dataset_name} to {output_dir}.\n")
        except Exception as e:
            print(f"Error downloading {dataset_name}: {e}")
    else:
        print(f"{output_dir} already exists. Skipping download.\n")

# Download MIT RIRs
download_rir_dataset(
    "davidscripka/MIT_environmental_impulse_responses",
    "./mit_rirs"
)

# Function to download files
def download_file(url, output_path):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    with open(output_path, "wb") as f, tqdm(
        desc=f"Downloading {output_path.name}",
        total=total_size,
        unit="B",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for chunk in response.iter_content(chunk_size=1024):
            f.write(chunk)
            bar.update(len(chunk))
    print(f"Downloaded {output_path}")

# Function to extract .tar files
def extract_tar(file_path, extract_dir):
    with tarfile.open(file_path, "r") as tar:
        tar.extractall(path=extract_dir)
    print(f"Extracted {file_path} to {extract_dir}")

# Function to download and extract ZIP files
def download_and_extract_zip(url, extract_to):
    file_name = url.split("/")[-1]
    local_path = Path(extract_to) / file_name
    download_file(url, local_path)
    with zipfile.ZipFile(local_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    print(f"Extracted {file_name} to {extract_to}")

# Directories
raw_dir = Path("./audioset_raw")
processed_dir = Path("./audioset_16k")
raw_dir.mkdir(exist_ok=True)
processed_dir.mkdir(exist_ok=True)

# Full-scale dataset download links
dataset_links = [
    f"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar"
    for i in range(10)  # Adjust for additional parts
]

# Step 1: Download all parts of the dataset
print("Downloading datasets...")
for link in dataset_links:
    file_name = link.split("/")[-1]
    output_path = raw_dir / file_name
    if not output_path.exists():
        download_file(link, output_path)

# Step 2: Extract all .tar files
print("Extracting datasets...")
for file_path in raw_dir.glob("*.tar"):
    extract_dir = raw_dir / file_path.stem
    if not extract_dir.exists():
        extract_tar(file_path, extract_dir)

# Step 3: Convert audio files to 16kHz WAV
audio_files = list(Path(raw_dir).glob("**/*.flac"))
print(f"Number of FLAC files found: {len(audio_files)}")
if not audio_files:
    raise FileNotFoundError("No .flac files found in the raw directories. Check your dataset extraction.")

print("Converting audio files to 16kHz WAV...")
corrupted_files = []
resampled_files = []

for file_path in tqdm(audio_files, desc="Processing audio files"):
    try:
        # Read the .flac file
        data, samplerate = sf.read(file_path)

        # Check and resample if needed
        if samplerate != 16000:
            resampled_files.append(str(file_path))
            data = np.interp(
                np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),
                np.arange(len(data)),
                data,
            )

        # Convert and save as WAV
        output_path = processed_dir / file_path.name.replace(".flac", ".wav")
        scipy.io.wavfile.write(
            output_path,
            16000,
            (data * 32767).astype(np.int16),
        )
    except Exception as e:
        corrupted_files.append(str(file_path))

# Log corrupted files
if corrupted_files:
    with open("corrupted_files.log", "w") as log_file:
        log_file.writelines(f"{file}\n" for file in corrupted_files)

# Log resampled files
if resampled_files:
    with open("resampled_files.log", "w") as log_file:
        log_file.writelines(f"{file}\n" for file in resampled_files)

print(f"Audio conversion complete! {len(corrupted_files)} files corrupted and logged.")
print(f"{len(resampled_files)} files resampled and logged.")

# Process fma_xs dataset
fma_raw_dir = Path("./fma")
fma_processed_dir = Path("./fma_16k")  # Separate directory for fma_xs processed files
fma_raw_dir.mkdir(exist_ok=True)
fma_processed_dir.mkdir(exist_ok=True)

fma_link = "https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip"
fma_zip_path = fma_raw_dir / "fma_xs.zip"

if not fma_zip_path.exists():
    print("Downloading fma_xs dataset...")
    download_file(fma_link, fma_zip_path)

print("Extracting fma_xs dataset...")
download_and_extract_zip(fma_link, fma_raw_dir)

fma_audio_files = list(fma_raw_dir.glob("**/*.mp3"))
print(f"Number of MP3 files found: {len(fma_audio_files)}")
if not fma_audio_files:
    raise FileNotFoundError("No .mp3 files found in the fma directory. Check your dataset extraction.")

print("Converting fma_xs files to 16kHz WAV...")
for file_path in tqdm(fma_audio_files, desc="Processing fma_xs files"):
    try:
        # Read the .mp3 file
        data, samplerate = sf.read(file_path)

        # Check and resample if needed
        if samplerate != 16000:
            data = np.interp(
                np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),
                np.arange(len(data)),
                data,
            )

        # Convert and save as WAV
        output_path = fma_processed_dir / file_path.name.replace(".mp3", ".wav")
        scipy.io.wavfile.write(
            output_path,
            16000,
            (data * 32767).astype(np.int16),
        )
    except Exception as e:
        corrupted_files.append(str(file_path))

# Log corrupted files from fma_xs
if corrupted_files:
    with open("fma_corrupted_files.log", "w") as log_file:
        log_file.writelines(f"{file}\n" for file in corrupted_files)

print("fma_xs processing complete!")
print("Full-scale dataset preparation complete!")

In [None]:
# Sets up the augmentations.
# To improve your model, experiment with these settings and use more sources of
# background clips.

import os
from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration

def validate_directories(paths):
    for path in paths:
        if not os.path.exists(path):
            print(f"Error: Directory {path} does not exist. Please ensure preprocessing is complete.")
            return False
    return True

# Paths to augmented data
impulse_paths = ['mit_rirs']
background_paths = ['fma_16k', 'audioset_16k']

if not validate_directories(impulse_paths + background_paths):
    raise ValueError("One or more required directories are missing.")

clips = Clips(
    input_directory='./generated_samples',
    file_pattern='*.wav',
    max_clip_duration_s=5,
    remove_silence=True,
    random_split_seed=10,
    split_count=0.1,
)

augmenter = Augmentation(
    augmentation_duration_s=3.2,
    augmentation_probabilities={
        "SevenBandParametricEQ": 0.1,
        "TanhDistortion": 0.05,
        "PitchShift": 0.15,
        "BandStopFilter": 0.1,
        "AddColorNoise": 0.1,
        "AddBackgroundNoise": 0.9,
        "Gain": 0.8,
        "RIR": 0.7,
    },
    impulse_paths=impulse_paths,
    background_paths=background_paths,
    background_min_snr_db=-10,
    background_max_snr_db=15,
    min_jitter_s=0.15,
    max_jitter_s=0.25,
)


In [None]:
# Augment a random clip and play it back to verify it works well
from pathlib import Path
from IPython.display import Audio
from microwakeword.audio.audio_utils import save_clip

# Ensure output directory exists
output_dir = Path('./augmented_clips')
output_dir.mkdir(exist_ok=True)

try:
    # Get a random clip and apply augmentation
    random_clip = clips.get_random_clip()
    augmented_clip = augmenter.augment_clip(random_clip)
    
    # Save augmented clip to file
    output_file = output_dir / 'augmented_clip.wav'
    save_clip(augmented_clip, output_file)
    print(f"Augmented clip saved to {output_file}")
    
    # Playback augmented clip
    display(Audio(str(output_file), autoplay=True))
except Exception as e:
    print(f"Error during augmentation or playback: {e}")

In [None]:
# Augment samples and save the training, validation, and testing sets.
# Validating and testing samples generated the same way can make the model
# benchmark better than it performs in real-word use. Use real samples or TTS
# samples generated with a different TTS engine to potentially get more accurate
# benchmarks.

import os
from mmap_ninja.ragged import RaggedMmap
from microwakeword.audio.spectrograms import SpectrogramGeneration

# Output directory for augmented features
output_dir = 'generated_augmented_features'
os.makedirs(output_dir, exist_ok=True)

# Configuration for each split
split_config = {
    "training": {"name": "train", "repetition": 2, "slide_frames": 10},
    "validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
    "testing": {"name": "test", "repetition": 1, "slide_frames": 1},
}

# Generate augmented features for each split
for split, config in split_config.items():
    out_dir = os.path.join(output_dir, split)
    os.makedirs(out_dir, exist_ok=True)
    print(f"Processing {split} set...")

    try:
        # Spectrogram generation configuration
        spectrograms = SpectrogramGeneration(
            clips=clips,
            augmenter=augmenter,
            slide_frames=config["slide_frames"],
            step_ms=10,  # Can parameterize this if needed
        )

        # Generate and save spectrogram features
        RaggedMmap.from_generator(
            out_dir=os.path.join(out_dir, 'wakeword_mmap'),
            sample_generator=spectrograms.spectrogram_generator(
                split=config["name"], repeat=config["repetition"]
            ),
            batch_size=100,  # Can parameterize this if needed
            verbose=True,
        )
        print(f"Completed processing {split} set. Output saved to {out_dir}")
    except Exception as e:
        print(f"Error processing {split} set: {e}")

In [None]:
# Downloads pre-generated spectrogram features (made for microWakeWord in
# particular) for various negative datasets. This can be slow!

import os
import requests
import zipfile
from pathlib import Path
from tqdm import tqdm

# Function to download a file with progress bar
def download_file(url, output_path):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    with open(output_path, "wb") as f, tqdm(
        desc=f"Downloading {output_path.name}",
        total=total_size,
        unit="B",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for chunk in response.iter_content(chunk_size=1024):
            f.write(chunk)
            bar.update(len(chunk))
    print(f"Downloaded: {output_path}")

# Function to extract ZIP files
def extract_zip(zip_path, extract_to):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    print(f"Extracted: {zip_path} to {extract_to}")

# Directory for negative datasets
output_dir = Path('./negative_datasets')
output_dir.mkdir(exist_ok=True)

# Negative dataset URLs
link_root = "https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/"
filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']

# Download and extract files
for fname in filenames:
    link = link_root + fname
    zip_path = output_dir / fname

    # Download only if the file doesn't already exist
    if not zip_path.exists():
        try:
            download_file(link, zip_path)
        except Exception as e:
            print(f"Error downloading {fname}: {e}")
            continue

    # Extract the ZIP file
    try:
        extract_zip(zip_path, output_dir)
    except Exception as e:
        print(f"Error extracting {fname}: {e}")


In [None]:
# Save a yaml config that controls the training process
# These hyperparamters can make a huge different in model quality.
# Experiment with sampling and penalty weights and increasing the number of
# training steps.

import yaml
import os

config = {}

config["window_step_ms"] = 10

config["train_dir"] = "trained_models/wakeword"

config["features"] = [
    {
        "features_dir": "generated_augmented_features",
        "sampling_weight": 5.0,  # Increased
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/speech",
        "sampling_weight": 8.0,  # Adjusted
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/dinner_party",
        "sampling_weight": 8.0,  # Adjusted
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/no_speech",
        "sampling_weight": 5.0,  # Balanced
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/dinner_party_eval",
        "sampling_weight": 0.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "split",
        "type": "mmap",
    },
]

config["training_steps"] = [20000]  # Increased
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [15]  # Adjusted
config["learning_rates"] = [0.0005]  # Adjusted
config["batch_size"] = 128

config["time_mask_max_size"] = [30]  # Enabled SpecAugment
config["time_mask_count"] = [2]
config["freq_mask_max_size"] = [15]
config["freq_mask_count"] = [2]

config["eval_step_interval"] = 1000  # Adjusted
config["clip_duration_ms"] = 2000  # Increased

config["target_minimization"] = 0.9
config["minimization_metric"] = "false_positive_rate"  # Updated
config["maximization_metric"] = "average_viable_recall"

with open(os.path.join("training_parameters.yaml"), "w") as file:
    documents = yaml.dump(config, file)

In [None]:
# Trains a model. When finished, it will quantize and convert the model to a
# streaming version suitable for on-device detection.
# It will resume if stopped, but it will start over at the configured training
# steps in the yaml file.
# Change --train 0 to only convert and test the best-weighted model.
# On Google colab, it doesn't print the mini-batch results, so it may appear
# stuck for several minutes! Additionally, it is very slow compared to training
# on a local GPU.

import os
import sys

# Ensure the library path is correctly set
os.environ['LD_LIBRARY_PATH'] = "/usr/lib/x86_64-linux-gnu:" + os.environ.get('LD_LIBRARY_PATH', '')

# Training command with optimized settings
!"{sys.executable}" -m microwakeword.model_train_eval \
--training_config='training_parameters.yaml' \
--train 1 \
--restore_checkpoint 1 \
--test_tf_nonstreaming 1 \
--test_tflite_nonstreaming 1 \
--test_tflite_nonstreaming_quantized 1 \
--test_tflite_streaming 1 \
--test_tflite_streaming_quantized 1 \
--use_weights "best_weights" \
mixednet \
--pointwise_filters "64,96,128,160" \
--repeat_in_block "2,2,3,3" \
--mixconv_kernel_sizes '[5], [7,11], [9,15], [17,23]' \
--residual_connection "1,1,1,0" \
--first_conv_filters 48 \
--first_conv_kernel_size 7 \
--stride 2 


In [None]:
import shutil
import json
from IPython.display import FileLink, HTML

# Copy the TFLite model file to the working directory
source_path = "trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
destination_path = "./stream_state_internal_quant.tflite"
shutil.copy(source_path, destination_path)

# Define the JSON metadata for the model
json_data = {
    "type": "micro",
    "wake_word": "hey_norman",  # Adjust based on your target wake word
    "author": "master phooey",
    "website": "https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker",
    "model": "stream_state_internal_quant.tflite",
    "trained_languages": ["en"],
    "version": 2,
    "micro": {
        "probability_cutoff": 0.97,  # Threshold for wake word detection
        "sliding_window_size": 5,  # Frames averaged for predictions
        "feature_step_size": 10,
        "tensor_arena_size": 30000,  # Memory allocation for TensorFlow Lite model
        "minimum_esphome_version": "2024.7.0"
    }
}

# Write the metadata to a JSON file
json_path = "./stream_state_internal_quant.json"
with open(json_path, "w") as json_file:
    json.dump(json_data, json_file, indent=2)

# Generate download links with styled HTML
tflite_link = FileLink(destination_path)
json_link = FileLink(json_path)

html_content = f"""
<h3 style="color:orange;">Your files are ready for download:</h3>
<ul>
    <li><b><a href="{tflite_link.url}" target="_blank" style="color:orange;">TFLite Model: stream_state_internal_quant.tflite</a></b></li>
    <li><b><a href="{json_link.url}" target="_blank" style="color:orange;">JSON Metadata: stream_state_internal_quant.json</a></b></li>
</ul>
<p style="font-size:12px; color:gray;">Click the links to download the files. Ensure the files are moved to the correct directory for deployment.</p>
"""

display(HTML(html_content))