## MNIST classification on Loihi

In [None]:
from CNN.mnist_model import ConvNet
from CNN.utils import get_weights_biases
import torch
import time
from torchvision import datasets, transforms
import quartz
from quartz import layers
import numpy as np

## Load pre-trained ANN model and inspect parameters

In [None]:
n_classes = 10
model = ConvNet(n_classes)
model.load_state_dict(torch.load("CNN/models/mnist-convnet.pth", map_location=torch.device('cpu')))
model.eval()

In [None]:
weights, biases = get_weights_biases(model)
print(([weight.max() for weight in weights]))
print(([weight.min() for weight in weights]))

In [None]:
# last layer activity will necessarily be a bit higher because of logits, so need to scale down params here
weights[3] /= 3
biases[3] /= 3

## build Quartz model with parameters from ANN model

In [None]:
t_max = 2**7
input_dims = (1,28,28)
pool_kernel_size = [2,2]
batch_size = 100

loihi_model = quartz.Network(t_max, verbose=True, layers=[
    layers.InputLayer(dims=input_dims),
    layers.Conv2D(weights=weights[0], biases=biases[0]),
    layers.MaxPool2D(kernel_size=pool_kernel_size),
    layers.Conv2D(weights=weights[1], biases=biases[1]),
    layers.MaxPool2D(kernel_size=pool_kernel_size),
    layers.Conv2D(weights=weights[2], biases=biases[2]),
    layers.Dense(weights=weights[3], biases=biases[3]), 
])

n_cores_per_layer = [0,4,1,2,1,1,1]

In [None]:
loihi_model

## Load test data and classify or run the power benchmark

In [None]:
transform=transforms.Compose([transforms.ToTensor(),])
test_set = datasets.MNIST('./CNN/data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
test_loader_iter = iter(test_loader)

In [None]:
# use this to skip to a certain batch in case script failed midway
# for i in range(6):
#     inputs, target = next(test_loader_iter)

In [None]:
print(time.strftime("Started on %a, %d %b %Y %H:%M:%S", time.gmtime()))
start_time = time.time()

profiling = False # change this flag for classification or power benchmarks!

errors = []
avg_first_spikes = []
for b, (inputs, target) in enumerate(test_loader_iter):
    if not profiling:
        loihi_output_values = loihi_model(inputs.detach().numpy(), n_cores_per_layer, partition='loihi_2h', logging=True, profiling=profiling)
        # some of the outputs might spike multiple times so instead of the output values, we rely on the first spikes for every batch
        avg_first_spikes.append(np.mean([time-i*loihi_model.steps_per_image for i, time in enumerate(np.min(loihi_model.first_spikes, axis=0))]))
        print("Average first spike: " + str(avg_first_spikes[-1]))
        classification_results = np.argmin(loihi_model.first_spikes, axis=0)
        positives = sum(classification_results == target.numpy())
        errors.append(100*(1-positives/len(target)))
        print("Correctly detected {} out of {}: {}% error".format(positives, len(target), str(errors[-1])))
    else:
        energy_probe = loihi_model(inputs.detach().numpy(), partition='nahuku32_2h', logging=True, profiling=profiling)
        break
    
    print("Batch {} finished within {} seconds.".format(b, time.time() - start_time))
    start_time = time.time()
    break

In [None]:
print(loihi_model.compartments_on_core.reshape(-1,8))

In [None]:
loihi_model.compartments_on_core[:20]

In [None]:
np.sum(avg_first_spikes)/len(avg_first_spikes)

In [None]:
np.sum(errors)/len(errors)

In [None]:
with open("mnist-results.txt", "a") as myfile:
    myfile.write("{}\n".format(errors))