## CNN

In [None]:
from CNN.convnet import ConvNet
import torch
import time
from torchvision import datasets, transforms
import quartz
from quartz import layers
import numpy as np
np.set_printoptions(suppress=True)

In [None]:
n_classes = 10
model = ConvNet(n_classes)
model.load_state_dict(torch.load("CNN/quartz-convnet-6c.pth", map_location=torch.device('cpu')))
model.eval()

In [None]:
def get_weights_biases(model):
    parameters = list(model.parameters())
    weights = [weight.detach().numpy() for weight in parameters[::2][::2]]
    biases = [bias.detach().numpy() for bias in parameters[1::2][::2]]
    return weights, biases

weights, biases = get_weights_biases(model)

In [None]:
([weight.max() for weight in weights])

In [None]:
([weight.min() for weight in weights])

In [None]:
weights[1] /= 1.5
biases[1] /= 1.5
weights[2] /= 1.5
biases[2] /= 1.5
weights[-1] /= 3
biases[-1] /= 3

## build model

In [None]:
t_max = 2**6
input_dims = (1,28,28)
pool_kernel_size = [2,2]
steps_per_image = 6*t_max
batch_size = 5000

loihi_model = quartz.Network([
    layers.InputLayer(dims=input_dims),

#     layers.Conv2D(weights=weights[0], biases=biases[0]),
#     layers.MaxPool2D(kernel_size=pool_kernel_size),
#     layers.Conv2D(weights=weights[1], biases=biases[1]),
#     layers.MaxPool2D(kernel_size=pool_kernel_size),
    layers.ConvPool2D(weights=weights[0], biases=biases[0], pool_kernel_size=pool_kernel_size),
    layers.ConvPool2D(weights=weights[1], biases=biases[1], pool_kernel_size=pool_kernel_size),

    layers.Conv2D(weights=weights[2], biases=biases[2]),
    layers.Dense(weights=weights[3], biases=biases[3], rectifying=False),
    layers.MonitorLayer(),
])

In [None]:
loihi_model


In [None]:
#loihi_model.check_block_delays(t_max, 2**3)
#loihi_model.print_core_layout(redo=True)

## load data

In [None]:
#transform=transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(),])
transform=transforms.Compose([transforms.ToTensor(),])
test_set = datasets.MNIST('./CNN/data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [None]:
test_loader_iter = iter(test_loader)

In [None]:
print(time.strftime("Started on %a, %d %b %Y %H:%M:%S", time.gmtime()))
start_time = time.time()

errors = []
for inputs, target in test_loader_iter:
    loihi_output = loihi_model(inputs.numpy(), t_max, steps_per_image=steps_per_image, partition='loihi_2h', 
                               logging=False, profiling=False)
    firsts = np.zeros((n_classes,batch_size))
    for i, (key, values) in enumerate(sorted(loihi_model.data[1].items())):
        for iteration in range(batch_size):
            firsts[i, iteration] = values[(values>(iteration * steps_per_image)) & (values<((iteration+1)*steps_per_image))][0]
    loihi_results = np.argmin(firsts, axis=0)
    positives = sum(loihi_results == target.numpy())
    negatives = loihi_results != target.numpy()
    error = 100*(1-positives/len(target))
    errors.append(error)
    print("Correctly detected {} out of {}: {}% error".format(positives, len(target), error))
    #break
    
print("--- %s seconds ---" % (time.time() - start_time))

In [None]:
errors

In [None]:
with open("results.txt", "a") as myfile:
    myfile.write("{}\n".format(errors))

In [None]:
firsts.min()

In [None]:
np.max(loihi_output)

In [None]:
%debug