## CIFAR10 classification

In [None]:
from CNN.cifar_model import MobileNet
import torch
import torch.nn as nn
import time
from torchvision import datasets, transforms
import quartz
from quartz import layers
import numpy as np
np.set_printoptions(suppress=True)

In [None]:
model = MobileNet(10)#.to(DEVICE)
model.load_state_dict(torch.load("CNN/cifar-convnet-pickle2.pth", map_location=torch.device('cpu')))
capture = model.eval()

In [None]:
previous_module = None
new_layers = []
for module in model.modules():
    if isinstance(module, (nn.Conv2d, nn.MaxPool2d, nn.BatchNorm2d, nn.Linear, nn.ReLU6, nn.ReLU, nn.Flatten)):
        if isinstance(module, nn.BatchNorm2d) and isinstance(previous_module, nn.Conv2d):
            new_layers[-1] = torch.nn.utils.fuse_conv_bn_eval(previous_module, module)
        else:
            new_layers.append(module)
        previous_module = module

folded_model = nn.Sequential(*new_layers)

In [None]:
layer_list = list(folded_model.modules())[1:]

## build model

In [None]:
t_max = 2**7
input_dims = (3,32,32)
pool_kernel_size = [2,2]

loihi_layers = []
for l, layer in enumerate(layer_list):
    rectification = l < len(layer_list)-1 and isinstance(layer_list[l+1], (nn.ReLU6, nn.ReLU))
    if isinstance(layer, nn.Conv2d):
        loihi_layers.append(layers.Conv2D(weights=layer.weight.detach().numpy(), biases=layer.bias.detach().numpy(), stride=layer.stride, padding=layer.padding, groups=layer.groups, rectifying=rectification))
    elif isinstance(layer, nn.Linear):
        loihi_layers.append(layers.Dense(weights=layer.weight.detach().numpy(), biases=layer.bias.detach().numpy(), rectifying=rectification))
    elif isinstance(layer, nn.MaxPool2d):
        loihi_layers.append(layers.MaxPool2D(kernel_size=layer.kernel_size, stride=layer.stride))

loihi_layers = [layers.InputLayer(dims=input_dims)] + loihi_layers
loihi_layers

In [None]:
loihi_model = quartz.Network(t_max, loihi_layers[:3], verbose=True)

In [None]:
loihi_model

## load data

In [None]:
batch_size = 10

transform=transforms.Compose([transforms.ToTensor(),])
test_set = datasets.CIFAR10('./CNN/data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [None]:
test_loader_iter = iter(test_loader)

In [None]:
print(time.strftime("Started on %a, %d %b %Y %H:%M:%S", time.gmtime()))
start_time = time.time()

profiling = True # change this flag for classification or power benchmarks!

errors = []
avg_first_spikes = []
for b, (inputs, target) in enumerate(test_loader_iter):
    if not profiling:
        loihi_output_values = loihi_model(inputs.detach().numpy(), partition='loihi_2h', logging=True, profiling=profiling)
        # some of the outputs might spike multiple times so instead of the output values, we rely on the first spikes for every batch
        avg_first_spikes.append(np.mean([time-i*loihi_model.steps_per_image for i, time in enumerate(np.min(loihi_model.first_spikes, axis=0))]))
        print("Average first spike: " + str(avg_first_spikes[-1]))
        classification_results = np.argmin(loihi_model.first_spikes, axis=0)
        positives = sum(classification_results == target.numpy())
        errors.append(100*(1-positives/len(target)))
        print("Correctly detected {} out of {}: {}% error".format(positives, len(target), str(errors[-1])))
    else:
        energy_probe = loihi_model(inputs.detach().numpy(), partition='nahuku32_2h', logging=True, profiling=profiling)
        break
    
    print("Batch {} finished within {} seconds.".format(b, time.time() - start_time))
    start_time = time.time()
#     break

In [None]:
np.sum(errors)/len(errors)

In [None]:
with open("results.txt", "a") as myfile:
    myfile.write("{}\n".format(errors))

In [None]:
%debug