In [6]:
# Boilerplate imports

import numpy as np

import torch
from torch.utils.data import DataLoader

from tqdm.autonotebook import trange
import time

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)

import matplotlib.pyplot as plt

In [7]:
# Load NIR graph and convert to a Rockpool torch module

import nir
from rockpool.nn.modules.torch.nir import from_nir

nirgraph = nir.read('braille_noDelay_noBias_subtract_subgraph.nir')
net = from_nir(nirgraph)
print(net)

GraphExecutor(
  (fc1): LinearTorch()
  (fc2): LinearTorch()
  (input): Identity()
  (lif1): LIFTorch()
  (lif2): LIFTorch()
  (output): Identity()
)


In [8]:
# Impose a common `dt` on the LIF layers
dt = 1e-4
net.lif1.dt = dt
net.lif2.dt = dt

In [9]:
# Rockpool imports for deployment
from rockpool.devices.xylo.syns61201 import mapper, config_from_specification, XyloSim, XyloSamna
from rockpool.transform.quantize_methods import channel_quantize

# - Map network to Xylo architecture
spec = mapper(net.as_graph())

# - Post-training quantization of network
Qspec = spec
Qspec.update(channel_quantize(**Qspec))
Qspec.pop('mapped_graph')

# - Produce and validate a bitstream for Xylo architecture
config, is_valid, msg = config_from_specification(**Qspec)

if not is_valid:
    print(msg)

In [10]:
# - Enumerat and connect to a Xylo HDK
from rockpool.devices.xylo import find_xylo_hdks

hdks, support_mods, _ = find_xylo_hdks()

# - Use a connected XYlo HDK, or use the bit-accurate simulator
if len(hdks) > 0:
    hdk = hdks[0]
    evolve_args = {"read_timeout": 40., "record_power": True}
    XyloModule = lambda config: support_mods[0].XyloSamna(hdk, config, dt = dt)
else:
    hdk = None
    evolve_args = {}
    XyloModule = lambda config: XyloSim.from_config(config, dt = dt)

The connected Xylo HDK contains a Xylo Audio v2 (SYNS61201). Importing `rockpool.devices.xylo.syns61201`


In [11]:
# - Configure Xylo with the network bitstream
mod = XyloModule(config)
print(mod)

XyloSamna  with shape (12, 40, 7)




In [12]:
### Specify test data
test_data_path = "data/ds_test.pt"
ds_test = torch.load(test_data_path)
letter_written = ["Space", "A", "E", "I", "O", "U", "Y"]


In [13]:
### RUN TESTS
n_samples = len(ds_test)  # dataset size: 140
predicted_labels = []
actual_labels = []

for i in trange(n_samples):
    single_sample = next(iter(DataLoader(ds_test, batch_size=1, shuffle=True)))
    sample = single_sample[0].numpy()[0].astype(int)  # shape: (256, 12)

    output, _, rec_dict = mod(sample, record = False, **evolve_args)
    n_output_spikes = np.sum(output, axis=0)

    predicted_label = int(np.argmax(n_output_spikes))
    actual_label = int(single_sample[1])
    # print("Predicted Label:", predicted_label)
    # print("Actual Label:   ", actual_label)
    predicted_labels.append(predicted_label)
    actual_labels.append(actual_label)

predicted_labels = np.array(predicted_labels)
actual_labels = np.array(actual_labels)
n_correct = np.count_nonzero(predicted_labels == actual_labels)
print(f"n_correct {n_correct} out of {n_samples} ({n_correct / n_samples * 100.}%)")

100%|██████████| 140/140 [00:21<00:00,  6.47it/s]

n_correct 120 out of 140 (85.71428571428571%)





In [14]:
# Save accuracy results and internal network activity
np.save('Xylo_accuracy_subtract.npy', n_correct / n_samples)

test_sample = torch.load("data/ds_test.pt")[0][0]
_, _, rec_dict = mod(sample, record = True, **evolve_args)

np.save('Xylo_activity_subtract.npy', rec_dict['Spikes'])

In [15]:
# Perform live power measurement on Xylo HDK
if len(hdks) > 0:
    clk = support_mods[0].xa2_devkit_utils.set_xylo_core_clock_freq(hdk, 6.25)
    print(f'Xylo clock freq: {clk} MHz')

    NT = 100
    start = time.time()
    _, _, rec_dict = mod(np.repeat(sample, NT, axis = 0), record = False, **evolve_args)
    inf_time = time.time() - start
else:
    print(f"No Xylo HDK, skipping power measurement")

Xylo clock freq: 6.25 MHz


In [16]:
if len(hdks) > 0:
    print(f"Inference total logic power: {np.mean(rec_dict['logic_power'] * 1e6):.2f} uW")
    print(f"Inference time: {inf_time:.2f}s; per sample {inf_time / NT:.3f}s")
    print(f"Inference total energy per sample: {np.mean(rec_dict['logic_power'] * 1e6) * inf_time / NT:.2f} uJ")

Inference total logic power: 277.09 uW
Inference time: 8.59s; per sample 0.086s
Inference total energy per sample: 23.80 uJ
