In [1]:
import numpy as np

import torch
from torch.utils.data import DataLoader, TensorDataset

from tqdm.autonotebook import trange
import time

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)

import matplotlib.pyplot as plt

  from tqdm.autonotebook import trange


In [2]:
import nir
from rockpool.nn.modules.torch.nir import from_nir, to_nir

nirgraph = nir.read('braille_subtract_subgraph.nir')
net = from_nir(nirgraph)

  torch.tensor(weight) if weight is not None else None,
  super().register_parameter(key, nn.Parameter(torch.tensor(value.data)))
  dt=torch.min(torch.tensor(_to_tensor(lif_node.tau_mem / (1+lif_node.r)))).item(),
  dt=torch.min(torch.tensor(_to_tensor(node.tau_mem / (1+node.r)))).item(),


In [3]:
net

GraphExecutor(
  (fc1): LinearTorch()
  (fc2): LinearTorch()
  (input): Identity()
  (lif1): LIFTorch()
  (lif2): LIFTorch()
  (output): Identity()
)

In [4]:
# Impose common `dt`
dt = 1e-4

net.lif1.dt = dt
net.lif2.dt = dt

In [5]:
from rockpool.devices.xylo.syns61201 import mapper, config_from_specification, XyloSim, XyloSamna
from rockpool.transform.quantize_methods import channel_quantize

spec = mapper(net.as_graph())

Found weights LinearWeights "LinearTorch__6424411920" with 12 input nodes -> 55 output nodes with biases. Set `has_bias = False` for this module .
Found weights LinearWeights "LinearTorch__4615411312" with 55 input nodes -> 7 output nodes with biases. Set `has_bias = False` for this module .


In [6]:
Qspec = spec
Qspec.update(channel_quantize(**Qspec))
Qspec.pop('mapped_graph')

GraphHolder "GraphExecutor_6424757920" with 12 input nodes -> 7 output nodes

In [7]:
net.lif1.tau_mem

Parameter containing:
tensor([0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007,
        0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007,
        0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007,
        0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007,
        0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007,
        0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007, 0.0007,
        0.0007], requires_grad=True)

In [8]:

config, is_valid, msg = config_from_specification(**Qspec)

if not is_valid:
    print(msg)

In [9]:
from rockpool.devices.xylo import find_xylo_hdks

hdks, support_mods, _ = find_xylo_hdks()

if len(hdks) > 0:
    hdk = hdks[0]
    evolve_args = {"read_timeout": 40., "record_power": True}
    XyloModule = lambda config: support_mods[0].XyloSamna(hdk, config, dt = dt)
else:
    hdk = None
    evolve_args = {}
    XyloModule = lambda config: XyloSim.from_config(config, dt = dt)

The connected Xylo HDK contains a Xylo Audio v2 (SYNS61201). Importing `rockpool.devices.xylo.syns61201`


In [10]:
mod = XyloModule(config)
mod



XyloSamna  with shape (12, 55, 7)

In [11]:
### TEST DATA
test_data_path = "data/ds_test.pt"
ds_test = torch.load(test_data_path)
letter_written = ["Space", "A", "E", "I", "O", "U", "Y"]


In [12]:

### RUN TESTS
n_samples = len(ds_test)  # dataset size: 140
predicted_labels = []
actual_labels = []

for i in trange(n_samples):
    single_sample = next(iter(DataLoader(ds_test, batch_size=1, shuffle=True)))
    sample = single_sample[0].numpy()[0].astype(int)  # shape: (256, 12)

    output, _, rec_dict = mod(sample, record = False, **evolve_args)
    n_output_spikes = np.sum(output, axis=0)

    predicted_label = int(np.argmax(n_output_spikes))
    actual_label = int(single_sample[1])
    # print("Predicted Label:", predicted_label)
    # print("Actual Label:   ", actual_label)
    predicted_labels.append(predicted_label)
    actual_labels.append(actual_label)

predicted_labels = np.array(predicted_labels)
actual_labels = np.array(actual_labels)
n_correct = np.count_nonzero(predicted_labels == actual_labels)
print(f"n_correct {n_correct} out of {n_samples} ({n_correct / n_samples * 100.}%)")

100%|██████████| 140/140 [00:22<00:00,  6.36it/s]

n_correct 58 out of 140 (41.42857142857143%)





In [13]:
# Save results
np.save('Xylo_accuracy_subtract.npy', n_correct / n_samples)

test_sample = torch.load("data/ds_test.pt")[0][0]
_, _, rec_dict = mod(sample, record = True, **evolve_args)

np.save('Xylo_activity_subtract.npy', rec_dict['Spikes'])

In [14]:
# Power measurement
clk = support_mods[0].xa2_devkit_utils.set_xylo_core_clock_freq(hdk, 6.25)
print(f'Xylo clock freq: {clk} MHz')

NT = 10
start = time.time()
_, _, rec_dict = mod(np.repeat(sample, NT, axis = 0), record = False, **evolve_args)
inf_time = time.time() - start


Xylo clock freq: 6.25 MHz


In [16]:
print(f"Inference logic power: {np.mean(rec_dict['logic_power'] * 1e6)} uW")
print(f"Inference time: {inf_time}s")
print(f"Energy per sample: {np.mean(rec_dict['logic_power'] * 1e6) * inf_time / NT} uJ")

Inference logic power: 278.62548828125006 uW
Inference time: 1.7315008640289307s
Energy per sample: 48.24402736994672 uJ
