## CNN

In [None]:
%matplotlib inline
from CNN.lenet import LeNet5
import torch
from torchvision import datasets, transforms
import quartz
from quartz import layers
import numpy as np

In [None]:
model = LeNet5(10)
model.load_state_dict(torch.load("CNN/lenet.pth", map_location=torch.device('cpu')))
model.eval()

In [None]:
def get_weights_biases(model):
    parameters = list(model.parameters())
    weights = [weight.detach().numpy() for weight in parameters[::2]]
    biases = [bias.detach().numpy() for bias in parameters[1::2]]
    return weights, biases

weights, biases = get_weights_biases(model)

## build model

In [None]:
t_max = 2**9
run_time = 40*t_max
input_dims = (1,32,32,2)
weight_e = 500
weight_acc = 128
weight_args = {'weight_e':weight_e, 'weight_acc':weight_acc}
pool_kernel_size = [2,2]

loihi_model = quartz.Network([
    layers.InputLayer(dims=input_dims, **weight_args),
    layers.Conv2D(weights=weights[0], biases=biases[0], **weight_args),
    layers.MaxPool2D(kernel_size=pool_kernel_size, **weight_args),
    layers.Conv2D(weights=weights[1], biases=biases[1], **weight_args),
    layers.MaxPool2D(kernel_size=pool_kernel_size, **weight_args),
    layers.Conv2D(weights=weights[2], biases=biases[2], **weight_args),
    layers.FullyConnected(weights=weights[3], biases=biases[3], **weight_args),
    layers.FullyConnected(weights=weights[4], biases=biases[4], split_output=False, **weight_args),
    layers.MonitorLayer(**weight_args),
])

In [None]:
loihi_model

In [None]:
%debug

## load data

In [None]:
transform=transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(),])
test_set = datasets.MNIST('./CNN/data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set)

In [None]:
sample, target = next(iter(test_loader))
inputs = quartz.utils.decode_values_into_spike_input(sample.squeeze().detach().numpy().flatten(), t_max)

In [None]:
output_values, spike_times = l8.run_on_loihi(run_time, t_max=t_max, input_spike_list=inputs, partition="nahuku32", num_chips=3, plot=False)

In [None]:
l0.n_all_connections()

In [None]:
l1.n_all_connections()