In [1]:
import numpy as np
import matplotlib.pyplot as plt
import keras_spiking
import tensorflow as tf

In [2]:
# Load CIFAR 10
(
    (train_images, train_labels),
    (test_images, test_labels),
) = tf.keras.datasets.mnist.load_data()

In [3]:
# normalize images so values are between 0 and 1
train_images = train_images / 255.0
test_images = test_images / 255.0

In [6]:
from tensorflow.keras.callbacks import EarlyStopping

es = EarlyStopping(monitor='val_loss', mode='min', verbose=1)

topologies = [
    [64, 64],
    [32],
    [32, 32],
    [64],
    [128],
]

def generate_non_spiking_model(topology):
    layers = [tf.keras.layers.Flatten(input_shape=(28, 28))]
    for hidden_layer_size in topology:
        layers.append(tf.keras.layers.Dense(hidden_layer_size, activation="relu"))
    layers.append(tf.keras.layers.Dense(10))
    model = tf.keras.Sequential(layers)
    return model

def generate_spiking_model(topology):
    layers = [tf.keras.layers.Reshape((-1, 28 * 28), input_shape=(None, 28, 28))]
    for hidden_layer_size in topology:
        layers.append(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_layer_size)))
        layers.append(keras_spiking.SpikingActivation("relu", dt=0.01, spiking_aware_training=True))
    layers.append(tf.keras.layers.GlobalAveragePooling1D(),)
    layers.append(tf.keras.layers.Dense(10))
    model = tf.keras.Sequential(layers)
    return model

def train(input_model, train_x, test_x):
    input_model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )
    
    num_epochs = 20
    best_test_acc = 0
    for num_epoch in range(num_epochs):
        input_model.fit(train_x, train_labels, epochs=1)
        _, test_acc = input_model.evaluate(test_x, test_labels, verbose=2)
        if test_acc > best_test_acc:
            best_epoch_model = tf.keras.models.clone_model(input_model)
            best_test_acc = test_acc
    return test_acc, best_epoch_model

In [7]:
n_steps = 10
train_sequences = np.tile(train_images[:, None], (1, n_steps, 1, 1))
test_sequences = np.tile(test_images[:, None], (1, n_steps, 1, 1))

In [None]:
models = dict()
for topology in topologies:
    print(str(topology))
    spiking_model = generate_spiking_model(topology)
    test_acc, best_model = train(spiking_model, train_sequences, test_sequences)
    models[str(topology)] = best_model

[64, 64]
313/313 - 3s - loss: 0.6753 - accuracy: 0.8029
313/313 - 3s - loss: 0.4664 - accuracy: 0.8578
313/313 - 3s - loss: 0.3843 - accuracy: 0.8802
313/313 - 3s - loss: 0.3235 - accuracy: 0.9025
313/313 - 3s - loss: 0.3297 - accuracy: 0.8983
313/313 - 3s - loss: 0.2839 - accuracy: 0.9112
313/313 - 3s - loss: 0.2650 - accuracy: 0.9190
313/313 - 3s - loss: 0.2484 - accuracy: 0.9253
313/313 - 3s - loss: 0.2738 - accuracy: 0.9171
313/313 - 3s - loss: 0.2284 - accuracy: 0.9309
313/313 - 3s - loss: 0.2372 - accuracy: 0.9260

In [None]:
asdf = {
    "128,128": 94.20,
    "64,64": 93.47,
    "32": 92.51,
    "32,32": 91.08,
    "64": 94.53,
    "128": 94.63
}