In [1]:
# https://www.nengo.ai/nengo-examples/loihi/mnist-convnet.html

In [2]:
import nengo
import numpy as np
import matplotlib.pyplot as plt

In [3]:
train = np.load("spike_trains_trainNEW.npy")

In [4]:
train.shape

# this is... the second one is the 40 frequency bands, why is this 3D though 
# last number is 41*41, why??
# first one is then the amount of samples maybe??

# not sure how to resolve the different samples, since I think you have to give all of the samples in one go?

(2464, 40, 1681)

In [5]:
input = train[0:50,:,:]

In [6]:
# Parameters from the paper
n_neurons = 40
f_maps = 50
window_size = (6, 40)

mean = 0.8
std = 0.05

threshold = 23
thresh_config = nengo.presets.ThresholdingEnsembles(threshold) # Set the threshold

In [None]:
model = nengo.Network()

with model:

    # Not entirely sure about these dimensions yet
    input_layer = nengo.Node(nengo.processes.PresentInput(input, 0.1), size_out=input.shape[1] * input.shape[2])
    pre = nengo.Ensemble(n_neurons, dimensions = 2000)
    post = nengo.Ensemble(n_neurons, dimensions = 2000) # ??
    
    # This looks fairly solid to me though
    transform = nengo.Convolution(
                n_filters = f_maps,
                input_shape = (input.shape[1],input.shape[2]),
                kernel_size = [6*40],
                strides = [1*1],
                padding="same",
                channels_last = True,
                init = nengo.dists.Gaussian(mean, std)
            )
    
    # And then you can apply it on the input as preprocessing step
    conv_conn = nengo.Connection(input_layer, pre, transform = transform)
    
    # Learning rule
    learn_conn = nengo.Connection(
        pre, post,
        learning_rule_type = nengo.BCM(learning_rate=5e-10), # Change this later
        solver = nengo.solvers.LstsqL2(weights=True)
    )
    
    #Probes and such
    input_probe = nengo.Probe(input_layer)
    pre_probe = nengo.Probe(pre, synapse=0.01)
    post_probe = nengo.Probe(post, synapse=0.01)
    
    with nengo.Simulator(model) as sim:
        sim.run(time_in_seconds=60)

In [None]:
# Plotting
def plots(start_ix=None, end_ix=None):
    sl = slice(start_ix, end_ix)
    t = sim.trange()[sl]
    plt.figure(figsize=(12, 12))

    plt.subplot(3, 1, 1)
    plt.plot(t, sim.data[input_probe].T[sl], c="k", label="Input")
    plt.ylabel("Input")
    plt.legend(loc="best")

    plt.subplot(3, 1, 2)
    plt.plot(t, sim.data[pre_probe].T[sl], c="b", label="Pre")
    plt.ylabel("Pre")
    plt.legend(("Pre[0]", "Pre[1]"), loc="best")
    plt.legend(loc="best")

    plt.subplot(3, 1, 3)
    plt.plot(t, sim.data[post_probe][sl], c="b", label="Post")
    plt.ylabel("Post")
    plt.legend(loc="best")
    
plots()