In [1]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import keras_spiking

tf.random.set_seed(0)
np.random.seed(0)

In [10]:
# Load MNIST Digits Dataset
((train_images, train_labels),(test_images, test_labels),) = tf.keras.datasets.mnist.load_data()
#train_labels = train_labels.squeeze()
#test_labels = test_labels.squeeze()

# maximum of each colour is 255
#print(np.max(train_images[0], axis=(0,1)))

# normalize images so values are between 0 and 1
train_images = train_images / 255.0
test_images = test_images / 255.0

#class_names = [
#    "airplane",
#    "automobile",
#    "bird",
#    "cat",
#    "deer",
#    "dog",
#    "frog",
#    "horse",
#    "ship",
#    "truck",
#]

class_names = [
    "0",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9"
]

num_classes = len(class_names)


In [11]:
train_images = np.expand_dims(train_images, axis=-1)
test_images = np.expand_dims(test_images, axis=-1)

In [12]:
model = tf.keras.Sequential(
    [
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
        
        #tf.keras.layers.Flatten(input_shape=train_images.shape[1:]),
        #tf.keras.layers.Dense(128, activation="relu"),
        #tf.keras.layers.Dense(num_classes),
    ]
)


def train(input_model, train_x, test_x):
    input_model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )

    input_model.fit(train_x, train_labels, epochs=10)

    _, test_acc = input_model.evaluate(test_x, test_labels, verbose=2)


    print("\nTest accuracy:", test_acc)


train(model, train_images, test_images)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
313/313 - 1s - loss: 0.0336 - accuracy: 0.9923

Test accuracy: 0.9922999739646912


In [13]:
# repeat the images for n_steps
n_steps = 10
train_sequences = np.tile(train_images[:, None], (1, n_steps, 1, 1, 1))
test_sequences = np.tile(test_images[:, None], (1, n_steps, 1, 1, 1))

In [None]:
spikeaware_model = tf.keras.Sequential(
    [
        tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(32, (3, 3), input_shape=(28, 28, 1))),
        keras_spiking.SpikingActivation("relu", dt=0.01, spiking_aware_training=True),
        tf.keras.layers.TimeDistributed(tf.keras.layers.MaxPooling2D((2, 2))),
        tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(64, (3, 3))),
        keras_spiking.SpikingActivation("relu", dt=0.01, spiking_aware_training=True),
        tf.keras.layers.TimeDistributed(tf.keras.layers.MaxPooling2D((2, 2))),
        tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(64, (3, 3))),
        
        tf.keras.layers.Reshape((-1, 3 * 3 * 64), input_shape=(None,10,3,3,64)),
        keras_spiking.SpikingActivation("relu", dt=0.01, spiking_aware_training=True),
        tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(64)),
        # set spiking_aware training and a moderate dt
        keras_spiking.SpikingActivation("relu", dt=0.01, spiking_aware_training=True),
        tf.keras.layers.GlobalAveragePooling1D(),
        tf.keras.layers.Dense(10),
    ]
)

# train the model, identically to the non-spiking version,
# except using the time sequences as inputs
train(spikeaware_model, train_sequences, test_sequences)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10