In [1]:
!pip install snntoolbox
!pip install onnx
!pip install onnxruntime

Collecting snntoolbox
  Downloading snntoolbox-0.6.0-py2.py3-none-any.whl (203 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m203.9/203.9 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: snntoolbox
Successfully installed snntoolbox-0.6.0
Collecting onnx
  Downloading onnx-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.7/15.7 MB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: onnx
Successfully installed onnx-1.15.0
Collecting onnxruntime
  Downloading onnxruntime-1.16.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.4/6.4 MB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0

In [2]:
"""End-to-end example for SNN Toolbox.

This script sets up a small CNN using PyTorch, trains it for one epoch on
MNIST, stores model and dataset in a temporary folder on disk, creates a
configuration file for SNN toolbox, and finally calls the main function of SNN
toolbox to convert the trained ANN to an SNN and run it using INI simulator.
"""

import os
import shutil
import inspect
import time

import numpy as np
import torch
import torch.nn as nn
from tensorflow.keras import backend
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

from snntoolbox.bin.run import main
from snntoolbox.utils.utils import import_configparser
from tests.parsing.models.pytorch import Model


# Pytorch to Keras parser needs image_data_format == channel_first.
backend.set_image_data_format('channels_first')


In [3]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [30]:

# WORKING DIRECTORY #
#####################

# Define path where model and output files will be stored.
# The user is responsible for cleaning up this temporary directory.
path_wd = '/content/drive/MyDrive/Dissertation/project_code/ann_models/'

# GET DATASET #
###############

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize input so we can train ANN with it.
# Will be converted back to integers for SNN layer.
x_train = x_train / 255
x_test = x_test / 255

# Add a channel dimension.
axis = 1 if backend.image_data_format() == 'channels_first' else -1
x_train = np.expand_dims(x_train, axis)
x_test = np.expand_dims(x_test, axis)

# One-hot encode target vectors before saving for use in toolbox.
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Save dataset so SNN toolbox can find it.
np.savez_compressed(os.path.join(path_wd, 'x_test'), x_test)
np.savez_compressed(os.path.join(path_wd, 'y_test'), y_test)
# SNN toolbox will not do any training, but we save a subset of the training
# set so the toolbox can use it when normalizing the network parameters.
np.savez_compressed(os.path.join(path_wd, 'x_norm'), x_train[::10])

# Pytorch doesn't support one-hot labels, so we undo it for training the ANN.
y_train = np.argmax(y_train, 1)
y_test = np.argmax(y_test, 1)

In [31]:
class PytorchDataset(torch.utils.data.Dataset):
    def __init__(self, data, target, transform=None):
        self.data = torch.from_numpy(data).float()
        self.target = torch.from_numpy(target).long()
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]

        if self.transform:
            x = self.transform(x)

        return x, self.target[index]

    def __len__(self):
        return len(self.data)


trainset = torch.utils.data.DataLoader(PytorchDataset(x_train, y_train),
                                       batch_size=64)
testset = torch.utils.data.DataLoader(PytorchDataset(x_test, y_test),
                                      batch_size=64)

In [6]:
# CREATE ANN #
##############

# This section creates a CNN using pytorch, and trains it with backpropagation.
# There are no spikes involved at this point.

# Create pytorch model from definition in separate script.
model = Model()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

# Train model with backprop.
acc = 0
for epoch in range(3):
    for i, (xx, y) in enumerate(trainset):
        optimizer.zero_grad()
        outputs = model(xx)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

    total = 0
    correct = 0
    with torch.no_grad():
        for xx, y in testset:
            outputs = model(xx)
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    acc = correct / total

print("Test accuracy: {:.2%}".format(acc)) #86.37%

Test accuracy: 96.04%


In [21]:
# Store weights so SNN Toolbox can find them.
model_name = 'cnn_model'
#torch.save(model.state_dict(), os.path.join(path_wd, model_name + '.pth'))

#torch.load(model.state_dict(), os.path.join(path_wd, model_name + '.pth'))

# model.load_state_dict(torch.load('/content/drive/MyDrive/Dissertation/project_code/ann_models/cnn_model.pkl'))

torch.save(model.state_dict(), '/content/drive/MyDrive/Dissertation/project_code/ann_models/cnn_model.pkl')


In [32]:
# SNN TOOLBOX CONFIGURATION #
#############################

# Create a config file with experimental setup for SNN Toolbox.
configparser = import_configparser()
config = configparser.ConfigParser()

config['paths'] = {
    'path_wd': path_wd,             # Path to model.
    'dataset_path': path_wd,        # Path to dataset.
    'filename_ann': model_name      # Name of input model.
}

config['tools'] = {
    'evaluate_ann': True,           # Test ANN on dataset before conversion.
    'normalize': True               # Normalize weights for full dynamic range.
}

config['simulation'] = {
    'simulator': 'INI',             # Chooses execution backend of SNN toolbox.
    'duration': 50,                 # Number of time steps to run each sample.
    'num_to_test': 100,             # How many test samples to run.
    'batch_size': 50,               # Batch size for simulation.
    'keras_backend': 'tensorflow'   # Which keras backend to use.
}

config['input'] = {
    'model_lib': 'pytorch'          # Input model is defined in pytorch.
}

config['output'] = {
    'plot_vars': {                  # Various plots (slows down simulation).
        'spiketrains',              # Leave section empty to turn off plots.
        'spikerates',
        'activations',
        'correlation',
        'v_mem',
        'error_t'}
}

In [33]:
# Store config file.
config_filepath = os.path.join(path_wd, 'config')
with open(config_filepath, 'w') as configfile:
    config.write(configfile)

In [34]:
# Need to copy model definition over to ``path_wd`` (needs to be in same dir as
# the weights saved above).
source_path = inspect.getfile(Model)
shutil.copyfile(source_path, os.path.join(path_wd, model_name + '.py'))

'/content/drive/MyDrive/Dissertation/project_code/ann_models/cnn_model.py'

In [36]:
# !pip install onnx
# !pip install onnxruntime
!pip install onnx2keras

Collecting onnx2keras
  Downloading onnx2keras-0.0.24.tar.gz (20 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: onnx2keras
  Building wheel for onnx2keras (setup.py) ... [?25l[?25hdone
  Created wheel for onnx2keras: filename=onnx2keras-0.0.24-py3-none-any.whl size=24577 sha256=181dfa994e29acdd8068a39cca80b12c653d941fbde5409960ebb3c680a160ef
  Stored in directory: /root/.cache/pip/wheels/a1/fb/c9/349c27912022d104c7dd5f5d272595c33b1b959c4468d5e784
Successfully built onnx2keras
Installing collected packages: onnx2keras
Successfully installed onnx2keras-0.0.24


In [37]:
# RUN SNN TOOLBOX #
###################

main(config_filepath)

Initializing INI simulator...

Loading data set from '.npz' files in /content/drive/MyDrive/Dissertation/project_code/ann_models.

Error(s) in loading state_dict for Model:
	Missing key(s) in state_dict: "trunk.0.weight", "trunk.0.bias", "branch1.0.weight", "branch1.0.bias", "branch2.0.weight", "branch2.0.bias", "head.0.weight", "head.0.bias", "classifier.1.weight", "classifier.1.bias". 
	Unexpected key(s) in state_dict: "network.0.weight", "network.0.bias", "network.1.weight", "network.1.bias", "network.1.running_mean", "network.1.running_var", "network.1.num_batches_tracked", "network.4.weight", "network.4.bias", "network.5.weight", "network.5.bias", "network.5.running_mean", "network.5.running_var", "network.5.num_batches_tracked", "network.8.weight", "network.8.bias", "network.9.weight", "network.9.bias", "network.9.running_mean", "network.9.running_var", "network.9.num_batches_tracked", "network.13.weight", "network.13.bias". 
Pytorch model was successfully ported to ONNX.


ValueError: ignored