<div align="center">

  <img src="https://raw.githubusercontent.com/MasterPhooey/MicroWakeWord-Trainer-Docker/refs/heads/main/mmw.png" alt="MicroWakeWord Trainer Logo" width="100" />

  <h1>MicroWakeWord Trainer Docker</h1>

</div>

This notebook steps you through training a basic microWakeWord model. It is intended as a **starting point** for advanced users. You should use Python 3.11.

**The model generated will most likely not be usable for everyday use; it may be difficult to trigger or falsely activates too frequently. You will most likely have to experiment with many different settings to obtain a decent model!**

In the comment at the start of certain blocks, I note some specific settings to consider modifying.

This runs on Google Colab, but is extremely slow compared to training on a local GPU. If you must use Colab, be sure to Change the runtime type to a GPU. Even then, it still slow!

At the end of this notebook, you will be able to download a tflite file. To use this in ESPHome, you need to write a model manifest JSON file. See the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for the details and the [model repo](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2) for examples.

In [None]:
# Initial Setup and Data Preparation
# This cell runs the main data preparation script.
# It ensures all necessary repositories and datasets are downloaded and processed into /data.
import os
import sys

print("Starting data preparation...")
# Assuming prepare_local_data.py is in /data, and this notebook runs from /data as well.
# The --data-dir /data ensures the script uses the correct base path inside Docker.
prepare_script_path = "/data/prepare_local_data.py"
if not os.path.exists(prepare_script_path):
    # Fallback if script is in parent dir relative to /data (e.g. /prepare_local_data.py)
    prepare_script_path = "../prepare_local_data.py" 
    if not os.path.exists(prepare_script_path):
        print(f"ERROR: {prepare_script_path} not found. Cannot prepare data.")
        # Consider raising an exception or sys.exit(1)
else:
    print(f"Executing: !python {prepare_script_path} --data-dir /data")
    !python {prepare_script_path} --data-dir /data
print("Data preparation script finished.")

# Ensure piper-sample-generator is in sys.path if it's used as a collection of scripts
piper_path = "/data/piper-sample-generator"
if piper_path not in sys.path:
    sys.path.append(piper_path)
    print(f"Added {piper_path} to sys.path")

In [None]:
# Generates 1 sample of the target word for manual verification.
target_word = 'khum_puter'  # Phonetic spellings may produce better samples

import os
import sys
from IPython.display import Audio, display

piper_script_path = "/data/piper-sample-generator/generate_samples.py"
output_sample_dir = "/data/generated_samples_test"
os.makedirs(output_sample_dir, exist_ok=True)

if not os.path.exists(piper_script_path):
    print(f"ERROR: Piper sample generator script not found at {piper_script_path}. Check data preparation.")
else:
    print(f"Generating test sample for '{target_word}'...")
    !"{sys.executable}" {piper_script_path} "{target_word}" \
    --max-samples 1 \
    --batch-size 1 \
    --output-dir {output_sample_dir}

    # Play the generated audio sample
    audio_path = os.path.join(output_sample_dir, "0.wav")
    if os.path.exists(audio_path):
        print(f"Playing test sample: {audio_path}")
        display(Audio(audio_path, autoplay=True))
    else:
        print(f"Audio file not found at {audio_path}. Sample generation might have failed.")

In [None]:
# Generates a larger amount of wake word samples.
# Start here when trying to improve your model.
# See https://github.com/rhasspy/piper-sample-generator for the full set of
# parameters. In particular, experiment with noise-scales and noise-scale-ws,
# generating negative samples similar to the wake word, and generating many more
# wake word samples, possibly with different phonetic pronunciations.
import os
import sys

piper_script_path = "/data/piper-sample-generator/generate_samples.py"
output_ww_dir = "/data/generated_samples_ww"
os.makedirs(output_ww_dir, exist_ok=True)

if not os.path.exists(piper_script_path):
    print(f"ERROR: Piper sample generator script not found at {piper_script_path}. Check data preparation.")
else:
    print(f"Generating wake word samples for '{target_word}'...")
    !"{sys.executable}" {piper_script_path} "{target_word}" \
    --max-samples 1000 \
    --batch-size 100 \
    --output-dir {output_ww_dir}
    print(f"Wake word samples generated in {output_ww_dir}")

## Augmentation Data Setup
The `prepare_local_data.py` script should have downloaded and processed all necessary augmentation data (MIT RIR, Audioset, FMA) into subdirectories within `/data`.
The following cells will set up the augmentation using these pre-prepared datasets.

In [None]:
# Sets up the augmentations.
# To improve your model, experiment with these settings and use more sources of
# background clips.

from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration
import os

# Paths to pre-prepared data in /data directory
generated_samples_path = "/data/generated_samples_ww" # From previous cell
mit_rirs_path = "/data/mit_rirs"
fma_16k_path = "/data/fma_16k"
audioset_16k_path = "/data/audioset_16k"

if not os.path.exists(generated_samples_path):
    print(f"ERROR: Wake word samples not found at {generated_samples_path}. Please run the previous cell.")
if not os.path.exists(mit_rirs_path) or not os.path.exists(fma_16k_path) or not os.path.exists(audioset_16k_path):
    print("ERROR: One or more augmentation data directories (MIT RIRs, FMA, Audioset) not found in /data. Please ensure prepare_local_data.py ran successfully.")

clips = Clips(input_directory=generated_samples_path,
              file_pattern='*.wav',
              max_clip_duration_s=None,
              remove_silence=False,
              random_split_seed=10,
              split_count=0.1,
             )

augmenter = Augmentation(augmentation_duration_s=3.2,
                         augmentation_probabilities = {
                                "SevenBandParametricEQ": 0.1,
                                "TanhDistortion": 0.1,
                                "PitchShift": 0.1,
                                "BandStopFilter": 0.1,
                                "AddColorNoise": 0.1,
                                "AddBackgroundNoise": 0.75,
                                "Gain": 1.0,
                                "RIR": 0.5,
                            },
                         impulse_paths = [mit_rirs_path],
                         background_paths = [fma_16k_path, audioset_16k_path],
                         background_min_snr_db = -5,
                         background_max_snr_db = 10,
                         min_jitter_s = 0.195,
                         max_jitter_s = 0.205,
                        )

In [None]:
# Augment a random clip and play it back to verify it works well
from IPython.display import Audio, display
from microwakeword.audio.audio_utils import save_clip
import os

augmented_clip_path = "/data/augmented_clip_test.wav"

try:
    random_clip = clips.get_random_clip()
    augmented_clip = augmenter.augment_clip(random_clip)
    save_clip(augmented_clip, augmented_clip_path)
    print(f"Playing augmented test clip: {augmented_clip_path}")
    display(Audio(augmented_clip_path, autoplay=True))
except Exception as e:
    print(f"Error during test augmentation: {e}. Check if previous cells ran successfully and data paths are correct.")

In [None]:
# Augment samples and save the training, validation, and testing sets.
# Validating and testing samples generated the same way can make the model
# benchmark better than it performs in real-word use. Use real samples or TTS
# samples generated with a different TTS engine to potentially get more accurate
# benchmarks.
import os
from mmap_ninja.ragged import RaggedMmap

output_dir_augmented_features = '/data/generated_augmented_features'
os.makedirs(output_dir_augmented_features, exist_ok=True)

splits = ["training", "validation", "testing"]
for split in splits:
  out_dir_split = os.path.join(output_dir_augmented_features, split)
  os.makedirs(out_dir_split, exist_ok=True)

  split_name = "train"
  repetition = 2

  spectrograms = SpectrogramGeneration(clips=clips,
                                     augmenter=augmenter,
                                     slide_frames=10,    # Uses the same spectrogram repeatedly, just shifted over by one frame. This simulates the streaming inferences while training/validating in nonstreaming mode.
                                     step_ms=10,
                                     )
  if split == "validation":
    split_name = "validation"
    repetition = 1
  elif split == "testing":
    split_name = "test"
    repetition = 1
    spectrograms = SpectrogramGeneration(clips=clips,
                                     augmenter=augmenter,
                                     slide_frames=1,    # The testing set uses the streaming version of the model, so no artificial repetition is necessary
                                     step_ms=10,
                                     )

  print(f"Generating augmented features for {split_name} set...")
  try:
    RaggedMmap.from_generator(
        out_dir=os.path.join(out_dir_split, 'wakeword_mmap'),
        sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),
        batch_size=100,
        verbose=True,
    )
    print(f"Finished generating features for {split_name} set.")
  except Exception as e:
    print(f"Error generating features for {split_name} set: {e}")

## Negative Datasets Setup
The `prepare_local_data.py` script should have downloaded and extracted pre-generated negative spectrogram features into `/data/negative_datasets`.

In [None]:
# Save a yaml config that controls the training process
# These hyperparamters can make a huge different in model quality.
# Experiment with sampling and penalty weights and increasing the number of
# training steps.
import yaml
import os

config = {}

config["window_step_ms"] = 10

config["train_dir"] = (
    "/data/trained_models/wakeword" # Path inside Docker
)
os.makedirs(config["train_dir"], exist_ok=True)

# Each feature_dir should have at least one of the following folders with this structure:
#  training/
#    ragged_mmap_folders_ending_in_mmap
#  testing/
#    ragged_mmap_folders_ending_in_mmap
#  testing_ambient/
#    ragged_mmap_folders_ending_in_mmap
#  validation/
#    ragged_mmap_folders_ending_in_mmap
#  validation_ambient/
#    ragged_mmap_folders_ending_in_mmap
#
#  sampling_weight: Weight for choosing a spectrogram from this set in the batch
#  penalty_weight: Penalizing weight for incorrect predictions from this set
#  truth: Boolean whether this set has positive samples or negative samples
#  truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated
#       - random: choose a random portion of the entire spectrogram - useful for long negative samples
#       - truncate_start: remove the start of the spectrogram
#       - truncate_end: remove the end of the spectrogram
#       - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets

config["features"] = [
    {
        "features_dir": "/data/generated_augmented_features", # Path inside Docker
        "sampling_weight": 2.0,
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        "features_dir": "/data/negative_datasets/speech", # Path inside Docker
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "/data/negative_datasets/dinner_party", # Path inside Docker
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "/data/negative_datasets/no_speech", # Path inside Docker
        "sampling_weight": 5.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    { # Only used for validation and testing
        "features_dir": "/data/negative_datasets/dinner_party_eval", # Path inside Docker
        "sampling_weight": 0.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "split",
        "type": "mmap",
    },
]

# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps
config["training_steps"] = [10000]

# Penalizing weight for incorrect class predictions - lists that correspond to training steps
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [20]

config["learning_rates"] = [
    0.001,
]  # Learning rates for Adam optimizer - list that corresponds to training steps
config["batch_size"] = 128

config["time_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["time_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps
config["freq_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["freq_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps

config["eval_step_interval"] = (
    500  # Test the validation sets after every this many steps
)
config["clip_duration_ms"] = (
    1500  # Maximum length of wake word that the streaming model will accept
)

# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization
# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize
# Available metrics:
#   - "loss" - cross entropy error on validation set
#   - "accuracy" - accuracy of validation set
#   - "recall" - recall of validation set
#   - "precision" - precision of validation set
#   - "false_positive_rate" - false positive rate of validation set
#   - "false_negative_rate" - false negative rate of validation set
#   - "ambient_false_positives" - count of false positives from the split validation_ambient set
#   - "ambient_false_positives_per_hour" - estimated number of false positives per hour on the split validation_ambient set
config["target_minimization"] = 0.9
config["minimization_metric"] = None  # Set to None to disable

config["maximization_metric"] = "average_viable_recall"

training_params_path = "/data/training_parameters.yaml"
with open(training_params_path, "w") as file:
    documents = yaml.dump(config, file)
print(f"Training parameters saved to {training_params_path}")

In [None]:
# Trains a model. When finished, it will quantize and convert the model to a
# streaming version suitable for on-device detection.
# It will resume if stopped, but it will start over at the configured training
# steps in the yaml file.
# Change --train 0 to only convert and test the best-weighted model.
# On Google colab, it doesn't print the mini-batch results, so it may appear
# stuck for several minutes! Additionally, it is very slow compared to training
# on a local GPU.
import os
import sys

# LD_LIBRARY_PATH might be needed if base TF image doesn't set it up for all custom ops, but usually it's fine.
# os.environ['LD_LIBRARY_PATH'] = "/usr/lib/x86_64-linux-gnu:" + os.environ.get('LD_LIBRARY_PATH', '')

training_params_path = "/data/training_parameters.yaml"
print(f"Starting model training using config: {training_params_path}")

!"{sys.executable}" -m microwakeword.model_train_eval \
    --training_config='{training_params_path}' \
    --train 1 \
    --restore_checkpoint 1 \
    --test_tf_nonstreaming 0 \
    --test_tflite_nonstreaming 0 \
    --test_tflite_nonstreaming_quantized 0 \
    --test_tflite_streaming 0 \
    --test_tflite_streaming_quantized 1 \
    --use_weights "best_weights" \
    mixednet \
    --pointwise_filters "64,64,64,64" \
    --repeat_in_block  "1, 1, 1, 1" \
    --mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \
    --residual_connection "0,0,0,0" \
    --first_conv_filters 32 \
    --first_conv_kernel_size 5 \
    --stride 3

print("Model training/evaluation finished.")

In [None]:
import shutil
import json
import os
from IPython.display import FileLink, display

# target_word should be defined in an earlier cell (e.g., where test sample is generated)
# If not, define it here or ensure it's passed correctly.
# For example: target_word = 'khum_puter' 

# Define the source path and desired download location for the TFLite file
source_tflite_path = "/data/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
destination_tflite_path = "/data/stream_state_internal_quant.tflite" # Will be accessible from host via mapped /data volume

if os.path.exists(source_tflite_path):
    shutil.copy(source_tflite_path, destination_tflite_path)
    print(f"Copied TFLite model to {destination_tflite_path}")
else:
    print(f"ERROR: Trained TFLite model not found at {source_tflite_path}")

# Define the JSON file content
json_data = {
    "type": "micro",
    "wake_word": target_word,  # Using the target_word variable
    "author": "master phooey", # Or your name/handle
    "website": "https://github.com/kiwina/MicroWakeWord-Trainer-Docker", # Updated to kiwina fork
    "model": "stream_state_internal_quant.tflite", # Relative path for use with ESPHome
    "trained_languages": ["en"],
    "version": 2, # Increment if you retrain and improve
    "micro": {
        "probability_cutoff": 0.97, # Adjust based on testing
        "sliding_window_size": 5,
        "feature_step_size": 10,
        "tensor_arena_size": 30000, # Adjust based on model needs
        "minimum_esphome_version": "2024.7.0"
    }
}

# Define the JSON file path
destination_json_path = "/data/stream_state_internal_quant.json"

# Write the JSON file
with open(destination_json_path, "w") as json_file:
    json.dump(json_data, json_file, indent=2)
print(f"Created JSON metadata at {destination_json_path}")

# Generate download links for both files (if running in a Jupyter environment that supports this)
print("\nAccess your files in the 'microwakeword-trainer-data' directory on your host machine.")
if os.path.exists(destination_tflite_path):
    print("TFLite Model Link (for Jupyter environments):")
    display(FileLink(destination_tflite_path))
if os.path.exists(destination_json_path):
    print("\nJSON Metadata Link (for Jupyter environments):")
    display(FileLink(destination_json_path))