In [1]:
import numpy as np
import matplotlib.pyplot as plt
import keras_spiking
import tensorflow as tf

In [2]:
# Load CIFAR 10
(
    (train_images, train_labels),
    (test_images, test_labels),
) = tf.keras.datasets.mnist.load_data()

In [3]:
# normalize images so values are between 0 and 1
train_images = train_images / 255.0
test_images = test_images / 255.0

In [4]:
from tensorflow.keras.callbacks import EarlyStopping

es = EarlyStopping(monitor='val_loss', mode='min', verbose=1)

topologies = [
    [32],
    [32, 32],
    [64],
    [64, 64],
    [128],
    [128, 128]
]

def generate_non_spiking_model(topology):
    layers = [tf.keras.layers.Flatten(input_shape=(28, 28))]
    for hidden_layer_size in topology:
        layers.append(tf.keras.layers.Dense(hidden_layer_size, activation="relu"))
    layers.append(tf.keras.layers.Dense(10))
    model = tf.keras.Sequential(layers)
    return model

def generate_spiking_model(topology):
    layers = [tf.keras.layers.Reshape((-1, 28 * 28), input_shape=(None, 28, 28))]
    for hidden_layer_size in topology:
        layers.append(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_layer_size)))
        layers.append(keras_spiking.SpikingActivation("relu", spiking_aware_training=False))
        layers.append(tf.keras.layers.GlobalAveragePooling1D(),)
    layers.append(tf.keras.layers.Dense(10))
    model = tf.keras.Sequential(layers)
    return model

def train(input_model, train_x, test_x):
    input_model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )
    num_epochs = 10
    best_test_acc = 0
    for epoch in range(num_epochs):
        input_model.fit(train_x, train_labels, epochs=1)
        _, test_acc = input_model.evaluate(test_x, test_labels, verbose=2)
        if test_acc > best_test_acc:
            best_test_acc = test_acc

In [5]:
for topology in topologies:
    spiking_model = generate_spiking_model([128])
n_steps = 10
train_sequences = np.tile(train_images[:, None], (1, n_steps, 1, 1))
test_sequences = np.tile(test_images[:, None], (1, n_steps, 1, 1))
hist = train(spiking_mod_1, train_sequences, test_sequences)
print(hist)


313/313 - 2s - loss: 10.7269 - accuracy: 0.1810

Test accuracy: 0.1809999942779541 , Epoch: 1
313/313 - 2s - loss: 11.5206 - accuracy: 0.1757

Test accuracy: 0.17569999396800995 , Epoch: 2
313/313 - 2s - loss: 12.1336 - accuracy: 0.1729

Test accuracy: 0.1729000061750412 , Epoch: 3
313/313 - 2s - loss: 13.2952 - accuracy: 0.1695

Test accuracy: 0.16949999332427979 , Epoch: 4
313/313 - 2s - loss: 14.2442 - accuracy: 0.1725

Test accuracy: 0.17249999940395355 , Epoch: 5
313/313 - 2s - loss: 15.8159 - accuracy: 0.1787

Test accuracy: 0.17870000004768372 , Epoch: 6
313/313 - 2s - loss: 16.2201 - accuracy: 0.1803

Test accuracy: 0.18029999732971191 , Epoch: 7
313/313 - 2s - loss: 17.5342 - accuracy: 0.1852

Test accuracy: 0.18520000576972961 , Epoch: 8
313/313 - 2s - loss: 17.6922 - accuracy: 0.1773

Test accuracy: 0.17730000615119934 , Epoch: 9
313/313 - 2s - loss: 18.8027 - accuracy: 0.1874

Test accuracy: 0.1873999983072281 , Epoch: 10


NameError: name 'hist' is not defined