# Training a microWakeWord Model

This notebook steps you through training a basic microWakeWord model. It is intended as a **starting point** for advanced users. You should use Python 3.10.

**The model generated will most likely not be usable for everyday use; it may be difficult to trigger or falsely activates too frequently. You will most likely have to experiment with many different settings to obtain a decent model!**

In the comment at the start of certain blocks, I note some specific settings to consider modifying.

This runs on Google Colab, but is extremely slow compared to training on a local GPU. If you must use Colab, be sure to Change the runtime type to a GPU. Even then, it still slow!

At the end of this notebook, you will be able to download a tflite file. To use this in ESPHome, you need to write a model manifest JSON file. See the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for the details and the [model repo](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2) for examples.

In [1]:
# Installs microWakeWord. Be sure to restart the session after this is finished.
import platform

if platform.system() == "Darwin":
    # `pymicro-features` is installed from a fork to support building on macOS
    !pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version'

# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter
!pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f'

!git clone https://github.com/kahrendt/microWakeWord
!pip install -e ./microWakeWord

Collecting git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f
  Cloning https://github.com/whatsnowplaying/audio-metadata (to revision d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f) to /tmp/pip-req-build-c8nckgqg
  Running command git clone --filter=blob:none --quiet https://github.com/whatsnowplaying/audio-metadata /tmp/pip-req-build-c8nckgqg
  Running command git rev-parse -q --verify 'sha^d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f'
  Running command git fetch -q https://github.com/whatsnowplaying/audio-metadata d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f
  Running command git checkout -q d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f
  Resolved https://github.com/whatsnowplaying/audio-metadata to commit d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
fatal: destination path 'microWakeWord' alre

In [2]:
import os
from IPython.display import Audio, display

# 1. Configuration
target_word = "രാഘവാ"
model_name = "ml_IN-meera-medium"
model_path_hf = "ml/ml_IN/meera/medium"

# Download Logic (Standard)
if not os.path.exists("./piper_standalone"):
    !wget -q https://github.com/rhasspy/piper/releases/download/v1.2.0/piper_amd64.tar.gz
    !tar -xf piper_amd64.tar.gz
    !mv piper piper_standalone

!mkdir -p models
!wget -q -L -O models/{model_name}.onnx "https://huggingface.co/rhasspy/piper-voices/resolve/main/{model_path_hf}/{model_name}.onnx"
!wget -q -L -O models/{model_name}.onnx.json "https://huggingface.co/rhasspy/piper-voices/resolve/main/{model_path_hf}/{model_name}.onnx.json"

target_word = "രാഘവാ"
model_name = "ml_IN-arjun-medium"
model_path_hf = "ml/ml_IN/arjun/medium"
!wget -q -L -O models/{model_name}.onnx "https://huggingface.co/rhasspy/piper-voices/resolve/main/{model_path_hf}/{model_name}.onnx"
!wget -q -L -O models/{model_name}.onnx.json "https://huggingface.co/rhasspy/piper-voices/resolve/main/{model_path_hf}/{model_name}.onnx.json"



# 2. Generate variations with SLOWER speeds
# 1.0 is normal. We are testing 1.1 to 1.5 to slow it down.
print(f"Generating slower Malayalam variations for: {target_word}")
os.makedirs("malayalam_slow_test", exist_ok=True)

# Testing different 'slowness' levels
slow_speeds = [1.0, 1.15, 1.25, 1.35, 1.45]

for i, speed in enumerate(slow_speeds):
    output_file = f"malayalam_slow_test/slow_{speed}.wav"
    # Using --length_scale to slow down the voice
    !echo '{target_word}' | ./piper_standalone/piper \
        --model models/{model_name}.onnx \
        --length_scale {speed} \
        --output_file {output_file}

# 3. Playback
print("\n--- Malayalam Audio Results (Slowed Down) ---")
for speed in slow_speeds:
    file_path = f"malayalam_slow_test/slow_{speed}.wav"
    if os.path.exists(file_path):
        print(f"Length Scale: {speed} ({'Normal' if speed==1.0 else 'Slower'})")
        display(Audio(file_path))

Generating slower Malayalam variations for: രാഘവാ
[2026-01-14 07:15:53.657] [piper] [[32minfo[m] Loaded voice in 0.28397395 second(s)
[2026-01-14 07:15:53.657] [piper] [[32minfo[m] Initialized piper
malayalam_slow_test/slow_1.0.wav
[2026-01-14 07:15:53.840] [piper] [[32minfo[m] Real-time factor: 0.13903193625532673 (infer=0.177557339 sec, audio=1.2770975056689342 sec)
[2026-01-14 07:15:53.840] [piper] [[32minfo[m] Terminated piper
[2026-01-14 07:15:54.277] [piper] [[32minfo[m] Loaded voice in 0.299771774 second(s)
[2026-01-14 07:15:54.278] [piper] [[32minfo[m] Initialized piper
malayalam_slow_test/slow_1.15.wav
[2026-01-14 07:15:54.479] [piper] [[32minfo[m] Real-time factor: 0.14350187300052963 (infer=0.196594312 sec, audio=1.3699773242630386 sec)
[2026-01-14 07:15:54.479] [piper] [[32minfo[m] Terminated piper
[2026-01-14 07:15:54.863] [piper] [[32minfo[m] Loaded voice in 0.278372288 second(s)
[2026-01-14 07:15:54.864] [piper] [[32minfo[m] Initialized piper
malayala

Length Scale: 1.15 (Slower)


Length Scale: 1.25 (Slower)


Length Scale: 1.35 (Slower)


Length Scale: 1.45 (Slower)


In [4]:
import os
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock

# Configuration
target_word = "രാഘവാ"
models = {
    "meera": "models/ml_IN-meera-medium.onnx",
    "arjun": "models/ml_IN-arjun-medium.onnx"
}
samples_per_model = 500
max_workers = 8  # Number of parallel processes - adjust based on your CPU

# Variations
length_scales = [1.2, 1.3, 1.4, 1.5, 1.6]
noise_scales = [0.5, 0.667, 0.8, 1.0]

# Piper executable
piper_exe = "./piper_standalone/piper.exe" if os.name == 'nt' else "./piper_standalone/piper"

# Progress tracking
progress_lock = Lock()
progress_counters = {}

def generate_sample(model_name, model_path, sample_idx, output_dir):
    """Generate a single sample"""
    length = length_scales[sample_idx % len(length_scales)]
    noise = noise_scales[sample_idx % len(noise_scales)]
    output_file = f"{output_dir}/{sample_idx}.wav"

    try:
        # Use subprocess for better control
        cmd = f'echo {target_word} | "{piper_exe}" --model "{model_path}" --length_scale {length} --noise_scale {noise} --output_file "{output_file}"'

        result = subprocess.run(
            cmd,
            shell=True,
            capture_output=True,
            text=True,
            timeout=30
        )

        if result.returncode == 0:
            # Update progress
            with progress_lock:
                progress_counters[model_name] += 1
                current = progress_counters[model_name]
                if current % 50 == 0:
                    print(f"  [{model_name}] Progress: {current}/{samples_per_model}")
            return True
        else:
            print(f"⚠️ [{model_name}] Sample {sample_idx} failed: {result.stderr}")
            return False

    except Exception as e:
        print(f"⚠️ [{model_name}] Exception at sample {sample_idx}: {e}")
        return False

# Main generation loop
print(f"🚀 Starting parallel generation with {max_workers} workers...")

for model_name, model_path in models.items():
    if not os.path.exists(model_path):
        print(f"⚠️ Skipping {model_name}: Model not found")
        continue

    print(f"\n🎙️ Generating {samples_per_model} samples for {model_name}...")

    output_dir = f"generated_samples/{model_name}"
    os.makedirs(output_dir, exist_ok=True)

    # Initialize progress counter
    progress_counters[model_name] = 0

    # Generate samples in parallel
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = []

        for i in range(samples_per_model):
            future = executor.submit(
                generate_sample,
                model_name,
                model_path,
                i,
                output_dir
            )
            futures.append(future)

        # Wait for all to complete
        success_count = 0
        for future in as_completed(futures):
            if future.result():
                success_count += 1

    print(f"✅ [{model_name}] Completed: {success_count}/{samples_per_model} samples")

print(f"\n✅ All samples generated in ./generated_samples/")
print(f"📊 Total samples: {sum(progress_counters.values())}")

🚀 Starting parallel generation with 8 workers...

🎙️ Generating 500 samples for meera...
  [meera] Progress: 50/500
  [meera] Progress: 100/500
  [meera] Progress: 150/500
  [meera] Progress: 200/500
  [meera] Progress: 250/500
  [meera] Progress: 300/500
  [meera] Progress: 350/500
  [meera] Progress: 400/500
  [meera] Progress: 450/500
  [meera] Progress: 500/500
✅ [meera] Completed: 500/500 samples

🎙️ Generating 500 samples for arjun...
  [arjun] Progress: 50/500
  [arjun] Progress: 100/500
  [arjun] Progress: 150/500
  [arjun] Progress: 200/500
  [arjun] Progress: 250/500
  [arjun] Progress: 300/500
  [arjun] Progress: 350/500
  [arjun] Progress: 400/500
  [arjun] Progress: 450/500
  [arjun] Progress: 500/500
✅ [arjun] Completed: 500/500 samples

✅ All samples generated in ./generated_samples/
📊 Total samples: 1000


In [7]:
# Downloads audio data for augmentation. This can be slow!
# **Google Colab optimized with working download links**

# Install required audio libraries
import sys
!{sys.executable} -m pip install -q soundfile librosa audioread ffmpeg-python

import datasets
import scipy
import os
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
import urllib.request
import tarfile
import zipfile
import soundfile as sf

# Progress callback for downloads
def download_with_progress(url, output_path):
    """Download with progress bar"""
    print(f"📥 Starting download from {url.split('/')[-1]}...")

    class DownloadProgressBar(tqdm):
        def update_to(self, b=1, bsize=1, tsize=None):
            if tsize is not None:
                self.total = tsize
            self.update(b * bsize - self.n)

    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t:
        urllib.request.urlretrieve(url, output_path, reporthook=t.update_to)
    print("✅ Download complete!")

## Download MIT RIR data
output_dir = "./mit_rirs"
if not os.path.exists(output_dir):
    os.makedirs(output_dir, exist_ok=True)
    print("📥 Loading MIT RIR dataset...")

    try:
        # Use non-streaming mode for reliability
        rir_dataset = datasets.load_dataset(
            "davidscripka/MIT_environmental_impulse_responses",
            split="train[:500]",  # Limit to 500 samples
            trust_remote_code=True
        )

        # Save clips to 16-bit PCM wav files
        count = 0
        for row in tqdm(rir_dataset, desc="Processing RIR files"):
            try:
                name = row['audio']['path'].split('/')[-1]
                audio_array = row['audio']['array']
                sample_rate = row['audio']['sampling_rate']

                # Resample to 16kHz if needed
                if sample_rate != 16000:
                    import librosa
                    audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000)

                scipy.io.wavfile.write(
                    os.path.join(output_dir, name),
                    16000,
                    (audio_array * 32767).astype(np.int16)
                )
                count += 1
            except Exception as e:
                continue

        print(f"✅ MIT RIR dataset complete: {count} files")
    except Exception as e:
        print(f"⚠️ MIT RIR dataset download failed: {e}")
        print("Continuing without MIT RIR data...")
else:
    print("✅ MIT RIRs already exist, skipping")

## Download FMA using direct file processing instead of datasets library
if not os.path.exists("fma_16k"):
    os.makedirs("fma", exist_ok=True)
    fname = "fma_xs.zip"

    # Alternative: Use smaller FMA dataset from different source
    link = "https://os.unil.cloud.switch.ch/fma/fma_small.zip"
    out_path = os.path.join("fma", "fma_small.zip")

    print(f"📥 Downloading FMA dataset (~7.2GB - this will take time on Colab)...")
    print("⚠️ This is a large download. Consider using pre-generated features instead.")

    try:
        download_with_progress(link, out_path)

        # Extract
        print("📦 Extracting FMA (this may take 5-10 minutes)...")
        with zipfile.ZipFile(out_path, 'r') as zip_ref:
            zip_ref.extractall("fma")

        # Convert to 16kHz using librosa directly
        output_dir = "./fma_16k"
        os.makedirs(output_dir, exist_ok=True)

        mp3_files = list(Path("fma/fma_small").glob("**/*.mp3"))[:1000]  # Limit to 1000 files
        print(f"Converting {len(mp3_files)} FMA files to 16kHz using librosa...")

        import librosa
        for idx, mp3_file in enumerate(tqdm(mp3_files, desc="Converting FMA to 16kHz")):
            try:
                # Load with librosa (handles mp3 without torchcodec)
                audio, sr = librosa.load(str(mp3_file), sr=16000, mono=True)

                # Save as wav
                name = mp3_file.stem + ".wav"
                scipy.io.wavfile.write(
                    os.path.join(output_dir, name),
                    16000,
                    (audio * 32767).astype(np.int16)
                )
            except Exception as e:
                if idx % 100 == 0:
                    print(f"⚠️ Error at file {idx}: {e}")
                continue

        print(f"✅ FMA conversion complete!")

    except Exception as e:
        print(f"⚠️ FMA download/processing failed: {e}")
        print("Continuing without FMA data...")
else:
    print("✅ FMA already exists, skipping")

## Alternative: Download pre-generated background noise from other sources
# Using ESC-50 environmental sounds as alternative to audioset
output_dir = "./esc50"
if not os.path.exists(output_dir):
    print("📥 Downloading ESC-50 environmental sounds (alternative to AudioSet)...")

    try:
        link = "https://github.com/karolpiczak/ESC-50/archive/master.zip"
        out_path = "esc50.zip"

        download_with_progress(link, out_path)

        print("📦 Extracting ESC-50...")
        with zipfile.ZipFile(out_path, 'r') as zip_ref:
            zip_ref.extractall(".")

        os.rename("ESC-50-master", "esc50")

        # Convert to 16kHz
        output_dir = "./esc50_16k"
        os.makedirs(output_dir, exist_ok=True)

        wav_files = list(Path("esc50/audio").glob("*.wav"))
        print(f"Converting {len(wav_files)} ESC-50 files to 16kHz...")

        import librosa
        for idx, wav_file in enumerate(tqdm(wav_files, desc="Converting ESC-50")):
            try:
                audio, sr = librosa.load(str(wav_file), sr=16000, mono=True)
                name = wav_file.name
                scipy.io.wavfile.write(
                    os.path.join(output_dir, name),
                    16000,
                    (audio * 32767).astype(np.int16)
                )
            except Exception as e:
                continue

        print(f"✅ ESC-50 complete!")

    except Exception as e:
        print(f"⚠️ ESC-50 download failed: {e}")
else:
    print("✅ ESC-50 already exists, skipping")

print("\n✅ Augmentation audio data download complete!")
print("📊 Summary:")
mit_count = len(list(Path('./mit_rirs').glob('*.wav'))) if os.path.exists('./mit_rirs') else 0
esc50_count = len(list(Path('./esc50_16k').glob('*.wav'))) if os.path.exists('./esc50_16k') else 0
fma_count = len(list(Path('./fma_16k').glob('*.wav'))) if os.path.exists('./fma_16k') else 0

print(f"  - MIT RIRs: {mit_count} files")
print(f"  - ESC-50 (16kHz): {esc50_count} files")
print(f"  - FMA (16kHz): {fma_count} files")
print(f"  - Total: {mit_count + esc50_count + fma_count} files")

if mit_count + esc50_count + fma_count == 0:
    print("\n⚠️ WARNING: No background audio downloaded!")
    print("Consider using pre-downloaded datasets or smaller alternatives.")

✅ MIT RIRs already exist, skipping
✅ FMA already exists, skipping
📥 Downloading ESC-50 environmental sounds (alternative to AudioSet)...
📥 Starting download from master.zip...


master.zip: 0.00B [00:00, ?B/s]

✅ Download complete!
📦 Extracting ESC-50...
Converting 2000 ESC-50 files to 16kHz...


Converting ESC-50:   0%|          | 0/2000 [00:00<?, ?it/s]

✅ ESC-50 complete!

✅ Augmentation audio data download complete!
📊 Summary:
  - MIT RIRs: 0 files
  - ESC-50 (16kHz): 2000 files
  - FMA (16kHz): 0 files
  - Total: 2000 files


In [6]:
# Sets up the augmentations.
# To improve your model, experiment with these settings and use more sources of
# background clips.
import os
from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration

clips = Clips(input_directory='generated_samples',
              file_pattern='*.wav',
              max_clip_duration_s=None,
              remove_silence=False,
              random_split_seed=10,
              split_count=0.1,
              )

# Updated augmentation settings for Colab with downloaded datasets
augmenter = Augmentation(augmentation_duration_s=3.2,
                         augmentation_probabilities = {
                                "SevenBandParametricEQ": 0.1,
                                "TanhDistortion": 0.1,
                                "PitchShift": 0.1,
                                "BandStopFilter": 0.1,
                                "AddColorNoise": 0.1,
                                "AddBackgroundNoise": 0.75,
                                "Gain": 1.0,
                                "RIR": 0.5,
                            },
                         impulse_paths = [],  # Empty - no MIT RIR files downloaded
                         # Updated to use ESC-50 instead of audioset (and FMA if available)
                         background_paths = ['esc50_16k', 'fma_16k'] if os.path.exists('fma_16k') else ['esc50_16k'],
                         background_min_snr_db = -5,
                         background_max_snr_db = 10,
                         min_jitter_s = 0.195,
                         max_jitter_s = 0.205,
                         )

print("✅ Augmentation setup complete!")

✅ Augmentation setup complete!


In [9]:
import os
from pathlib import Path

# Check if directory exists
if os.path.exists('generated_samples'):
    print("✅ generated_samples directory exists")

    # Check for WAV files recursively
    wav_files = list(Path('generated_samples').rglob('*.wav'))
    print(f"📊 Found {len(wav_files)} WAV files")

    if wav_files:
        print("\n📂 Sample file locations:")
        for f in wav_files[:5]:  # Show first 5
            print(f"  - {f}")
    else:
        print("⚠️ No WAV files found in generated_samples")
else:
    print("❌ generated_samples directory does NOT exist")
    print("\n💡 You need to run the sample generation cell first!")

✅ generated_samples directory exists
📊 Found 1000 WAV files

📂 Sample file locations:
  - generated_samples/meera/96.wav
  - generated_samples/meera/139.wav
  - generated_samples/meera/92.wav
  - generated_samples/meera/167.wav
  - generated_samples/meera/392.wav


In [10]:
# Augment samples and save the training, validation, and testing sets.
# Validating and testing samples generated the same way can make the model
# benchmark better than it performs in real-word use. Use real samples or TTS
# samples generated with a different TTS engine to potentially get more accurate
# benchmarks.

import os
from mmap_ninja.ragged import RaggedMmap

output_dir = 'generated_augmented_features'

if not os.path.exists(output_dir):
    os.mkdir(output_dir)

splits = ["training", "validation", "testing"]
for split in splits:
  out_dir = os.path.join(output_dir, split)
  if not os.path.exists(out_dir):
      os.mkdir(out_dir)


  split_name = "train"
  repetition = 2

  spectrograms = SpectrogramGeneration(clips=clips,
                                     augmenter=augmenter,
                                     slide_frames=10,    # Uses the same spectrogram repeatedly, just shifted over by one frame. This simulates the streaming inferences while training/validating in nonstreaming mode.
                                     step_ms=10,
                                     )
  if split == "validation":
    split_name = "validation"
    repetition = 1
  elif split == "testing":
    split_name = "test"
    repetition = 1
    spectrograms = SpectrogramGeneration(clips=clips,
                                     augmenter=augmenter,
                                     slide_frames=1,    # The testing set uses the streaming version of the model, so no artificial repetition is necessary
                                     step_ms=10,
                                     )

  RaggedMmap.from_generator(
      out_dir=os.path.join(out_dir, 'wakeword_mmap'),
      sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),
      batch_size=100,
      verbose=True,
  )

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [8]:
# Downloads pre-generated spectrogram features (made for microWakeWord in
# particular) for various negative datasets. This can be slow!

output_dir = './negative_datasets'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    link_root = "https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/"
    filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']
    for fname in filenames:
        link = link_root + fname

        zip_path = f"negative_datasets/{fname}"
        !wget -O {zip_path} {link}
        !unzip -q {zip_path} -d {output_dir}

--2026-01-14 07:20:52--  https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/dinner_party.zip
Resolving huggingface.co (huggingface.co)... 3.165.160.61, 3.165.160.11, 3.165.160.12, ...
Connecting to huggingface.co (huggingface.co)|3.165.160.61|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://us.gcp.cdn.hf.co/xet-bridge-us/65e327bc1445a768ed343b8c/228d7e72cd5fdc4e6e57da36b88a4c227d34cb8dc44041078b4c4b65dc75848d?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27dinner_party.zip%3B+filename%3D%22dinner_party.zip%22%3B&response-content-type=application%2Fzip&Expires=1768378852&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiRXBvY2hUaW1lIjoxNzY4Mzc4ODUyfX0sIlJlc291cmNlIjoiaHR0cHM6Ly91cy5nY3AuY2RuLmhmLmNvL3hldC1icmlkZ2UtdXMvNjVlMzI3YmMxNDQ1YTc2OGVkMzQzYjhjLzIyOGQ3ZTcyY2Q1ZmRjNGU2ZTU3ZGEzNmI4OGE0YzIyN2QzNGNiOGRjNDQwNDEwNzhiNGM0YjY1ZGM3NTg0OGRcXD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSomcmVzcG9uc2UtY29udGVudC1

In [None]:
# Save a yaml config that controls the training process
# These hyperparamters can make a huge different in model quality.
# Experiment with sampling and penalty weights and increasing the number of
# training steps.

import yaml
import os

config = {}

config["window_step_ms"] = 10

config["train_dir"] = (
    "trained_models/wakeword"
)


# Each feature_dir should have at least one of the following folders with this structure:
#  training/
#    ragged_mmap_folders_ending_in_mmap
#  testing/
#    ragged_mmap_folders_ending_in_mmap
#  testing_ambient/
#    ragged_mmap_folders_ending_in_mmap
#  validation/
#    ragged_mmap_folders_ending_in_mmap
#  validation_ambient/
#    ragged_mmap_folders_ending_in_mmap
#
#  sampling_weight: Weight for choosing a spectrogram from this set in the batch
#  penalty_weight: Penalizing weight for incorrect predictions from this set
#  truth: Boolean whether this set has positive samples or negative samples
#  truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated
#       - random: choose a random portion of the entire spectrogram - useful for long negative samples
#       - truncate_start: remove the start of the spectrogram
#       - truncate_end: remove the end of the spectrogram
#       - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets

config["features"] = [
    {
        "features_dir": "generated_augmented_features",
        "sampling_weight": 2.0,
        "penalty_weight": 1.0,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/speech",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/dinner_party",
        "sampling_weight": 10.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/no_speech",
        "sampling_weight": 5.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    { # Only used for validation and testing
        "features_dir": "negative_datasets/dinner_party_eval",
        "sampling_weight": 0.0,
        "penalty_weight": 1.0,
        "truth": False,
        "truncation_strategy": "split",
        "type": "mmap",
    },
]

# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps
config["training_steps"] = [10000]

# Penalizing weight for incorrect class predictions - lists that correspond to training steps
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [20]

config["learning_rates"] = [
    0.001,
]  # Learning rates for Adam optimizer - list that corresponds to training steps
config["batch_size"] = 128

config["time_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["time_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps
config["freq_mask_max_size"] = [
    0
]  # SpecAugment - list that corresponds to training steps
config["freq_mask_count"] = [0]  # SpecAugment - list that corresponds to training steps

config["eval_step_interval"] = (
    500  # Test the validation sets after every this many steps
)
config["clip_duration_ms"] = (
    1500  # Maximum length of wake word that the streaming model will accept
)

# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization
# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize
# Available metrics:
#   - "loss" - cross entropy error on validation set
#   - "accuracy" - accuracy of validation set
#   - "recall" - recall of validation set
#   - "precision" - precision of validation set
#   - "false_positive_rate" - false positive rate of validation set
#   - "false_negative_rate" - false negative rate of validation set
#   - "ambient_false_positives" - count of false positives from the split validation_ambient set
#   - "ambient_false_positives_per_hour" - estimated number of false positives per hour on the split validation_ambient set
config["target_minimization"] = 0.9
config["minimization_metric"] = None  # Set to None to disable

config["maximization_metric"] = "average_viable_recall"

with open(os.path.join("training_parameters.yaml"), "w") as file:
    documents = yaml.dump(config, file)

In [None]:
# Trains a model. When finished, it will quantize and convert the model to a
# streaming version suitable for on-device detection.
# It will resume if stopped, but it will start over at the configured training
# steps in the yaml file.
# Change --train 0 to only convert and test the best-weighted model.
# On Google colab, it doesn't print the mini-batch results, so it may appear
# stuck for several minutes! Additionally, it is very slow compared to training
# on a local GPU.

!python -m microwakeword.model_train_eval \
--training_config='training_parameters.yaml' \
--train 1 \
--restore_checkpoint 1 \
--test_tf_nonstreaming 0 \
--test_tflite_nonstreaming 0 \
--test_tflite_nonstreaming_quantized 0 \
--test_tflite_streaming 0 \
--test_tflite_streaming_quantized 1 \
--use_weights "best_weights" \
mixednet \
--pointwise_filters "64,64,64,64" \
--repeat_in_block  "1, 1, 1, 1" \
--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \
--residual_connection "0,0,0,0" \
--first_conv_filters 32 \
--first_conv_kernel_size 5 \
--stride 3

In [None]:
# Downloads the tflite model file. To use on the device, you need to write a
# Model JSON file. See https://esphome.io/components/micro_wake_word for the
# documentation and
# https://github.com/esphome/micro-wake-word-models/tree/main/models/v2 for
# examples. Adjust the probability threshold based on the test results obtained
# after training is finished. You may also need to increase the Tensor arena
# model size if the model fails to load.

from google.colab import files

files.download(f"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite")