In [1]:
import numpy as np
import matplotlib.pyplot as plt
import mnist

from ml_genn import Connection, Network, Population
from ml_genn.callbacks import Checkpoint, SpikeRecorder, VarRecorder
from ml_genn.compilers import EventPropCompiler, InferenceCompiler
from ml_genn.connectivity import Dense,FixedProbability
from ml_genn.initializers import Normal
from ml_genn.neurons import LeakyIntegrate, LeakyIntegrateFire, SpikeInput
from ml_genn.optimisers import Adam
from ml_genn.serialisers import Numpy
from ml_genn.synapses import Exponential
from tonic.datasets import SHD
from tonic import transforms

from time import perf_counter
from ml_genn.utils.data import (calc_latest_spike_time, calc_max_spikes,
                                preprocess_tonic_spikes)

from ml_genn.compilers.event_prop_compiler import default_params


In [2]:
sample_T = 1000 #64
shd_channels = 700
net_channels = 700 #128
# note that mlGeNN works in units of ms
net_dt = 1000/sample_T

NUM_HIDDEN = 256
BATCH_SIZE = 32
NUM_EPOCHS = 100
EXAMPLE_TIME = 20.0
KERNEL_PROFILING = True


transform = transforms.Downsample(
        time_factor=1,
        spatial_factor=net_channels / shd_channels
    )

# Get SHD dataset
dataset= {}
dataset["train"] = SHD(save_to='./data', train=True, transform=transform)
dataset["test"] = SHD(save_to='./data', train=False, transform=transform)

# Preprocess
spikes= {}
labels= {}
for which in ["train","test"]:
    spikes[which] = []
    labels[which] = []
    for i in range(len(dataset[which])):
        events, label = dataset[which][i]
        spikes[which].append(preprocess_tonic_spikes(events, dataset[which].ordering,
                                              dataset[which].sensor_size, dt=net_dt, histogram_thresh=1))
        labels[which].append(label)

# Determine max spikes and latest spike time
max_spikes = {}
latest_spike_time = {}
for which in ["train","test"]:
    max_spikes[which] = calc_max_spikes(spikes[which])
    latest_spike_time[which] = calc_latest_spike_time(spikes[which])

for which in ["train","test"]:
    print(f"Max spikes {which} {max_spikes[which]}, latest spike time {which} {latest_spike_time[which]}")

# Get number of input and output neurons from dataset 
# and round up outputs to power-of-two
# these are the same for train and test
num_input = int(np.prod(dataset["train"].sensor_size))
num_output = len(dataset["train"].classes)



Max spikes train 14917, latest spike time train 1369.0
Max spikes test 16257, latest spike time test 1169.0


In [3]:
print(len(spikes["train"]))
print(len(spikes["test"]))

8156
2264


In [4]:
serialiser = Numpy("shd_checkpoints")
network = Network(default_params)
with network:
    # Populations
    input = Population(SpikeInput(max_spikes=BATCH_SIZE * max_spikes["train"]),
                       num_input)
    hidden = Population(LeakyIntegrateFire(v_thresh=1.0, tau_mem=20.0,
                                           tau_refrac=None),
                        NUM_HIDDEN)
    output = Population(LeakyIntegrate(tau_mem=20.0, readout="avg_var_exp_weight"),
                        num_output)

    # Connections
    Connection(input, hidden, Dense(Normal(mean=0.03, sd=0.01)),
               Exponential(5.0))
    Connection(hidden, hidden, Dense(Normal(mean=0.0, sd=0.02)),
               Exponential(5.0))
    Connection(hidden, output, Dense(Normal(mean=0.0, sd=0.03)),
               Exponential(5.0))

max_example_timesteps = int(np.ceil(latest_spike_time["train"] / net_dt))

compiler = EventPropCompiler(example_timesteps=max_example_timesteps,
                             losses="sparse_categorical_crossentropy",
                             reg_lambda_upper=4e-09, reg_lambda_lower=4e-09, 
                             reg_nu_upper=14, max_spikes=1500, 
                             optimiser=Adam(0.001), batch_size=BATCH_SIZE, 
                             kernel_profiling=KERNEL_PROFILING)
compiled_net = compiler.compile(network)

with compiled_net:
    # Evaluate model on numpy dataset
    start_time = perf_counter()
    callbacks = ["batch_progress_bar", Checkpoint(serialiser)]
    metrics, _  = compiled_net.train({input: spikes["train"]},
                                     {output: labels["train"]},
                                     num_epochs=NUM_EPOCHS, shuffle=True,
                                     callbacks=callbacks)
    compiled_net.save_connectivity((NUM_EPOCHS - 1,), serialiser)

    end_time = perf_counter()
    print(f"Accuracy = {100 * metrics[output].result}%")
    print(f"Time = {end_time - start_time}s")

    if KERNEL_PROFILING:
        print(f"Neuron update time = {compiled_net.genn_model.neuron_update_time}")
        print(f"Presynaptic update time = {compiled_net.genn_model.presynaptic_update_time}")
        print(f"Gradient batch reduce time = {compiled_net.genn_model.get_custom_update_time('GradientBatchReduce')}")
        print(f"Gradient learn time = {compiled_net.genn_model.get_custom_update_time('GradientLearn')}")
        print(f"Reset time = {compiled_net.genn_model.get_custom_update_time('Reset')}")
        print(f"Softmax1 time = {compiled_net.genn_model.get_custom_update_time('BatchSoftmax1')}")
        print(f"Softmax2 time = {compiled_net.genn_model.get_custom_update_time('BatchSoftmax2')}")
        print(f"Softmax3 time = {compiled_net.genn_model.get_custom_update_time('BatchSoftmax3')}")







make: Entering directory '/its/home/tn41/spyx/research/paper/EventPropCompiler_CODE'






make: Leaving directory '/its/home/tn41/spyx/research/paper/EventPropCompiler_CODE'


  0%|          | 0/255 [00:00<?, ?it/s]

Accuracy = 85.42177538008828%
Time = 1306.8711126269773s
Neuron update time = 281.534813520676
Presynaptic update time = 446.8579314324223
Gradient batch reduce time = 7.168805803238494
Gradient learn time = 0.47280346669059126
Reset time = 0.24571251416058992
Softmax1 time = 0.15397123452222825
Softmax2 time = 0.14953955355187823
Softmax3 time = 0.14467980824406626


In [5]:
# Load network state from final checkpoint
network.load((NUM_EPOCHS - 1,), serialiser)

compiler = InferenceCompiler(evaluate_timesteps=max_example_timesteps,
                            reset_in_syn_between_batches=True,
                            batch_size=BATCH_SIZE)
compiled_net = compiler.compile(network)

with compiled_net:
    # Evaluate model on numpy dataset
    start_time = perf_counter()
    metrics, _  = compiled_net.evaluate({input: spikes["test"]},
                                        {output: labels["test"]})
    end_time = perf_counter()
    print(f"Accuracy = {100 * metrics[output].result}%")
    print(f"Time = {end_time - start_time}s")





make: Entering directory '/its/home/tn41/spyx/research/paper/InferenceCompiler_CODE'





make: Leaving directory '/its/home/tn41/spyx/research/paper/InferenceCompiler_CODE'


  0%|          | 0/71 [00:00<?, ?it/s]

Accuracy = 68.28621908127208%
Time = 1.2334860870032571s
