In [13]:
# - Imports
import warnings
warnings.filterwarnings('ignore')
import nengo
import nengo_dl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from IPython import display
import os
from nengo_extras.plot_spikes import plot_spikes

# - Simmba imports
from SIMMBA import BaseModel
from SIMMBA.experiments.HeySnipsDEMAND import HeySnipsDEMAND
from SIMMBA import BatchResult

from rockpool.layers import ButterMelFilter
from rockpool.timeseries import TSContinuous

In [14]:
def weight_init(shape):
    '''Convenience function for randomly initializing weights'''
    weights = np.random.uniform(-0.05, 0.05, size=shape)
    return weights

class HeySnipsNetworkNengo(BaseModel):
    def __init__(self,
                 labels,
                 num_neurons,
                 tau_slow,
                 num_val,
                 num_test,
                 num_epochs,
                 threshold,
                 eta,
                 fs=16000.,
                 verbose=0,
                 node_id=1,
                 name="Snips Nengo",
                 version="1.0"):
        
        super(HeySnipsNetworkNengo, self).__init__(name,version)

        self.verbose = verbose
        self.node_id = node_id
        self.fs = fs
        self.dt = 0.001

        self.num_val = num_val
        self.num_test = num_test

        self.num_epochs = num_epochs
        self.num_neurons = num_neurons
        self.threshold = threshold
        self.num_channels = 16
        max_rate = 250
        amplitude = 1 / max_rate
        tau_mem = 0.05


        self.num_targets = len(labels)

        # - Everything is stored in base_path/Resources/hey-snips/
        self.base_path = "/home/julian/Documents/nengo-samples/"
        self.node_prefix = "node_"+str(self.node_id)+str(int(np.abs(np.random.randn()*1e10)))

        self.lyr_filt = ButterMelFilter(fs=fs,
                                num_filters=self.num_channels,
                                cutoff_fs=400.,
                                filter_width=2.,
                                num_workers=4,
                                name='filter')


        # - Create Network
        model_path_nengo_net = os.path.join(self.base_path,"Resources/x")

        if(os.path.exists(model_path_nengo_net)):
            # - Load network
            assert(False)
            # - Need to : self.best_model = self.net
            print("Loaded pretrained network from %s" % model_path_nengo_net)
        else:
            
            # - Create network here
            # - Tau_rc is the membrane TC, tau_ref is the refractory period
            lifs = nengo.LIF(tau_rc=tau_mem, tau_ref=0.00, amplitude=amplitude)
            # - Network connectivity, only one recurrently connected ensemble
            with nengo.Network() as net:
                net.config[nengo.Connection].synapse = nengo.synapses.Lowpass(tau=tau_slow)
                net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([max_rate])
                net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])

                inp = nengo.Node(np.zeros(self.num_channels))
                ens = nengo.Ensemble(n_neurons=self.num_neurons, dimensions=1, neuron_type=lifs)
                out = nengo.Node(size_in=1)

                conn_a = nengo.Connection(
                    inp, ens.neurons, transform=weight_init(shape=(self.num_neurons, self.num_channels)))

                conn_rec = nengo.Connection(
                    ens.neurons, ens.neurons, transform=weight_init(shape=(self.num_neurons, self.num_neurons)))

                conn_b = nengo.Connection(
                    ens.neurons, out, transform=weight_init(shape=(1, self.num_neurons)) / tau_mem)

                self.probe_out = nengo.Probe(out, synapse=0.01)
                self.probe_spikes = nengo.Probe(ens.neurons)
                self.best_model = net
                self.net = net


    def save(self, fn):
        return


    def train(self, data_loader, fn_metrics):

        self.best_model = self.net

        for epoch in range(self.num_epochs):

            for batch_id, [batch, train_logger] in enumerate(data_loader.train_set()):

                audio = np.vstack([s[0][0] for s in batch])
                filtered = np.stack([s[0][1] for s in batch])
                tgt_signals = np.vstack([s[2] for s in batch])
                
                print(audio.shape)
                print(filtered.shape)

                train_logger.add_predictions(pred_labels=np.ones(50), pred_target_signals=[[0]])
                fn_metrics('train', train_logger)


            yield {"train_loss": 0.0}


            # Validate at the end of the epoch
            val_acc = self.perform_validation_set(data_loader=data_loader, fn_metrics=fn_metrics)


    def perform_validation_set(self, data_loader, fn_metrics):

        
        for batch_id, [batch, val_logger] in enumerate(data_loader.val_set()):
            if(batch_id >= self.num_val):
                break

            val_logger.add_predictions(pred_labels=[0], pred_target_signals=[[0]])
            fn_metrics('val', val_logger)


        return 0.0


    def test(self, data_loader, fn_metrics):

        for batch_id, [batch, test_logger] in enumerate(data_loader.test_set()):

            if batch_id > self.num_test:
                break
            

            test_logger.add_predictions(pred_labels=[0], pred_target_signals=[[0]])
            fn_metrics('test', test_logger)


In [15]:
batch_size = 50
balance_ratio = 1.0
snr = 10.
percentage_data = 0.1

experiment = HeySnipsDEMAND(batch_size=batch_size,
                            percentage=percentage_data,
                            snr=snr,
                            randomize_after_epoch=True,
                            downsample=1000,
                            is_tracking=False,
                            one_hot=False)

num_train_batches = int(np.ceil(experiment.num_train_samples / batch_size))
num_val_batches = int(np.ceil(experiment.num_val_samples / batch_size))
num_test_batches = int(np.ceil(experiment.num_test_samples / batch_size))

model = HeySnipsNetworkNengo(labels=experiment._data_loader.used_labels,
                            num_neurons=1024,
                            tau_slow=0.07,
                            num_val=500,
                            num_test=1000,
                            num_epochs=100,
                            threshold=0.7,
                            eta=0.0001,
                            verbose=0,
                            node_id=0)

experiment.set_model(model)
experiment.set_config({'num_train_batches': num_train_batches,
                       'num_val_batches': num_val_batches,
                       'num_test_batches': num_test_batches,
                       'batch size': batch_size,
                       'percentage data': percentage_data,
                       'snr': snr,
                       'balance_ratio': balance_ratio})
experiment.start()

print("experiment done",flush=True)
print(f"Accuracy score: {experiment.acc_scores}",flush=True)
print("confusion matrix",flush=True)
print(experiment.cm)

[INFO] - HeySnips -  Running with tracking deactivated
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(50, 80000)
(50, 4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990, 16)
(4990

Process ForkPoolWorker-47:
Process ForkPoolWorker-45:
Process ForkPoolWorker-46:
Traceback (most recent call last):


KeyboardInterrupt: 

Traceback (most recent call last):
Process ForkPoolWorker-48:
  File "/home/julian/anaconda3/envs/nengo/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/julian/anaconda3/envs/nengo/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/julian/anaconda3/envs/nengo/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/julian/anaconda3/envs/nengo/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/home/julian/anaconda3/envs/nengo/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/home/julian/anaconda3/envs/nengo/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
Traceback (most recent call last):
  File "/home/julian/anaconda3/envs/nengo/lib/python3.6/multiprocessing/queues.py", line