In [22]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import nengo
import nengo_dl
from tensorflow.keras.utils import to_categorical

#sources:
#https://www.kaggle.com/code/vtu5118/cifar-10-using-vgg16
#https://towardsdatascience.com/creating-vgg-from-scratch-using-tensorflow-a998a5640155

In [23]:
vgg16_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',
                    include_top=False,
                    classes=10,
                    input_shape=(32,32,3)# input: 32x32 images with 3 channels -> (32, 32, 3) tensors.
                   )

model = tf.keras.models.Sequential()

# Add vgg16 layers
for layer in vgg16_model.layers:
    # Replace max pool with avg pool (nengo doesnt support max pooling)
    # if isinstance(layer, tf.keras.layers.MaxPooling2D):
    #     model.add(tf.keras.layers.AveragePooling2D(layer.pool_size, layer.strides, layer.padding, layer.data_format))
    #     continue
    model.add(layer)


# Fully connected layers
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512, activation='relu', name='hidden1'))
model.add(tf.keras.layers.Dropout(0.4))
model.add(tf.keras.layers.Dense(256, activation='relu', name='hidden2'))
model.add(tf.keras.layers.Dropout(0.4))
model.add(tf.keras.layers.Dense(10, activation='softmax', name='predictions'))

model.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 block1_conv1 (Conv2D)       (None, 32, 32, 64)        1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 32, 32, 64)        36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 16, 16, 64)        0         
                                                                 
 block2_conv1 (Conv2D)       (None, 16, 16, 128)       73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 16, 16, 128)       147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 8, 8, 128)         0         
                                                                 
 block3_conv1 (Conv2D)       (None, 8, 8, 256)        

In [24]:
# Load CIFAR10 data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# One hot encode labels
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)

# Data normalization
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train  /= 255
x_test /= 255

In [25]:
# Shuffle data before splitting into validation set

p = np.random.permutation(len(x_train))
print(x_train.shape)
x_train, y_train = x_train[p], y_train[p]
print(x_train.shape)

# 20% validation, 80% training
val_split = 0.2
num_val = int(val_split * len(x_train))

# Split into train, validation, and test sets
x_val = x_train[:num_val]
y_val = y_train[:num_val]
print(x_val.shape)
print(y_val.shape)

x_train = x_train[num_val:]
y_train = y_train[num_val:]
print(x_train.shape)
print(y_train.shape)

(50000, 32, 32, 3)
(50000, 32, 32, 3)
(10000, 32, 32, 3)
(10000, 10)
(40000, 32, 32, 3)
(40000, 10)


In [26]:
# For saving model weights
checkpoint_path = "training/vgg16.weights.h5"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

In [27]:
# Train model with tf
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)
batch_size_tf = 128
epochs = 100

model.compile(
    optimizer=optimizer,
    # optimizer='adam',
    loss='categorical_crossentropy',
    metrics=[
        'accuracy'
    ]
)

# def lr_scheduler(epoch):
#     return 0.001 * (0.5 ** (epoch // 20))

# reduce_lr = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)

# aug = tf.keras.preprocessing.image.ImageDataGenerator(
#     rotation_range=20,
#     zoom_range=0.15,
#     width_shift_range=0.2,
#     height_shift_range=0.2,
#     shear_range=0.15,
#     horizontal_flip=True,
#     fill_mode="nearest")

# model.fit(aug.flow(x_train,y_train, batch_size=batch_size_tf),
#           batch_size=batch_size_tf,
#           epochs=epochs,
#           callbacks=[reduce_lr, cp_callback],
#           validation_data=(x_val, y_val))

In [28]:
model.load_weights(checkpoint_path)

loss, acc = model.evaluate(x_test, y_test)
# 88% acc with tf
print("Test Acc: {:5.2f}%, Test Loss: {:5.2f}".format(100 * acc, loss))

Test Acc: 88.21%, Test Loss:  0.40


In [29]:
# Add time dimension to data for SNN
def add_time_dimension(arr):
    return np.reshape(arr, (arr.shape[0], 1, -1))

X_train_t = add_time_dimension(x_train)
Y_train_t = add_time_dimension(y_train)
X_val_t = add_time_dimension(x_val)
Y_val_t = add_time_dimension(y_val)
X_test_t = add_time_dimension(x_test)
Y_test_t = add_time_dimension(y_test)

In [126]:
converter = nengo_dl.Converter(
    model,
    swap_activations={
        tf.keras.activations.relu: nengo.SpikingRectifiedLinear()
    },
    scale_firing_rates=20,
    synapse=0.1,
    inference_only=True,
    # max_to_avg_pool=True
)



In [127]:
# Source: https://r-gaurav.github.io/2021/03/07/Spiking-Neural-Nets-for-Image-Classification-in-Nengo-DL.html
# Tile the test images n_steps times.
def get_nengo_compatible_test_data_generator(x, y, batch_size=100, n_steps=30):
  """
  Returns a test data generator of tiled (i.e. repeated) images.

  Args:
    batch_size <int>: Number of data elements in each batch.
    n_steps <int>: Number of timesteps for which the test data has to
                   be repeated.
  """
  num_images = x.shape[0]
  # Flatten the images
  reshaped_x = x.reshape((num_images, 1, -1))
  # Tile/Repeat them for `n_steps` times.
  tiled_x = np.tile(reshaped_x, (1, n_steps, 1))

  for i in range(0, num_images, batch_size):
    yield (tiled_x[i:i+batch_size], y[i:i+batch_size])

In [128]:
model_input = list(converter.inputs.keys())[0]
model_output = list(converter.outputs.keys())[0]

model_layers = list(converter.layers.keys())
conv1 = model_layers[1]
penltmt_layer = model_layers[-3]

# Get the probes for Input, first Conv, and the Output layers.
ndl_mdl_inpt = converter.inputs[model_input] # Input layer is Layer 0.
ndl_mdl_otpt = converter.outputs[model_output] # Output layer is last.

with converter.net:
  nengo_dl.configure_settings(stateful=False, keep_history=False) # Optimize simulation speed.
  # Probe for the first Conv layer.
  first_conv_probe = nengo.Probe(converter.layers[conv1])
  # Probe for penultimate dense layer.
  penltmt_dense_probe = nengo.Probe(converter.layers[[penltmt_layer]])

In [143]:
n_steps = 20 # Number of timesteps
batch_size = 4
collect_spikes_output = True

ndl_mdl_spikes = [] # To store the spike outputs of the first Conv layer and the
                # penultimate dense layer whose probes we defined earlier.
ndl_mdl_otpt_cls_probs = [] # To store the true class labels and the temporal
                            # class-probabilities output of the model.

num_samples = 16
x_test_slice = X_test_t[:num_samples]
y_test_slice = Y_test_t[:num_samples]

test_batches = get_nengo_compatible_test_data_generator(x_test_slice, y_test_slice,
    batch_size=batch_size, n_steps = n_steps)

ndl_mdl_inpt = converter.inputs[model_input] # Input layer is Layer 0.
ndl_mdl_otpt = converter.outputs[model_output] # Output layer is last.

num_batches = int(x_test_slice.shape[0] / batch_size)

# tiled_test_images = np.tile(x_test_slice, (1, n_steps, 1))

# Run the simulation.
# with nengo_dl.Simulator(converter.net) as sim:
#     data = sim.predict({ndl_mdl_inpt:tiled_test_images})

# predictions = np.argmax(data[ndl_mdl_otpt][:, -1], axis=-1)
# accuracy = (predictions ==y_test_slice[:num_samples, 0, 0]).mean()
# print(f"Test accuracy: {100 * accuracy:.2f}%")
with nengo_dl.Simulator(converter.net, minibatch_size=batch_size, progress_bar=False) as sim:
    # Predict on each batch.
    print(f"Starting inference on {num_batches} batches...")
    for i, batch in enumerate(test_batches):
        if len(batch[0]) < batch_size:
            print(f"Batch {i} too small ({len(batch)} < {batch_size}). Skipping...")
            continue
        print(f"Running inference on batch {(i + 1)}/{num_batches}...")
        sim_data = sim.predict_on_batch({ndl_mdl_inpt: batch[0]})
        # print(batch[1], sim_data[ndl_mdl_otpt])
        ndl_mdl_otpt_cls_probs.append((batch[1][-1], sim_data[ndl_mdl_otpt][-1]))
        # for y_true, y_pred in zip(batch[1], sim_data[ndl_mdl_otpt]):
        #     # Note that y_true is an array of shape (10,) and y_pred is a matrix of
        #     # shape (n_steps, 10) where 10 is the number of classes in CIFAR-10 dataset.
        #     ndl_mdl_otpt_cls_probs.append((y_true, y_pred))
        
        # Collect the spikes if required.
        if collect_spikes_output:
            for i in range(batch_size): # Collecting spikes for each image in first batch.
                ndl_mdl_spikes.append({
                    first_conv_probe.obj.ensemble.label: sim_data[first_conv_probe][i],
                    penltmt_dense_probe.obj.ensemble.label: sim_data[penltmt_dense_probe][i]
                })
            # Not collecting the spikes for rest batches to save memory.
            collect_spikes_output = False


Starting inference on 4 batches...
Running inference on batch 1/4...
Running inference on batch 2/4...
Running inference on batch 3/4...
Running inference on batch 4/4...


In [147]:
acc = 0
temporal_cls_probs = [] # To store the temporal class-probabilities of each test image.
for y_true, y_pred in ndl_mdl_otpt_cls_probs:
  # Pick the spiking network's last time-step output, therefore -1 in y_pred.
  temporal_cls_probs.append(y_pred)
  print(y_true.shape, y_pred.shape)
  if np.argmax(y_true) == np.argmax(y_pred):
    acc += 1

print("Spiking network prediction accuracy: %0.4f %%" % (acc * 100/ x_test_slice.shape[0]))

(1, 10) (1, 10)
(1, 10) (1, 10)
(1, 10) (1, 10)
(1, 10) (1, 10)
Spiking network prediction accuracy: 0.0000 %
