In [None]:
# Save a yaml config that controls the training process
import yaml
import os

config = {}

config["train_dir"] = "trained_models/alexa_model"

# Each ["features"] entry has the following parameters that dictate how the spectrograms are weighted, whether they represent the wake word or not, and how they are truncated
#  sampling_weight: Weight for choosing a spectrogram from this set in the batch
#  penalty_weight: Penalizing weight for incorrect predictions from this set
#  truth: Boolean whether this set has positive samples or negative samples
#  truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated
#       - random: choose a random portion of the entire spectrogram - useful for long negative samples
#       - truncate_start: remove the start of the spectrogram
#       - truncate_end: remove the end of the spectrogram
#       - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets
config["features"] = [
    # You can augment clips 'on-the-fly' while training
    # The dictionary entries correspond to arguments for the Clips, Augmentation, and SpectrogramGeneration classes
    {
        'clips_settings': {
            'input_directory': 'generated_samples/positive/training',
            'file_pattern': '*.wav',
            'max_clip_duration_s': 1.29,   
            'min_clip_duration_s': 0.5,
        },
        'augmentation_settings': {
            "impulse_paths": [
                "./mit_rirs"
            ],
            "background_paths": [
                "./audioset_16k", "./fma"
            ],
            'augmentation_probabilities': {
                    "SevenBandParametricEQ": 0.0,
                    "TanhDistortion": 0.0,
                    "PitchShift": 0.0,
                    "BandStopFilter": 0.0,
                    "AddColorNoise": 0.25,
                    "AddBackgroundNoise": 1.0,
                    "Gain": 1.0,
                    "RIR": 0.25,
            },
            'augmentation_duration_s': 3.99,
            'max_jitter_s': 0.2,
            'min_jitter_s': 0.1,
            'background_min_snr_db': -10,
            'background_max_snr_db': 5,
        },
        'spectrogram_generation_settings': {
            "slide_frames": 5,
        },
        'truncation_strategy': 'truncate_start',
        'sampling_weight': 0.5,
        'penalty_weight': 1,
        'truth': True,
        'type': "clips",   
    },
    {
        'clips_settings': {
            'input_directory': 'generated_samples/negative/training',
            'file_pattern': '*.wav',
            'max_clip_duration_s': 3.69,   
            'min_clip_duration_s': None,
        },
        'augmentation_settings': {
            "impulse_paths": [
                "./mit_rirs"
            ],
            "background_paths": [
                "./audioset_16k", "./fma"
            ],
            'augmentation_probabilities': {
                    "SevenBandParametricEQ": 0.0,
                    "TanhDistortion": 0.0,
                    "PitchShift": 0.0,
                    "BandStopFilter": 0.0,
                    "AddColorNoise": 0.25,
                    "AddBackgroundNoise": 0.9,
                    "Gain": 1.0,
                    "RIR": 0.33,
            },
            'augmentation_duration_s': 3.99,
            'max_jitter_s': 0.2,
            'min_jitter_s': 0.1,
            'background_min_snr_db': -10,
            'background_max_snr_db': 0,
        },
        'spectrogram_generation_settings': {
            "slide_frames": 5,
        },
        'truncation_strategy': 'truncate_start',
        'sampling_weight': 0.5,
        'penalty_weight': 0.33,
        'truth': False,
        'type': "clips",   
    },
# Each features_dir should have at least one of the following folders with this structure when you are using the "mmap" type:
#  training/
#    ragged_mmap_folders_ending_in_mmap
#  testing/
#    ragged_mmap_folders_ending_in_mmap
#  testing_ambient/
#    ragged_mmap_folders_ending_in_mmap
#  validation/
#    ragged_mmap_folders_ending_in_mmap
#  validation_ambient/
#    ragged_mmap_folders_ending_in_mmap
#
#  You need at least one of the 5 root folders
    {
        "features_dir": "generated_sets",
        "sampling_weight": 0.0,
        "penalty_weight": 1,
        "truth": True,
        "truncation_strategy": "truncate_start",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/speech_background",
        "sampling_weight": 5.0,
        "penalty_weight": 1,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/dinner_party_background",
        "sampling_weight": 5.0,
        "penalty_weight": 1,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
    {
        "features_dir": "negative_datasets/no_speech_background",
        "sampling_weight": 2,
        "penalty_weight": 1,
        "truth": False,
        "truncation_strategy": "random",
        "type": "mmap",
    },
]

# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps
config["training_steps"] = [20000, 20000, 20000, 20000]

# Penalizing weight for incorrect class predictions - lists that correspond to training steps
config["positive_class_weight"] = [1]
config["negative_class_weight"] = [1, 1.5, 2, 3]

# Learning rates for Adam optimizer - list that corresponds to training steps
config["learning_rates"] = [
    0.001,
    0.0005,
    0.0002,
    0.0001,
] 
config["batch_size"] = 128

#SpecAugment parameters in lists that correspond to training steps
config["time_mask_max_size"] = [0]
config["time_mask_count"] = [0]
config["freq_mask_max_size"] = [7]
config["freq_mask_count"] = [2]

# Test the validation sets every this many steps
config["eval_step_interval"] = 500

# Duration of the last layer before pooling or a fully connected layer
config["clip_duration_ms"] = 590  

# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization
# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize
# Available metrics:
#   - "loss" - cross entropy error on validation set
#   - "accuracy" - accuracy of validation set
#   - "recall" - recall of validation set
#   - "precision" - precision of validation set
#   - "false_positive_rate" - false positive rate of validation set
#   - "false_negative_rate" - false negative rate of validation set
#   - "ambient_false_positives" - count of false positives from the split validation_ambient set
#   - "ambient_false_positives_per_hour" - estimated number of false positives per hour on the split validation_ambient set
#   - "average_viable_recall" - the average recall rates for false accepts per hour rates between 0 and 2.0
config["target_minimization"] = 0.0

config["minimization_metric"] = None  # Set to None to disable and only maximize
config["maximization_metric"] = "average_viable_recall"

with open(os.path.join("training_parameters.yaml"), "w") as file:
    documents = yaml.dump(config, file)

In [None]:
# Train a MixedNet model. This produces great models and is very fast on the device while using small amounts of memory.
!python -m microwakeword.model_train_eval \
--training_config='training_parameters.yaml' \
--train 1 \
--restore_checkpoint 1 \
--test_tf_nonstreaming 0 \
--test_tflite_nonstreaming 0 \
--test_tflite_streaming 1 \
--test_tflite_streaming_quantized 1 \
--use_weights "best_weights" \
mixednet \
--pointwise_filters "48,48,48,48" \
--repeat_in_block  "1, 1, 1, 1" \
--mixconv_kernel_sizes '[5], [9], [13], [21]' \
--residual_connection "0,0,0,0" \
--first_conv_filters 32

In [None]:
# Train an Inception model. It can produce good models, but it is slow at inference on the device and uses more memory, so you can't run multiple models at once. You may need to increase config["clip_duration_ms"] so that the first layer's time dimension is at least 75
!python -m microwakeword.model_train_eval \
--training_config='training_parameters.yaml' \
--train 0 \
--restore_checkpoint 1 \
--test_tf_nonstreaming 0 \
--test_tflite_nonstreaming 0 \
--test_tflite_streaming 0 \
--test_tflite_streaming_quantized 1 \
inception \
--cnn1_filters '32' \
--cnn1_kernel_sizes '5' \
--cnn1_subspectral_groups '1' \
--cnn2_filters1 '24,24,24' \
--cnn2_filters2 '32,64,96' \
--cnn2_kernel_sizes '3,5,5' \
--cnn2_subspectral_groups '1,1,1' \
--cnn2_dilation '1,1,1' \
--dropout 0.8