In [1]:
# - Numpy
import numpy as np
import torch
from rockpool.nn.modules import LinearTorch, LIFTorch
from rockpool.parameters import Constant
from rockpool.nn.combinators import Sequential
# - Matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = [12, 4]
plt.rcParams['figure.dpi'] = 300

# - Rockpool time-series handling
from rockpool import TSEvent, TSContinuous

# - Pretty printing
try:
    from rich import print
except:
    pass

# - Display images
from IPython.display import Image

# - Disable warnings
import warnings
warnings.filterwarnings('ignore')
from rockpool.nn.networks.wavesense import WaveSenseNet
from rockpool.transform import quantize_methods as q
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from pathlib import Path
from tqdm.asyncio import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Dataset(Dataset):
    def __init__(self,root_pos,root_neg):
        self.sample = []
        pos_dir = Path(root_pos)
        neg_dir = Path(root_neg)
        for i in sorted(pos_dir.rglob('*.npy')):
            if (str(i.parts[-1][9:16]) != 'trail1_'):
                array = np.load(str(i),allow_pickle=True)
                tensor = torch.from_numpy(array.T)
                tensor = torch.tensor(tensor,dtype=torch.float)
                condititon = [tensor,torch.tensor(1)]
                self.sample.append(condititon)
                condititon = []
        for i in sorted(neg_dir.rglob('*.npy')):
            if (str(i.parts[-1][9:16]) != 'trail1_'):
                array = np.load(str(i))
                tensor = torch.from_numpy(array.T)
                tensor = torch.tensor(tensor,dtype=torch.float)
                condititon = [tensor,torch.tensor(0)]
                self.sample.append(condititon)
                condititon = []
                
            
    def __getitem__(self,idx):
        data = self.sample[idx][0]
        label = self.sample[idx][1]
        return data,label
    
    def __len__(self):
        return len(self.sample)
    
dataset_test = Dataset('/home/ruixing/workspace/chbtar/chb/data/test_data/spike/pos',
                       '/home/ruixing/workspace/chbtar/chb/data/test_data/spike/neg')
spiking_test_dataloader = DataLoader(dataset_test,batch_size=1,shuffle=True)

In [3]:
dilations = [2, 32]
n_out_neurons = 2
n_inp_neurons = 4
n_neurons = 32
kernel_size = 2
tau_mem = 0.002
base_tau_syn = 0.002
tau_lp = 0.01
threshold = 0.6
dt = 0.001
net = WaveSenseNet(
    dilations=dilations,
    n_classes=n_out_neurons,
    n_channels_in=n_inp_neurons,#in_channel
    n_channels_res=n_neurons,
    n_channels_skip=n_neurons,
    n_hidden=n_neurons,
    kernel_size=kernel_size,
    bias=Constant(0.0),
    smooth_output=True,
    tau_mem=Constant(tau_mem),
    base_tau_syn=base_tau_syn,
    tau_lp=tau_lp,
    threshold=Constant(threshold),
    neuron_model=LIFTorch,
    dt=dt,
)
net.load('/home/ruixing/workspace/chbtar/chb/models/SNN_model_Isyn.pth')

In [4]:
# - Import the Xylo HDK detection function
from rockpool.devices.xylo import find_xylo_hdks

# - Detect a connected HDK and import the required support package
connected_hdks, support_modules, chip_versions = find_xylo_hdks()

found_xylo = len(connected_hdks) > 0

if found_xylo:
    hdk = connected_hdks[0]
    x = support_modules[0]
else:
    assert False, 'This tutorial requires a connected Xylo HDK to run.'


The connected Xylo HDK contains a Xylo Audio v2 (SYNS61201). Importing `rockpool.devices.xylo.syns61201`


In [5]:
spec = x.mapper(net.as_graph(), weight_dtype = 'float')
spec.update(q.global_quantize(**spec))

In [6]:
# - Use rockpool.devices.xylo.config_from_specification
config, is_valid, msg = x.config_from_specification(**spec)
# - Use rockpool.devices.xylo.XyloSamna to deploy to the HDK
if found_xylo:
    modSamna = x.XyloSamna(hdk, config, dt = dt)
    print(modSamna)

ValueError: Invalid configuration for the Xylo HDK: Reservoir neuron 68 fanout must be in [0,63]. Actual: 64
Reservoir neuron 76 fanout must be in [0,63]. Actual: 64
Reservoir neuron 79 fanout must be in [0,63]. Actual: 64
Reservoir neuron 86 fanout must be in [0,63]. Actual: 64
Reservoir neuron 91 fanout must be in [0,63]. Actual: 64
Reservoir neuron 94 fanout must be in [0,63]. Actual: 64
Reservoir neuron 160 fanout must be in [0,63]. Actual: 64
Reservoir neuron 162 fanout must be in [0,63]. Actual: 64
Reservoir neuron 168 fanout must be in [0,63]. Actual: 64
Reservoir neuron 174 fanout must be in [0,63]. Actual: 64
Reservoir neuron 179 fanout must be in [0,63]. Actual: 64
Reservoir neuron 182 fanout must be in [0,63]. Actual: 64
Reservoir neuron 183 fanout must be in [0,63]. Actual: 64
Reservoir neuron 184 fanout must be in [0,63]. Actual: 64
Reservoir neuron 189 fanout must be in [0,63]. Actual: 64


In [7]:
n_0 = 0
n_1 = 0
consequence_list = []
for data,label in tqdm(spiking_test_dataloader,colour='yellow'):
    net.reset_state()
    data = torch.reshape(data,(500,4))
    data = data.numpy()
    data = data.astype(int)
    data = (data*20).clip(0, 15)
    output, state, recordings = modSamna(data,record=True,read_timeout=20)
    out = recordings['Vmem_out'].squeeze()
    # print(np.any(out))
    peaks = out.max(0)
    result = peaks.argmax()
    print('peaks:',peaks)
    print('result:',result)
    print('label:',label)
    if result.item() == 0:
        n_0  += 1
    if result.item() == 1:
        n_1  += 1
    # result.to(device)
    consequence = (result==label.item())
    consequence_list.append(consequence)
    
acc = sum(consequence_list)/len(consequence_list)
print(f'accuracy:{acc}')
print(f'number of zero:{n_0}，number of one:{n_1}')

  0%|[33m          [0m| 0/395 [00:20<?, ?it/s]


TimeoutError: Processing didn't finish for 20s. Read 0 events