# Keyword Spotting on Spiking Neural Networks

This is a test Jupyter notebook with implementations of various SNNs for keyword spotting (KWS) on SNNs. The networks developed will end up being implemented in an embedded hardware accelerator targeting low-power computing for e.g. hearing aids or similar. SNNs are ideal for such applications because of their sparse activity, and the application's relatively low requirements in terms of throughput. 

The notebook uses a combination of Numpy, PyTorch, and BindsNET libraries for this task. The dataset considered is a subset of relevant keywords from the speech commands dataset from Google, see [here](https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html).

The hardware accelerator is available in this repository under `./src/main/scala/neuroproc`. The accelerator is written in the Chisel language, which is an open-source HDL within Scala, see [here](https://github.com/chipsalliance/chisel3). Its design is based on [work by Anthon Riber](https://github.com/Thonner/NeuromorphicProcessor) and is developed as part of a master's thesis at the Institute of Mathematics and Computer Science, DTU Compute, Technical University of Denmark.

In [None]:
#%pip install --upgrade pip
#%pip install bindsnet seaborn

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set(style='whitegrid')
from kwsonsnn.utils import download
download('data')

## Dataset

First step in this process is loading and understanding the dataset. Additionally, the data may need various kinds of preprocessing to be used in the models (e.g. FFT or similar, see [this](https://towardsdatascience.com/speech-classification-using-neural-networks-the-basics-e5b08d6928b7)). The box above should have downloaded the dataset - granted it is not already available.

In [None]:
from kwsonsnn.dataset import SpeechCommandsDataset
train_data = SpeechCommandsDataset('data')

Let us plot some random non-preprocessed data.

In [None]:
plt.figure(figsize=(7,4))
data = train_data[0]
audio, label = data['audio'], data['label']
sns.lineplot(x=np.arange(len(audio)), y=audio)
plt.xlim(0, len(audio))
plt.ylim(min(audio)*1.05, max(audio)*1.05)
plt.title(f'Keyword: {label}')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.tight_layout()
plt.show()

Next, let us apply the preprocessing steps we wish to do and plot the resulting data. First of all, an FFT of the data will effectively transform the data to a plottable spectrogram. Secondly, the frequency spectrum data will be converted into Mel cepstral coefficients.

In [None]:
train_data.process_data()

And let us plot some processed data.

In [None]:
# Fetch some data
data = train_data[0]
frames, label = data['audio'], data['label']

plt.figure(figsize=(5,4))
ax = sns.heatmap(frames.T)
ax.invert_yaxis()
plt.xlabel('Time')
plt.ylabel('MFC Coefficients')
plt.tight_layout()
plt.show()

## Encoding

Next, we shall explore the encoding of the data fed into the network. Generally, we observe that power consumption is highly correlated with spike activity and thus, we should aim for as few spikes per time step as possible. The original work implements training and evaluation with rate-based encoding; direct rate encoding for simulation and training, and indirect rate-period encoding for efficient storage in the accelerator.

This project, however, sets out to optimize this by using a timing-based encoding instead. This affects only the input to the network in that the activity in the subsequent layers is less dependent on data encoding.

First, we encode the data with rate encoding.

In [None]:
# Direct rate encoding
from kwsonsnn.encode import RateEncoder
enc = RateEncoder(500)
print(enc(frames))

In [None]:
# Indirect rate encoding
from kwsonsnn.encode import RatePeriod
enc = RatePeriod(500)
print(enc(frames))

Notice how the periods calculated above are typically rather short relative to the 500 time steps of inputting spikes. This means that a massive amount of spike activity is seen in the input phase, which will likely not be reflected in subsequent layers.

Let us instead consider rank-order encoding.

In [None]:
# Direct rank-order encoding
# TODO: Fix this!
from bindsnet.encoding.encoders import RankOrderEncoder
enc = RankOrderEncoder(500)
print(enc(frames))
print(enc(frames).sum(dim=(2,1)))

In [None]:
# Indirect rank-order encoding
from kwsonsnn.encode import RankOrderPeriod
enc = RankOrderPeriod(500)
print(enc(frames))

When using rank-order encoding, at most one input neuron spikes in each of the 500 time steps. This means that the total number of spikes generated by input spike trains is reduced by roughly an order of magnitude with significant power savings to follow.

## Defining the model

The next step is to define a relevant model and prepare it for the supervised learning task that keyword spotting is. The following box checks that a pretrained model exists in `./pretrained/network.pt`. If not, it runs the training script first (training is kept in a separate file for better command line functionality). Please ensure that the network specifications in the two files is identical.

In [None]:
if not os.path.isfile('./pretrained/network.pt'):
    os.system('python train.py')

In [None]:
from kwsonsnn.model import ShowCaseNet
from kwsonsnn.utils import get_default_net

# Construct network
network = get_default_net()

# Load pre-trained network
network.load_state_dict(torch.load('./pretrained/network.pt'))
network.eval()
print(network)