In [1]:
import numpy as np
import matplotlib.pyplot as plt
import keras_spiking
import tensorflow as tf

In [2]:
# Load CIFAR 10
(
    (train_images, train_labels),
    (test_images, test_labels),
) = tf.keras.datasets.mnist.load_data()

In [3]:
# normalize images so values are between 0 and 1
train_images = train_images / 255.0
test_images = test_images / 255.0

In [4]:
from tensorflow.keras.callbacks import EarlyStopping

es = EarlyStopping(monitor='val_loss', mode='min', verbose=1)

topologies = [
    [64, 64],
    [32],
    [32, 32],
    [64],
    [128],
]

def generate_non_spiking_model(topology):
    layers = [tf.keras.layers.Flatten(input_shape=(28, 28))]
    for hidden_layer_size in topology:
        layers.append(tf.keras.layers.Dense(hidden_layer_size, activation="relu"))
    layers.append(tf.keras.layers.Dense(10))
    model = tf.keras.Sequential(layers)
    return model

def generate_spiking_model(topology):
    layers = [tf.keras.layers.Reshape((-1, 28 * 28), input_shape=(None, 28, 28))]
    for hidden_layer_size in topology:
        layers.append(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_layer_size)))
        layers.append(keras_spiking.SpikingActivation("relu", dt=0.01, spiking_aware_training=True))
    layers.append(tf.keras.layers.GlobalAveragePooling1D(),)
    layers.append(tf.keras.layers.Dense(10))
    model = tf.keras.Sequential(layers)
    return model

def train(input_model, train_x, test_x):
    input_model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )
    
    num_epochs = 40
    best_test_acc = 0
    for num_epoch in range(num_epochs):
        input_model.fit(train_x, train_labels, epochs=1)
        _, test_acc = input_model.evaluate(test_x, test_labels, verbose=2)
        if test_acc > best_test_acc:
            best_epoch_model = tf.keras.models.clone_model(input_model)
            best_test_acc = test_acc
    return test_acc, best_epoch_model

In [5]:
n_steps = 10
train_sequences = np.tile(train_images[:, None], (1, n_steps, 1, 1))
test_sequences = np.tile(test_images[:, None], (1, n_steps, 1, 1))
models = dict()

In [6]:
for topology in topologies:
    print(str(topology))
    spiking_model = generate_spiking_model(topology)
    test_acc, best_model = train(spiking_model, train_sequences, test_sequences)
    models[str(topology)] = best_model
    print(test_acc)
    energy = keras_spiking.ModelEnergy(best_model, example_data=np.ones((1, 28, 28)))
    summ = energy.summary_string(
        columns=("name", "energy cpu", "energy loihi"),
        timesteps_per_inference=10,
        print_warnings=False,
        dt=0.01
    )
    loihi_energy = float(summ.split(':')[-1].strip())
    print(loihi_energy)

[64, 64]
313/313 - 3s - loss: 0.6894 - accuracy: 0.7988
313/313 - 3s - loss: 0.4736 - accuracy: 0.8491
313/313 - 3s - loss: 0.3954 - accuracy: 0.8813
313/313 - 3s - loss: 0.3398 - accuracy: 0.8940
313/313 - 3s - loss: 0.2974 - accuracy: 0.9099
313/313 - 3s - loss: 0.2765 - accuracy: 0.9154
313/313 - 3s - loss: 0.2566 - accuracy: 0.9224
313/313 - 3s - loss: 0.2410 - accuracy: 0.9258
313/313 - 3s - loss: 0.2290 - accuracy: 0.9290
313/313 - 3s - loss: 0.2174 - accuracy: 0.9342
313/313 - 3s - loss: 0.2314 - accuracy: 0.9256
313/313 - 3s - loss: 0.2299 - accuracy: 0.9316
313/313 - 3s - loss: 0.2065 - accuracy: 0.9408
313/313 - 3s - loss: 0.2027 - accuracy: 0.9403
313/313 - 3s - loss: 0.2165 - accuracy: 0.9329
313/313 - 3s - loss: 0.1974 - accuracy: 0.9429
313/313 - 3s - loss: 0.1787 - accuracy: 0.9470
313/313 - 3s - loss: 0.1755 - accuracy: 0.9466
313/313 - 3s - loss: 0.1754 - accuracy: 0.9492
313/313 - 3s - loss: 0.1618 - accuracy: 0.9521
313/313 - 3s - loss: 0.1917 - accuracy: 0.9413
313/

313/313 - 2s - loss: 0.1870 - accuracy: 0.9430
313/313 - 2s - loss: 0.1834 - accuracy: 0.9451
313/313 - 2s - loss: 0.1754 - accuracy: 0.9464
313/313 - 2s - loss: 0.1684 - accuracy: 0.9489
313/313 - 2s - loss: 0.1857 - accuracy: 0.9455
313/313 - 2s - loss: 0.1713 - accuracy: 0.9493
313/313 - 2s - loss: 0.1737 - accuracy: 0.9498
313/313 - 2s - loss: 0.1526 - accuracy: 0.9546
313/313 - 2s - loss: 0.1604 - accuracy: 0.9512
313/313 - 2s - loss: 0.1731 - accuracy: 0.9468
313/313 - 2s - loss: 0.1637 - accuracy: 0.9504
313/313 - 2s - loss: 0.1583 - accuracy: 0.9540
313/313 - 2s - loss: 0.1575 - accuracy: 0.9518
313/313 - 2s - loss: 0.1584 - accuracy: 0.9536
313/313 - 2s - loss: 0.1761 - accuracy: 0.9510
313/313 - 2s - loss: 0.1589 - accuracy: 0.9536
313/313 - 2s - loss: 0.1501 - accuracy: 0.9550
313/313 - 2s - loss: 0.1810 - accuracy: 0.9484
313/313 - 2s - loss: 0.1607 - accuracy: 0.9536
313/313 - 2s - loss: 0.1527 - accuracy: 0.9561
313/313 - 2s - loss: 0.1660 - accuracy: 0.9527
313/313 - 2s 

313/313 - 3s - loss: 0.1978 - accuracy: 0.9387
313/313 - 3s - loss: 0.1929 - accuracy: 0.9399
313/313 - 3s - loss: 0.1740 - accuracy: 0.9512
313/313 - 3s - loss: 0.1831 - accuracy: 0.9457
313/313 - 3s - loss: 0.1726 - accuracy: 0.9482
313/313 - 3s - loss: 0.1861 - accuracy: 0.9438
313/313 - 3s - loss: 0.1684 - accuracy: 0.9490
313/313 - 3s - loss: 0.1861 - accuracy: 0.9479
313/313 - 3s - loss: 0.1722 - accuracy: 0.9497
0.9496999979019165
1.2e-07
[64]
313/313 - 2s - loss: 0.5741 - accuracy: 0.8240
313/313 - 2s - loss: 0.4254 - accuracy: 0.8681
313/313 - 2s - loss: 0.3386 - accuracy: 0.8958
313/313 - 2s - loss: 0.2893 - accuracy: 0.9140
313/313 - 2s - loss: 0.2566 - accuracy: 0.9226
313/313 - 2s - loss: 0.2384 - accuracy: 0.9285
313/313 - 2s - loss: 0.2279 - accuracy: 0.9318
313/313 - 2s - loss: 0.2080 - accuracy: 0.9350
313/313 - 2s - loss: 0.2259 - accuracy: 0.9331
313/313 - 2s - loss: 0.2040 - accuracy: 0.9389
313/313 - 2s - loss: 0.1840 - accuracy: 0.9422
313/313 - 2s - loss: 0.1761 



KeyboardInterrupt: 

In [None]:
asdf = {
    "64,64": 94.83,
    "32": 94.65,
    "32, 32": 93.60,
    "64": 95.97,
    "128": 96.12
}

In [None]:
topologies_2 = 