## CIFAR10 classification

In [None]:
from CNN.cifar_model import ConvNet, ConvBNReLU, Bottleneck, ConvPool
import torch
import torch.nn as nn
import time
from torchvision import datasets, transforms
import quartz
from quartz import layers
import numpy as np
np.set_printoptions(suppress=True)

In [None]:
n_classes = 10
model = ConvNet(n_classes)
model.load_state_dict(torch.load("CNN/cifar-convnet.pth", map_location=torch.device('cpu')))
model.eval()

previous_module = None
new_layers = []
for module in model.modules():
    if isinstance(module, (nn.Conv2d, nn.MaxPool2d, nn.BatchNorm2d, nn.Linear)):
        if isinstance(module, nn.BatchNorm2d) and isinstance(previous_module, nn.Conv2d):
            new_layers[-1] = torch.nn.utils.fuse_conv_bn_eval(previous_module, module)
        else:
            new_layers.append(module)
        previous_module = module
folded_model = nn.Sequential(*new_layers)

In [None]:
folded_model

In [None]:
conv_layer = list(folded_model.modules())[1]

## build model

In [None]:
t_max = 2**7
input_dims = (3,32,32)
pool_kernel_size = [2,2]
steps_per_image = 7*t_max
batch_size = 10

loihi_layers = [layers.Conv2D(weights=conv_layer.weight.detach().numpy(), biases=conv_layer.bias.detach().numpy(),
                              padding=conv_layer.padding, groups=conv_layer.groups) for conv_layer in folded_model.modules() if isinstance(conv_layer, nn.Conv2d)]
loihi_layers = [layers.InputLayer(dims=input_dims)] + loihi_layers

loihi_model = quartz.Network(t_max, loihi_layers)

In [None]:
loihi_model

In [None]:
#loihi_model.check_block_delays(t_max, 2**3)
#loihi_model.print_core_layout(redo=True)

## load data

In [None]:
transform=transforms.Compose([transforms.ToTensor(),])
test_set = datasets.MNIST('./CNN/data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [None]:
test_loader_iter = iter(test_loader)

In [None]:
print(time.strftime("Started on %a, %d %b %Y %H:%M:%S", time.gmtime()))
start_time = time.time()

errors = []
for inputs, target in test_loader_iter:
    loihi_output = loihi_model(inputs.numpy(), t_max, steps_per_image=steps_per_image, partition='nahuku32_2h', 
                               logging=False, profiling=False)
    #break
    firsts = np.zeros((n_classes,batch_size))
    for i, (key, values) in enumerate(sorted(loihi_model.data[1].items())[:n_classes]):
        for iteration in range(batch_size):
            firsts[i, iteration] = values[(values>(iteration * steps_per_image)) & (values<((iteration+1)*steps_per_image))][0]
    loihi_results = np.argmin(firsts, axis=0)
    positives = sum(loihi_results == target.numpy())
    negatives = loihi_results != target.numpy()
    error = 100*(1-positives/len(target))
    errors.append(error)
    print("Correctly detected {} out of {}: {}% error".format(positives, len(target), error))
    break
    
print("--- %s seconds ---" % (time.time() - start_time))

In [None]:
errors

In [None]:
inputs, target = next(test_loader_iter)

In [None]:
values

In [None]:
target

In [None]:
loihi_model.data

In [None]:
firsts


In [None]:
sorted(loihi_model.data[1].items())[:batch_size]

In [None]:
#loihi_output.rawPowerTimeStamps

In [None]:
np.sum(errors)/len(errors)

In [None]:
with open("results.txt", "a") as myfile:
    myfile.write("{}\n".format(errors))

In [None]:
firsts.min()

In [None]:
np.max(loihi_output)

In [None]:
%debug