# Few-Shot KWS

This notebook heavily borrows from the work done here:
- [github.com/harvard-edge/multilingual_kws](https://github.com/harvard-edge/multilingual_kws)
- [multilingual_kws_intro_tutorial.ipynb](https://colab.research.google.com/github/harvard-edge/multilingual_kws/blob/main/multilingual_kws_intro_tutorial.ipynb#scrollTo=rK2Bow1THEvp)



In [None]:
%shell apt-get -qq install sox

In [None]:
%pip install samplerate

In [None]:
%shell git clone https://github.com/harvard-edge/multilingual_kws/

In [None]:
import json
import logging
import os
import shutil
import subprocess
import sys

sys.path.append("/content/multilingual_kws/")

import absl
import librosa
import matplotlib.pyplot as plt
import numpy as np
import soundfile
import tensorflow as tf
from absl import logging as absl_logging

absl_logging.set_verbosity(absl.logging.ERROR)

from pathlib import Path

from google.colab.output import eval_js
from IPython.display import HTML, Audio, display
from multilingual_kws.embedding import input_data

In [None]:
KEYWORD = "tiempo"
SAMPLE_RATE = 16000
SAMPLES_DIR = "/content/samples"

# Downloads



In [None]:
assets = [
    (
        "http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz",
        "/content/speech_commands",
    ),
    (
        "https://github.com/harvard-edge/multilingual_kws/releases/download/v0.1-alpha/multilingual_context_73_0.8011.tar.gz",  # noqa
        "/content/embedding_model",
    ),
    (
        "https://github.com/harvard-edge/multilingual_kws/releases/download/v0.1-alpha/unknown_files.tar.gz",  # noqa
        "/content/unknown_files",
    ),
]

for origin, cache_subdir in assets:
    tf.keras.utils.get_file(origin=origin, untar=True, cache_subdir=cache_subdir)

# Samples

Record around ~20 samples of the *KEYWORD* above. The first 3-5 will be used for visualization and training purposes. The rest will be used for testing.

In [None]:
SAMPLES_HTML = HTML(
    """
<script>
const audioCtx = new (window.AudioContext || window.webkitAudioContext)();
const doneBtn = document.getElementById("done-btn");
const keepBtn = document.getElementById("keep-btn");
const keptLbl = document.getElementById("kept-lbl");
const playBtn = document.getElementById("play-btn");
const recordBtn = document.getElementById("record-btn");
const sampleRate = audioCtx.sampleRate;

let done, keep, numSamples = 0, sample;
const promise = () =>
    new Promise((resolve) => {
        done = (isDone = true) => {
            doneBtn.disabled = true;
            keepBtn.disabled = true;
            playBtn.disabled = true;
            recordBtn.disabled = true;
            if (sample)
                keptLbl.innerHTML = ++numSamples;
            resolve(isDone);
        };
        keep = () => done(false);
    });

const getSample = async () => {
    doneBtn.disabled = numSamples === 0;
    recordBtn.disabled = false;
    const done = await promise();
    const result = JSON.stringify({
        done,
        sampleRate,
        sample
    });
    keepBtn.disabled = true;
    playBtn.disabled = true;
    recordBtn.disabled = true;
    doneBtn.disabled = true;
    sample = null;
    return result;
};

const captureAudio = (analyser, duration) => {
    const fftSize = analyser.fftSize;
    let intervalID, numIntervals = Math.floor(sampleRate * duration / fftSize);
    const timeDomainData = new Uint8Array(fftSize);
    const timeDomainDataQueue = [];

    return new Promise(resolve => {
        const getByteTimeDomainData = () => {
            analyser.getByteTimeDomainData(timeDomainData);
            timeDomainData.forEach(byte => timeDomainDataQueue.push(byte / 128 - 1));
            if (--numIntervals === 0) {
                clearInterval(intervalID);
                resolve(timeDomainDataQueue);
            }
        };

        intervalID = setInterval(getByteTimeDomainData, (fftSize / sampleRate) * 1e3);
    });
};

const play = () => {
    const buffer = audioCtx.createBuffer(1, sample.length, sampleRate);
    const buffering = buffer.getChannelData(0);
    sample.forEach((value, index) => buffering[index] = value);

    const source = audioCtx.createBufferSource();
    source.buffer = buffer;
    source.connect(audioCtx.destination);
    source.start(0);
};

const record = async () => {
    keepBtn.disabled = true;
    playBtn.disabled = true;
    recordBtn.disabled = true;
    const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
    const source = audioCtx.createMediaStreamSource(stream);
    const analyser = source.context.createAnalyser();
    analyser.fftSize = 2048;
    analyser.smoothingTimeConstant = 0;
    source.connect(analyser);

    sample = await captureAudio(analyser, 1);

    analyser.disconnect();
    source.disconnect();
    stream.getTracks().forEach((track) => track.stop());
    doneBtn.disabled = false;
    keepBtn.disabled = false;
    playBtn.disabled = false;
    recordBtn.disabled = false;
};
</script>

<div>
    <button id="play-btn" disabled=true onclick="play()">Play</button>
    <button id="record-btn" onclick="record()">Record</button>
    <button id="keep-btn" disabled=true onclick="keep()">Keep</button>
    <button id="done-btn" disabled=true onclick="done()">Done</button>
    <div>
        <div style="display:inline-block">
            <p>Samples: </p>
        </div>
        <div style="display:inline-block">
            <label id="kept-lbl">0</label>
        </div>
    </div>
</div>
"""
)

In [None]:
def get_samples():
    shutil.rmtree(SAMPLES_DIR, ignore_errors=True)
    os.mkdir(SAMPLES_DIR)
    display(SAMPLES_HTML)
    count = 0
    while True:
        result = json.loads(eval_js("getSample()"))
        rate_in, rate_out = result["sampleRate"], SAMPLE_RATE
        sample = result["sample"]
        if sample:
            audio = np.array(sample, dtype=np.float32).reshape((len(sample),))
            resampled = librosa.resample(
                audio,
                orig_sr=rate_in,
                target_sr=rate_out,
                res_type="kaiser_fast",
                fix=True,
            )
            soundfile.write(f"{SAMPLES_DIR}/{count}.wav", resampled, rate_out, "PCM_16")
        if result["done"]:
            break
        count += 1

In [None]:
get_samples()

In [None]:
samples = list(sorted(Path(SAMPLES_DIR).glob("*.wav")))

In [None]:
for sample in samples[:3]:
    display(Audio(str(sample)))

In [None]:
settings = input_data.standard_microspeech_model_settings(label_count=1)
fig, axes = plt.subplots(ncols=3)
for sample, ax in zip(samples[:3], axes):
    spectrogram = input_data.file2spec(settings, str(sample))  # PosixPath not supported
    ax.imshow(spectrogram.numpy())
    ax.set_title(sample.parts[2:])
fig.set_size_inches(10, 5)

In [None]:
print(subprocess.check_output(["soxi", samples[0]]).decode("utf8"))

# Model

In [None]:
tf.get_logger().setLevel(logging.ERROR)
base_model = tf.keras.models.load_model(
    "./embedding_model/multilingual_context_73_0.8011"
)
tf.get_logger().setLevel(logging.INFO)

embedding = tf.keras.models.Model(
    name="embedding_model",
    inputs=base_model.inputs,
    outputs=base_model.get_layer(name="dense_2").output,
)
embedding.trainable = False

In [None]:
sample_fpath = str(samples[0])
print("Filepath:", sample_fpath)
spectrogram = input_data.file2spec(settings, sample_fpath)
print("Spectrogram shape", spectrogram.shape)
# retrieve embedding vector representation (reshape into 1x49x40x1)
feature_vec = embedding.predict(spectrogram[tf.newaxis, :, :, tf.newaxis])
print("Feature vector shape:", feature_vec.shape)
plt.plot(feature_vec[0])
plt.gcf().set_size_inches(15, 5)

In [None]:
CATEGORIES = 3  # silence + unknown + target_keyword
model = tf.keras.models.Sequential(
    [
        embedding,
        tf.keras.layers.Dense(units=18, activation="tanh"),
        tf.keras.layers.Dense(units=CATEGORIES, activation="softmax"),
    ]
)
model.summary()

In [None]:
LEARNING_RATE = 0.001

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=["accuracy"],
)

# Dataset

In [None]:
model_settings = input_data.standard_microspeech_model_settings(3)

unknown_files_txt = "/content/unknown_files/unknown_files.txt"
unknown_files = []
with open(unknown_files_txt) as fh:
    for w in fh.read().splitlines():
        unknown_files.append("/content/unknown_files/" + w)

audio_dataset = input_data.AudioDataset(
    model_settings=model_settings,
    commands=[KEYWORD],
    background_data_dir="/content/speech_commands/_background_noise_/",
    unknown_files=unknown_files,
    unknown_percentage=50.0,
    spec_aug_params=input_data.SpecAugParams(percentage=80),
)

# Training

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 64
EPOCHS = 4

five_samples = [f"{SAMPLES_DIR}/{sample.name}" for sample in samples[:5]]
init_train_ds = audio_dataset.init_single_target(
    AUTOTUNE, five_samples, is_training=True
)
train_ds = init_train_ds.shuffle(buffer_size=1000).repeat().batch(BATCH_SIZE)

In [None]:
history = model.fit(train_ds, steps_per_epoch=BATCH_SIZE, epochs=EPOCHS)

In [None]:
history = history.history

plt.figure()
plt.ylabel("Loss")
plt.xlabel("Training Steps")
plt.ylim([0, 2])
plt.plot(history["loss"])

plt.figure()
plt.ylabel("Accuracy")
plt.xlabel("Training Steps")
plt.ylim([0, 1])
plt.plot(history["accuracy"])

In [None]:
model.save("fewshotkws.h5")

# Evaluation

In [None]:
test_samples = [f"{SAMPLES_DIR}/{sample.name}" for sample in samples[5:]]
test_spectrograms = np.array([input_data.file2spec(settings, f) for f in test_samples])
# fetch softmax predictions from the finetuned model:
# (class 0: silence/background noise, class 1: unknown keyword, class 2: target)
predictions = model.predict(test_spectrograms)
categorical_predictions = np.argmax(predictions, axis=1)
# which predictions match the target class?
accuracy = (
    categorical_predictions[categorical_predictions == 2].shape[0]
    / predictions.shape[0]
)
print(f"Test accuracy on testset: {accuracy:0.2f}")

In [None]:
non_target_examples = []
for word in os.listdir("speech_commands"):
    if not os.path.isdir(f"speech_commands/{word}"):
        continue
    if word == KEYWORD or word == "_background_noise_":
        continue
    non_target_examples.extend(Path(f"speech_commands/{word}").glob("*.wav"))

# downsampling list to speed it up
rng = np.random.RandomState(42)
non_target_examples = rng.choice(non_target_examples, 1000, replace=False).tolist()
print("Number of non-target examples", len(non_target_examples))

non_target_spectrograms = np.array(
    [input_data.file2spec(settings, str(f)) for f in non_target_examples]
)
# fetch softmax predictions from the finetuned model:
# (class 0: silence/background noise, class 1: unknown keyword, class 2: target)
predictions = model.predict(non_target_spectrograms)
categorical_predictions = np.argmax(predictions, axis=1)
# which predictions match the non-target class?
accuracy = (
    categorical_predictions[categorical_predictions == 1].shape[0]
    / predictions.shape[0]
)
print(f"Estimated accuracy on non-target samples: {accuracy:0.2f}")