In [1]:
import numpy as np
import matplotlib.pyplot as plt

from ml_genn import Connection, Network, Population
from ml_genn.callbacks import Checkpoint, SpikeRecorder, VarRecorder
from ml_genn.compilers import EventPropCompiler, InferenceCompiler
from ml_genn.connectivity import Dense,FixedProbability
from ml_genn.initializers import Normal
from ml_genn.neurons import LeakyIntegrate, LeakyIntegrateFire, SpikeInput
from ml_genn.optimisers import Adam
from ml_genn.serialisers import Numpy
from ml_genn.synapses import Exponential
from tonic.datasets import SHD
from tonic import transforms

from time import perf_counter
from ml_genn.utils.data import (calc_latest_spike_time, calc_max_spikes,
                                preprocess_tonic_spikes)

from ml_genn.compilers.event_prop_compiler import default_params


In [2]:
sample_T = 256
shd_channels = 700
net_channels = 128
# note that mlGeNN works in units of ms
net_dt = 1000/sample_T


NUM_EPOCHS = 100
EXAMPLE_TIME = 20.0



transform = transforms.Downsample(
        time_factor=1,
        spatial_factor=net_channels / shd_channels
    )

# Get SHD dataset
dataset= {}
dataset["train"] = SHD(save_to='./data', train=True, transform=transform)
dataset["test"] = SHD(save_to='./data', train=False, transform=transform)

# Preprocess
spikes= {}
labels= {}
for which in ["train","test"]:
    spikes[which] = []
    labels[which] = []
    for i in range(len(dataset[which])):
        events, label = dataset[which][i]
        spikes[which].append(preprocess_tonic_spikes(events, dataset[which].ordering,
                                              dataset[which].sensor_size, dt=net_dt, histogram_thresh=1))
        labels[which].append(label)

# Determine max spikes and latest spike time
max_spikes = {}
latest_spike_time = {}
for which in ["train","test"]:
    max_spikes[which] = calc_max_spikes(spikes[which])
    latest_spike_time[which] = calc_latest_spike_time(spikes[which])

for which in ["train","test"]:
    print(f"Max spikes {which} {max_spikes[which]}, latest spike time {which} {latest_spike_time[which]}")

# Get number of input and output neurons from dataset 
# and round up outputs to power-of-two
# these are the same for train and test
num_input = int(np.prod(dataset["train"].sensor_size))
num_output = len(dataset["train"].classes)



Downloading https://zenkelab.org/datasets/shd_train.h5.zip to ./data/SHD/shd_train.h5.zip


  0%|          | 0/130863613 [00:00<?, ?it/s]

Extracting ./data/SHD/shd_train.h5.zip to ./data/SHD
Downloading https://zenkelab.org/datasets/shd_test.h5.zip to ./data/SHD/shd_test.h5.zip


  0%|          | 0/38141465 [00:00<?, ?it/s]

Extracting ./data/SHD/shd_test.h5.zip to ./data/SHD
Max spikes train 6335, latest spike time train 1367.1875
Max spikes test 6797, latest spike time test 1167.96875


In [4]:
serialiser = Numpy("shd_checkpoints")

def benchmark(num_trials, num_hidden, batch_size):

    NUM_HIDDEN = num_hidden
    BATCH_SIZE = batch_size
    
    times = []
    for i in range(num_trials):
        
    
        network = Network(default_params)
        with network:
            # Populations
            input = Population(SpikeInput(max_spikes=BATCH_SIZE * max_spikes["train"]),
                               num_input)
            hidden1 = Population(LeakyIntegrateFire(v_thresh=1.0, tau_mem=20.0,
                                                   tau_refrac=None),
                                NUM_HIDDEN)
            hidden2 = Population(LeakyIntegrateFire(v_thresh=1.0, tau_mem=20.0,
                                                   tau_refrac=None),
                                NUM_HIDDEN)
            output = Population(LeakyIntegrate(tau_mem=20.0, readout="avg_var_exp_weight"),
                                num_output)
        
            # Connections
            Connection(input, hidden1, Dense(Normal(mean=0.03, sd=0.01)),
                       Exponential(5.0))
            Connection(hidden1, hidden2, Dense(Normal(mean=0.0, sd=0.02)),
                       Exponential(5.0))
            Connection(hidden2, hidden1, Dense(Normal(mean=0.0, sd=0.02)),
                       Exponential(5.0))
            Connection(hidden2, output, Dense(Normal(mean=0.0, sd=0.03)),
                       Exponential(5.0))
        
        max_example_timesteps = int(np.ceil(latest_spike_time["train"] / net_dt))
        
        compiler = EventPropCompiler(example_timesteps=max_example_timesteps,
                                     losses="sparse_categorical_crossentropy",
                                     reg_lambda_upper=4e-09, reg_lambda_lower=4e-09, 
                                     reg_nu_upper=14, max_spikes=1500, 
                                     optimiser=Adam(0.001), batch_size=BATCH_SIZE, 
                                     kernel_profiling=KERNEL_PROFILING)
        
        compiled_net = compiler.compile(network)

        with compiled_net:
            # Evaluate model on numpy dataset
            start_time = perf_counter()
            callbacks = ["batch_progress_bar", Checkpoint(serialiser)]
            metrics, _  = compiled_net.train({input: spikes["train"]},
                                             {output: labels["train"]},
                                             num_epochs=NUM_EPOCHS, shuffle=True,
                                             callbacks=callbacks)
            compiled_net.save_connectivity((NUM_EPOCHS - 1,), serialiser)
        
            end_time = perf_counter()
            print(f"Accuracy = {100 * metrics[output].result}%")
            print(f"Total Time = {end_time - start_time}s")
        
            kernel_time = np.sum([compiled_net.genn_model.neuron_update_time,
                compiled_net.genn_model.presynaptic_update_time,
                compiled_net.genn_model.get_custom_update_time('GradientBatchReduce'),
                compiled_net.genn_model.get_custom_update_time('GradientLearn'),
                compiled_net.genn_model.get_custom_update_time('Reset'),
                compiled_net.genn_model.get_custom_update_time('BatchSoftmax1'),
                compiled_net.genn_model.get_custom_update_time('BatchSoftmax2'),
                compiled_net.genn_model.get_custom_update_time('BatchSoftmax3')])
            
            print(f"Kernel Time = {kernel_time}s")
            times.append(kernel_time)

            

    print(np.mean(times), np.std(times))

In [10]:
benchmark(5, 128, 64)







make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/128 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 215.60070894099772s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/128 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 218.50213174149394s






make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/128 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 214.57843165285885s






make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/128 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 213.85770628787577s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/128 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 216.04864346794784s
215.71752441823483 1.5895772665521737


In [11]:
benchmark(5, 128, 128)







make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/64 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 158.73156289756298s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/64 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 160.5546283069998s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/64 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 159.54116126336157s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/64 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 162.65104128792882s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/64 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 162.64246773160994s
160.82417229749262 1.5963354266028158


In [12]:
benchmark(5, 128, 256)







make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/32 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 123.82656937651336s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/32 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 124.93158670514822s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/32 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 123.78426889330149s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/32 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 122.86032863892615s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/32 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 121.87257836386561s
123.45506639555097 1.028054701231875


In [13]:
benchmark(5, 512, 64)







make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/128 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 732.3393701408058s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/128 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 733.7367905993015s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/128 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 734.5784377194941s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/128 [00:00<?, ?it/s]

Accuracy = 8.288376655223148%
Time = 521.0588604919612s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/128 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 732.967384962365s
690.9361687827856 84.94196719429978


In [14]:
benchmark(5, 512, 128)







make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/64 [00:00<?, ?it/s]

Accuracy = 8.374203040706227%
Time = 402.4214499145746s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/64 [00:00<?, ?it/s]

Accuracy = 8.128984796468858%
Time = 399.5823627784848s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/64 [00:00<?, ?it/s]

Accuracy = 7.736635605689063%
Time = 393.9364239759743s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/64 [00:00<?, ?it/s]

Accuracy = 8.055419323197647%
Time = 398.9263070188463s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/64 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 497.083568720147s
418.3900224816054 39.44163356765095


In [15]:
benchmark(5, 512, 256)







make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/32 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 368.60786990635097s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/32 [00:00<?, ?it/s]

Accuracy = 7.4668955370279555%
Time = 319.77090207859874s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/32 [00:00<?, ?it/s]

Accuracy = 4.977930358018637%
Time = 368.3075149767101s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/32 [00:00<?, ?it/s]

Accuracy = 7.233938205002452%
Time = 319.8355849850923s
make: Entering directory '/opt/genn/EventPropCompiler_CODE'






make: Leaving directory '/opt/genn/EventPropCompiler_CODE'


  0%|          | 0/32 [00:00<?, ?it/s]

Accuracy = 7.344286414909269%
Time = 320.85293639265s
339.4749616678804 23.66760392974065
