## MNIST classification on Loihi

In [None]:
from CNN.mnist_model import ConvNet
from CNN.utils import get_weights_biases
import torch
import time
from torchvision import datasets, transforms
import quartz
from quartz import layers
import numpy as np

## Load pre-trained ANN model and inspect parameters

In [None]:
n_classes = 10
model = ConvNet(n_classes)
model.load_state_dict(torch.load("CNN/models/mnist-convnet.pth", map_location=torch.device('cpu')))
model.eval()

In [None]:
weights, biases = get_weights_biases(model)

# these scaling factors have been found by normalizing the weights and afterwards scaling again by activations for each layer. Details in mnist-analyis notebook.
scaling_factors = [1.2470399948327828, 0.9052967932549607, 1.4555583482919765, 0.2525526552185343]

# joint scaling of parameters
for i, factor in enumerate(scaling_factors):
    weights[i] *= factor
    biases[i] *= factor

## build Quartz model with parameters from ANN model

In [None]:
t_max = 2**4
input_dims = (1,28,28)
pool_kernel_size = [2,2]

loihi_model = quartz.Network(t_max, verbose=True, layers=[
    layers.InputLayer(dims=input_dims),
    layers.Conv2D(weights=weights[0], biases=biases[0]),
    layers.MaxPool2D(kernel_size=pool_kernel_size),
    layers.Conv2D(weights=weights[1], biases=biases[1]),
    layers.MaxPool2D(kernel_size=pool_kernel_size),
    layers.Conv2D(weights=weights[2], biases=biases[2]),
    layers.Dense(weights=weights[3], biases=biases[3]), 
])

n_cores_per_layer = [0,5,3,3,2,2,2]

In [None]:
loihi_model

In [None]:
n_conns_per_layer = [layer.n_outgoing_connections() for layer in loihi_model.layers]
[n_conns // n_cores for n_cores, n_conns in zip(n_cores_per_layer[1:], n_conns_per_layer)]

## Load test data

In [None]:
transform=transforms.Compose([transforms.ToTensor(),])
test_set = datasets.MNIST('./CNN/data', train=False, download=False, transform=transform)

## Classify images

In [None]:
start_time = time.time()
batch_size = 2500
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
test_loader_iter = iter(test_loader)

errors = []
avg_first_spikes = []
for b, (inputs, target) in enumerate(test_loader_iter):
    loihi_output_values = loihi_model(inputs.detach().numpy(), n_cores_per_layer=n_cores_per_layer, partition='loihi_2h', logging=True)
    # some of the outputs might spike multiple times so instead of the output values, we rely on the first spikes for every batch
    avg_first_spikes.append(np.mean([time-i*loihi_model.steps_per_image for i, time in enumerate(np.min(loihi_model.first_spikes, axis=0))]))
    print("Average first spike: " + str(avg_first_spikes[-1]))
    classification_results = np.argmin(loihi_model.first_spikes, axis=0)
    positives = sum(classification_results == target.numpy())
    errors.append(100*(1-positives/len(target)))
    print("Correctly detected {} out of {}: {}% error".format(positives, len(target), str(errors[-1])))    
    print("Batch {} finished within {} seconds.".format(b+1, time.time() - start_time))
    start_time = time.time()

print("Average first spike for test set: {}".format(np.sum(avg_first_spikes)/len(avg_first_spikes)))
print("Accuracy error for test set: {}".format(np.sum(errors)/len(errors)))

## Power benchmarks

In [None]:
batch_size = 3000
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)
inputs, targets = next(iter(test_loader))
# energy probes currently only supported on nahuku32
energy_probe = loihi_model(inputs.detach().numpy(), n_cores_per_layer=n_cores_per_layer, partition='nahuku32', logging=True, profiling=True)

In [None]:
print(loihi_model.compartments_on_core.reshape(-1,8))

In [None]:
loihi_model.compartments_on_core[:20]

In [None]:
with open("mnist-results.txt", "a") as myfile:
    myfile.write("{}\n".format(errors))