# Keyword Spotting on Spiking Neural Networks

This is a test Jupyter notebook with implementations of various SNNs for keyword spotting (KWS) on SNNs. The networks developed will end up being implemented in an embedded hardware accelerator targeting low-power computing for e.g. hearing aids or similar. SNNs are ideal for such applications because of their sparse activity, and the application's relatively low requirements in terms of throughput. 

The notebook uses a combination of Numpy, PyTorch, and BindsNET libraries for this task. The dataset considered is a subset of relevant keywords from the speech commands dataset from Google, see [here](https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html).

The hardware accelerator is available in this repository under `./src/main/scala/neuroproc`. The accelerator is written in the Chisel language, which is an open-source HDL within Scala, see [here](https://github.com/chipsalliance/chisel3). Its design is based on [work by Anthon Riber](https://github.com/Thonner/NeuromorphicProcessor) and is developed as part of a master's thesis at the Institute of Mathematics and Computer Science, DTU Compute, Technical University of Denmark.

In [None]:
#%pip install --upgrade pip
#%pip install bindsnet seaborn

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set(style='whitegrid')
from kwsonsnn.utils import download
download('data')

## Dataset

First step in this process is loading and understanding the dataset. Additionally, the data may need various kinds of preprocessing to be used in the models (e.g. FFT or similar, see [this](https://towardsdatascience.com/speech-classification-using-neural-networks-the-basics-e5b08d6928b7)). The box above should have downloaded the dataset - granted it is not already available.

In [None]:
from kwsonsnn.dataset import SpeechCommandsDataset
kws = ['up', 'down', 'left', 'right', 'on', 'off', 'yes', 'no', 'go', 'stop']
train_data = SpeechCommandsDataset('data', kws=kws)

Let us plot some random non-preprocessed data for each keyword. Unfortunately, there is no better way to fetch these than sequentially because of the data shuffling and use of sets in `SpeechCommandsDataset`.

In [None]:
kwsS = set(kws)
indices = []
for i in range(len(train_data)):
    if len(kwsS) == 0:
        break
    label = train_data[i]['label']
    if label in kwsS:
        indices.append(i)
        kwsS.remove(label)

In [None]:
rows = len(indices) // 4 + 1
cols = 4 if len(indices) >= 3 else len(indices)
fig, axs = plt.subplots(rows, cols, figsize=(14,7))
for i in range(rows):
    for j in range(min(cols, len(indices)-4*i)):
        data = train_data[indices[i*4+j]]
        audio, label = data['audio'], data['label']
        sns.lineplot(ax=axs[i, j], x=np.arange(len(audio)), y=audio)
        axs[i, j].set_xlim(0, len(audio))
        axs[i, j].set_ylim(-1, 1)
        axs[i, j].set_title(f'Keyword: {label}')
plt.suptitle('Original audio signals')
plt.tight_layout()
plt.show()

And now, consider four randomly selected different versions of the same keyword.

In [None]:
kw = kws[0]
same = []
for i in range(len(train_data)):
    if len(same) == 4:
        break
    label = train_data[i]['label']
    if kw == label:
        same.append(i)

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(14,4))
for i in range(4):
    data = train_data[same[i]]
    audio, label = data['audio'], data['label']
    sns.lineplot(ax=axs[i], x=np.arange(len(audio)), y=audio)
    axs[i].set_xlim(0, len(audio))
    axs[i].set_ylim(-1, 1)
plt.suptitle(f'Different versions of keyword "{label}"')
plt.tight_layout()
plt.show()

Next, let us apply the preprocessing steps we wish to do and plot the resulting data. First of all, an FFT of the data will effectively transform the data to a plottable spectrogram. Secondly, the frequency spectrum data will be converted into Mel cepstral coefficients.

Various works consider a varying number of frames extracted from the audio signals. For example:
- BindsNET's own spoken MNIST implementation uses `frame_length=25ms` and `frame_stride=10ms` (see [here](https://github.com/BindsNET/bindsnet/blob/master/bindsnet/datasets/spoken_mnist.py))
- Hello Edge from Arm uses `frame_length=40ms` and `frame_stride=20ms` (see [here](https://arxiv.org/abs/1711.07128))
- Benchmarking KWS Efficiency ... from Applied Brain Research uses 390-dimension frames with `frame_stride=10ms` (see [here](https://arxiv.org/abs/1812.01739))
- Low-Power Low-Latency KWS ... from Arm, Intel Labs, Applied Brain Research etc. uses 390-dimension frames with `frame_stride=10ms` (see [here](https://arxiv.org/abs/2009.08921))
- Max-Pooling Loss Training ... from Amazon and Google uses `frame_length=25ms` and `frame_stride=10ms` (see [here](https://arxiv.org/abs/1705.02411))
- A Dataset and Taxonomy ... from NYU uses `frame_length=23.2ms` and `frame_stride=11.6ms` (see [here](https://dl.acm.org/doi/10.1145/2647868.2655045))
- FastGRNN ... from Microsoft uses `frame_length=25ms` and `frame_stride=10ms` (see [here](https://arxiv.org/pdf/1901.02358.pdf))
- Efficient KWS using Dilated ... from Snips uses `frame_length=25ms` and `frame_stride=10ms` (see [here](https://ieeexplore.ieee.org/document/8683474))
- Deep Residual Learning ... uses `frame_length=30ms` and `frame_stride=10ms` (see [here](https://arxiv.org/abs/1710.10361))

Overall, we see that a frame length of approximately 25ms with accordingly smaller stride (i.e., some amount of overlap) is typical for KWS applications. The overlap of frames serves to limit the risk of missing important information by poor frame "placement". Common to all these works is that they use significantly more frames than what is available to this project, if the network in the accelerator is not scaled. Currently, supporting only 22x22 input images limits is performance significantly as this means a significantly larger stride must be used.

In [None]:
train_data.process_data()

And let us plot some processed data. We shall reuse the signals used before.

In [None]:
fig, axs = plt.subplots(rows, cols, figsize=(14,7))
for i in range(rows):
    for j in range(min(cols, len(indices)-4*i)):
        data = train_data[indices[i*4+j]]
        frames, label = data['audio'], data['label']
        sns.heatmap(frames.T, ax=axs[i, j])
        axs[i, j].invert_yaxis()
        axs[i, j].set_title(f'Keyword: {kws[int(label.item())]}')
plt.suptitle('Mel-spectrograms of audio signals')
plt.tight_layout()
plt.show()

And returning back to the four versions of the same keyword.

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(14,4))
for i in range(4):
    data = train_data[same[i]]
    frames, label = data['audio'], data['label']
    sns.heatmap(frames.T, ax=axs[i])
    axs[i].invert_yaxis()
plt.suptitle(f'Different versions of keyword "{kw}"')
plt.tight_layout()
plt.show()

The four different utterances of the same keyword both before and after pre-processing show large differences primarily in the placement of the word in time.

## Encoding

Next, we shall explore the encoding of the data fed into the network. Generally, we observe that power consumption is highly correlated with spike activity and thus, we should aim for as few spikes per time step as possible. The original work implements training and evaluation with rate-based encoding; direct rate encoding for simulation and training, and indirect rate-period encoding for efficient storage in the accelerator.

This project, however, sets out to optimize this by using a timing-based encoding instead. This affects only the input to the network in that the activity in the subsequent layers is less dependent on data encoding.

First, we encode the data with rate encoding.

In [None]:
frames = train_data[0]['audio'] * 128 # use the same scaling factor as in train.py
frames.shape

In [None]:
# Direct rate encoding
from kwsonsnn.encode import RateEncoder
enc = RateEncoder(500)
print(enc(frames))
print(enc(frames).sum())

In [None]:
# Indirect rate encoding
from kwsonsnn.encode import RatePeriod
enc = RatePeriod(500)
print(enc(frames))

Notice how the periods calculated above are typically rather short relative to the 500 time steps of inputting spikes. This means that a massive amount of spike activity is seen in the input phase, which will likely not be reflected in subsequent layers.

Let us instead consider rank-order encoding.

In [None]:
# Direct rank-order encoding
from kwsonsnn.encode import RankOrderDirect
enc = RankOrderDirect(500)
print(enc(frames))
print(enc(frames).sum())

In [None]:
# Indirect rank-order encoding
from kwsonsnn.encode import RankOrderPeriod
enc = RankOrderPeriod(500)
print(enc(frames))

When using rank-order encoding, at most one input neuron spikes in each of the 500 time steps. This means that the total number of spikes generated by input spike trains is reduced by roughly an order of magnitude with significant power savings to follow.

## Defining the model

The next step is to define a relevant model and prepare it for the supervised learning task that keyword spotting is. The following box checks that a pretrained model exists in `./pretrained/network.pt`. If not, it runs the training script first (training is kept in a separate file for better command line functionality). Please ensure that the network specifications in the two files is identical.

In [None]:
if not os.path.isfile('./pretrained/network.pt'):
    os.system('python train.py')

In [None]:
from kwsonsnn.model import ShowCaseNet
from kwsonsnn.utils import get_default_net

# Construct network
network = get_default_net()

# Load pre-trained network
network.load_state_dict(torch.load('./pretrained/network.pt'))
network.eval()
print(network)

## Old code snippets

This section is not meant to be executed, but rather looked at for inspiration.

In [None]:
# OLD CODE FOR A SINGLE PLOT
#plt.figure(figsize=(7,4))
#data = train_data[0]
#audio, label = data['audio'], data['label']
#sns.lineplot(x=np.arange(len(audio)), y=audio)
#plt.xlim(0, len(audio))
#plt.ylim(min(audio)*1.05, max(audio)*1.05)
#plt.title(f'Keyword: {label}')
#plt.xlabel('Sample')
#plt.ylabel('Amplitude')
#plt.tight_layout()
#plt.show()

In [None]:
# OLD CODE FOR A SINGLE PLOT
#data = train_data[0]
#frames, label = data['audio'], data['label']
#plt.figure(figsize=(5,4))
#ax = sns.heatmap(frames.T)
#ax.invert_yaxis()
#plt.xlabel('Time')
#plt.ylabel('MFC Coefficients')
#plt.tight_layout()
#plt.show()