In [1]:
import numpy as np
import matplotlib.pyplot as plt
import keras_spiking
import tensorflow as tf

In [2]:
# Load CIFAR 10
(
    (train_images, train_labels),
    (test_images, test_labels),
) = tf.keras.datasets.mnist.load_data()

In [3]:
# normalize images so values are between 0 and 1
train_images = train_images / 255.0
test_images = test_images / 255.0

In [4]:
from tensorflow.keras.callbacks import EarlyStopping

es = EarlyStopping(monitor='val_loss', mode='min', verbose=1)

topologies = [
    [32],
    [32, 32],
    [64, 32],
    [64],
    [64, 64],
    [128, 64],
    [128],
    [128, 128],
    [256, 128]
]

def generate_non_spiking_model(topology):
    layers = [tf.keras.layers.Flatten(input_shape=(28, 28))]
    for hidden_layer_size in topology:
        layers.append(tf.keras.layers.Dense(hidden_layer_size, activation="relu"))
    layers.append(tf.keras.layers.Dense(10))
    model = tf.keras.Sequential(layers)
    return model

def generate_spiking_model(topology):
    layers = [tf.keras.layers.Reshape((-1, 28 * 28), input_shape=(None, 28, 28))]
    for hidden_layer_size in topology:
        layers.append(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_layer_size)))
        layers.append(keras_spiking.SpikingActivation("relu", spiking_aware_training=False))
        layers.append(tf.keras.layers.GlobalAveragePooling1D(),)
    layers.append(tf.keras.layers.Dense(10))
    model = tf.keras.Sequential(layers)
    return model

def train(input_model, train_x, test_x):
    input_model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )

    input_model.fit(train_x, train_labels, epochs=10)

    _, test_acc = input_model.evaluate(test_x, test_labels, verbose=2, callbacks=[es])

    print("\nTest accuracy:", test_acc)

In [5]:
spiking_mod_1 = generate_spiking_model([128])
n_steps = 10
train_sequences = np.tile(train_images[:, None], (1, n_steps, 1, 1))
test_sequences = np.tile(test_images[:, None], (1, n_steps, 1, 1))
train(spiking_mod_1, train_sequences, test_sequences)
