In [4]:
import samna
import torch
import pickle
import tqdm
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
from IPython.display import HTML
from torch.utils.data import DataLoader, Dataset, random_split

### Step 1

Load and visualize the data.
- Difference of AER and speck spikes

Use a small data sample to show speck spikes

In [2]:
def events_to_frames(events, time_interval = 10000):
    frames = []
    index = 0
    current_frame = torch.zeros(2, 128, 128)
    for (t, y, x, p) in tqdm.tqdm(events.int()):
        if t // time_interval > index:
            frames.append(current_frame.clone())
            current_frame.fill_(0)
            index += 1
        current_frame[p, x, y] = +1
    return torch.stack(frames)

ship_events = torch.load('ship_tensor.pt')
apple_events = torch.load('apple_tensor.pt')
car_events = torch.load('car_tensor.pt')

ship = events_to_frames(ship_events[:1000000])
apple = events_to_frames(apple_events[:1000000])
car = events_to_frames(car_events[:1000000])


100%|██████████████████████████████| 1000000/1000000 [00:40<00:00, 24543.84it/s]
100%|██████████████████████████████| 1000000/1000000 [00:36<00:00, 27596.55it/s]
100%|██████████████████████████████| 1000000/1000000 [00:34<00:00, 29288.16it/s]


In [5]:
def animate_frames(frames, figure=None, interval: int = 20, **kwargs):
    if figure is None:
        figure, _ = plt.subplots(**kwargs)
    ax = figure.gca()

    image = ax.imshow(frames[0])  # .T)
    ax.set_axis_off()

    def animate(index):
        image.set_data(frames[index])  # .T)
        return image

    anim = FuncAnimation(figure, animate, frames=len(frames), interval=interval)
    video = anim.to_html5_video()
    html = HTML(video)
    display(html)
    plt.tight_layout()
    plt.close()



def events_to_frames(frames, polarity: bool = True):
    if len(frames.shape) == 3:
        frames = frames.unsqueeze(-1).repeat(1, 1, 1, 3)
    else:
        if not polarity:
            frames = frames.abs().sum(-1)
        elif polarity:
            frames = torch.concat([frames, torch.zeros(frames.shape[0], 1, *frames.shape[2:], device=frames.device)], dim=1).movedim(1, -1)
    frames = ((frames / frames.max()) * 255).int().clip(0, 255)
    return frames


#visualize 2 seconds of the events
#change the interval parameter

animate_frames(events_to_frames(car, polarity=True), interval = 10)
# print(len(ship))

### Step 2

##### Split in data points
split the recording in 1sec recs

##### Split in train test
use train-test split (maybe not dataloader - maybe yes if time)

In [34]:
class SpeckDataset(Dataset):
    def __init__(self, frames, targets, transform=None, target_transform=None):
        self.targets = targets
        self.frames = frames
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        frames = self.frames[idx]
        label = self.targets[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return frames, label



c = car[:100*(car.shape[0]//100)].reshape((car.shape[0]//100, 100, *car.shape[1:]))
s = ship[:100*(ship.shape[0]//100)].reshape((ship.shape[0]//100, 100, *ship.shape[1:]))
a = apple[:100*(apple.shape[0]//100)].reshape((apple.shape[0]//100, 100, *apple.shape[1:]))

c_t = torch.zeros(c.shape[0])
s_t = torch.zeros(s.shape[0])+1
a_t = torch.zeros(a.shape[0])+2


data = torch.cat((c, s, a), dim=0)
targets = torch.cat((c_t, s_t, a_t), dim=0)


train_perc = 0.8
samples = data.shape[0]
batch_size = 10


dataset = SpeckDataset(data, targets)
trainset, testset = random_split(dataset, [train_perc, 1-train_perc])

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True)

### Step 3

Train SNN in dynapCNN - try NIR

In [35]:
import sinabs.layers as sl
from torch import nn
from sinabs.activation.surrogate_gradient_fn import PeriodicExponential

# just replace the ReLU layer with the sl.IAFSqueeze


snn_bptt = nn.Sequential(
    # [2, 128, 128] -> [4, 64, 64]
    nn.Conv2d(in_channels=2, out_channels=4, kernel_size=(5, 5), padding=(2, 2), stride=(2, 2), bias=False),
    sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, surrogate_grad_fn=PeriodicExponential()),
    # [4, 64, 64] -> [8, 32, 32]
    nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(3, 3), padding=(1, 1), bias=False),
    sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, surrogate_grad_fn=PeriodicExponential()),
    nn.AvgPool2d(2, 2),
    # [8 * 32 * 32] -> [16, 16, 16]
    nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2),  bias=False),
    sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, surrogate_grad_fn=PeriodicExponential()),
    # [16, 16, 16] -> [32, 8, 8] 
    nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(5, 5), padding=(2, 2), stride=(2, 2),  bias=False),
    sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, surrogate_grad_fn=PeriodicExponential()),
    # [32, 8, 8] -> [32, 2, 2]
    nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(7, 7), padding=(3, 3), stride=(4, 4),  bias=False),
    sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, surrogate_grad_fn=PeriodicExponential()),

    # [32, 2, 2] -> [10]
    nn.Flatten(),
    nn.Linear(128, 3, bias=False),
    sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, surrogate_grad_fn=PeriodicExponential()),
)

# init the model weights
for layer in snn_bptt.modules():
    if isinstance(layer, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_normal_(layer.weight.data)
        
        
        
try:
    from sinabs.exodus import conversion
    snn_bptt = conversion.sinabs_to_exodus(snn_bptt)
except ImportError:
    print("Sinabs-exodus is not intalled.")

snn_bptt

Sinabs-exodus is not intalled.


Sequential(
  (0): Conv2d(2, 4, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
  (1): IAFSqueeze(spike_threshold=Parameter containing:
  tensor(1.), min_v_mem=Parameter containing:
  tensor(-1.), batch_size=10, num_timesteps=-1)
  (2): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (3): IAFSqueeze(spike_threshold=Parameter containing:
  tensor(1.), min_v_mem=Parameter containing:
  tensor(-1.), batch_size=10, num_timesteps=-1)
  (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (5): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (6): IAFSqueeze(spike_threshold=Parameter containing:
  tensor(1.), min_v_mem=Parameter containing:
  tensor(-1.), batch_size=10, num_timesteps=-1)
  (7): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
  (8): IAFSqueeze(spike_threshold=Parameter containing:
  tensor(1.), min_v_mem=Parameter containing:
  tensor(-1.), batch_size=10, num_timesteps=-1)


In [36]:
from torch.optim import SGD
from torch.nn import CrossEntropyLoss


lr = 0.004
optimizer = SGD(params=snn_bptt.parameters(), lr=lr)
criterion = CrossEntropyLoss()
epochs = 10
n_time_steps = 100

for e in range(epochs):

    # train
    for data, label in tqdm.tqdm(trainloader):
        # reshape the input from [Batch, Time, Channel, Height, Width] into [Batch*Time, Channel, Height, Width]
        data = data.reshape(-1, 2, 128, 128)
        # forward
        optimizer.zero_grad()
        output = snn_bptt(data)
        # reshape the output from [Batch*Time,num_classes] into [Batch, Time, num_classes]
        b = label.shape[0]
        output = output.reshape(b, n_time_steps, -1)
        # accumulate all time-steps output for final prediction
        output = output.sum(dim=1)
        print(output)
        print(label)
        loss = criterion(output, label.long())/torch.sigmoid(0.1*output.sum()+1e-4)
        print(loss)
        # backward
        loss.backward()
        optimizer.step()
        # detach the neuron states and activations from current computation graph(necessary)
        for layer in snn_bptt.modules():
            if isinstance(layer, sl.StatefulLayer):
                for name, buffer in layer.named_buffers():
                    buffer.detach_()
        
        
    # validate
    correct_predictions = []
    with torch.no_grad():
        for data, label in tqdm.tqdm(testloader):
            # reshape the input from [Batch, Time, Channel, Height, Width] into [Batch*Time, Channel, Height, Width]
            data = data.reshape(-1, 2, 128, 128)
            # forward
            output = snn_bptt(data)
            # reshape the output from [Batch*Time,num_classes] into [Batch, Time, num_classes]
            b = label.shape[0]
            output = output.reshape(b, n_time_steps, -1)
            # accumulate all time-steps output for final prediction
            output = output.sum(dim=1)
            # calculate accuracy
            pred = output.argmax(dim=1, keepdim=True)
            # compute the total correct predictions
            correct_predictions.append(pred.eq(label.view_as(pred)))
            # set progressing bar
            
    
        correct_predictions = torch.cat(correct_predictions)
        print(f"Epoch {e} - BPTT accuracy: {correct_predictions.sum().item()/(len(correct_predictions))*100}%")

  0%|                                                     | 0/5 [00:00<?, ?it/s]

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], grad_fn=<SumBackward1>)
tensor([0., 1., 0., 2., 0., 1., 2., 1., 1., 1.])
tensor(2.1971, grad_fn=<DivBackward0>)


 20%|█████████                                    | 1/5 [00:20<01:21, 20.27s/it]

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], grad_fn=<SumBackward1>)
tensor([1., 2., 1., 1., 2., 1., 0., 0., 1., 1.])
tensor(2.1971, grad_fn=<DivBackward0>)


 40%|██████████████████                           | 2/5 [00:39<00:59, 19.87s/it]

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], grad_fn=<SumBackward1>)
tensor([2., 0., 2., 2., 2., 0., 0., 2., 1., 1.])
tensor(2.1971, grad_fn=<DivBackward0>)


 60%|███████████████████████████                  | 3/5 [00:54<00:34, 17.38s/it]

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], grad_fn=<SumBackward1>)
tensor([1., 2., 1., 1., 2., 1., 1., 1., 1., 2.])
tensor(2.1971, grad_fn=<DivBackward0>)


 80%|████████████████████████████████████         | 4/5 [01:07<00:15, 15.90s/it]

tensor([[0., 0., 0.]], grad_fn=<SumBackward1>)
tensor([0.])
tensor(2.1971, grad_fn=<DivBackward0>)


100%|█████████████████████████████████████████████| 5/5 [01:08<00:00, 13.66s/it]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.66it/s]


Epoch 0 - BPTT accuracy: 22.22222222222222%


  0%|                                                     | 0/5 [00:00<?, ?it/s]

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], grad_fn=<SumBackward1>)
tensor([0., 2., 1., 1., 1., 0., 1., 0., 0., 2.])
tensor(2.1971, grad_fn=<DivBackward0>)


 20%|█████████                                    | 1/5 [00:10<00:41, 10.32s/it]

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], grad_fn=<SumBackward1>)
tensor([1., 1., 2., 0., 1., 1., 2., 1., 1., 2.])
tensor(2.1971, grad_fn=<DivBackward0>)


 40%|██████████████████                           | 2/5 [00:25<00:38, 12.96s/it]

tensor([[0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [1., 0., 0.],
        [0., 0., 0.],
        [1., 1., 0.]], grad_fn=<SumBackward1>)
tensor([1., 2., 1., 2., 2., 2., 2., 0., 1., 2.])
tensor(1.7905, grad_fn=<DivBackward0>)


 60%|███████████████████████████                  | 3/5 [00:42<00:29, 14.86s/it]

tensor([[1., 1., 0.],
        [1., 2., 0.],
        [1., 1., 1.],
        [0., 1., 1.],
        [0., 1., 1.],
        [1., 1., 0.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 0., 0.],
        [0., 0., 0.]], grad_fn=<SumBackward1>)
tensor([0., 1., 2., 1., 1., 0., 0., 1., 1., 1.])
tensor(1.1002, grad_fn=<DivBackward0>)


 80%|████████████████████████████████████         | 4/5 [00:55<00:14, 14.35s/it]

tensor([[0., 1., 0.]], grad_fn=<SumBackward1>)
tensor([1.])
tensor(1.0504, grad_fn=<DivBackward0>)


100%|█████████████████████████████████████████████| 5/5 [00:56<00:00, 11.24s/it]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.45it/s]


Epoch 1 - BPTT accuracy: 11.11111111111111%


  0%|                                                     | 0/5 [00:00<?, ?it/s]

tensor([[2., 2., 0.],
        [0., 0., 1.],
        [2., 2., 0.],
        [1., 1., 0.],
        [2., 3., 1.],
        [1., 1., 1.],
        [1., 2., 1.],
        [1., 2., 1.],
        [1., 2., 1.],
        [1., 1., 1.]], grad_fn=<SumBackward1>)
tensor([1., 1., 1., 1., 0., 2., 2., 1., 2., 0.])
tensor(1.1528, grad_fn=<DivBackward0>)


 20%|█████████                                    | 1/5 [00:14<00:58, 14.58s/it]

tensor([[0., 1., 0.],
        [2., 1., 0.],
        [1., 1., 1.],
        [0., 1., 1.],
        [1., 0., 1.],
        [1., 1., 1.],
        [2., 2., 1.],
        [1., 0., 0.],
        [1., 1., 1.],
        [1., 1., 0.]], grad_fn=<SumBackward1>)
tensor([0., 1., 2., 1., 1., 2., 0., 2., 2., 1.])
tensor(1.3260, grad_fn=<DivBackward0>)


 40%|██████████████████                           | 2/5 [00:33<00:51, 17.32s/it]

tensor([[1., 1., 1.],
        [0., 1., 0.],
        [1., 1., 1.],
        [2., 2., 0.],
        [1., 2., 0.],
        [1., 1., 0.],
        [2., 2., 0.],
        [2., 2., 1.],
        [2., 1., 1.],
        [1., 1., 1.]], grad_fn=<SumBackward1>)
tensor([1., 2., 1., 1., 2., 1., 0., 0., 2., 1.])
tensor(1.2539, grad_fn=<DivBackward0>)


 60%|███████████████████████████                  | 3/5 [00:49<00:33, 16.76s/it]

tensor([[1., 0., 1.],
        [3., 2., 1.],
        [1., 1., 1.],
        [1., 0., 1.],
        [1., 0., 1.],
        [2., 2., 1.],
        [2., 2., 2.],
        [2., 2., 3.],
        [1., 3., 1.],
        [1., 1., 0.]], grad_fn=<SumBackward1>)
tensor([1., 2., 1., 1., 1., 0., 0., 2., 0., 1.])
tensor(1.4975, grad_fn=<DivBackward0>)


 60%|███████████████████████████                  | 3/5 [01:03<00:42, 21.20s/it]


KeyboardInterrupt: 

### Step 4

Load model in speck chip

In [37]:
from sinabs.backend.dynapcnn import DynapcnnNetwork

dynapcnn = DynapcnnNetwork(snn=snn_bptt, input_shape=(2, 128, 128), discretize=True, dvs_input=True)
devkit_name = "speck2fmodule"

# use the `to` method of DynapcnnNetwork to deploy the SNN to the devkit
dynapcnn.to(device=devkit_name, chip_layers_ordering="auto")
print(f"The SNN is deployed on the core: {dynapcnn.chip_layers_ordering}")


Network is valid
The SNN is deployed on the core: [0, 1, 2, 3, 5, 4]


### Run inference

Run in real time - use trained data

In [38]:
import time

while True:
    output_events = dynapcnn.samna_output_buffer.get_events()
    neuron_index = [each.feature for each in output_events]
    if neuron_index != []:
        print(neuron_index)
    time.sleep(0.1)

[0]
[1]
[0, 1]
[2]
[0]
[2]
[1]
[0]
[1]
[1]
[0]
[1]
[0]
[2]
[1, 0, 2]
[1]
[0]
[1]
[2]
[0]
[1]
[2]
[0]
[1]
[2]
[1]
[0]
[1, 2]
[0]
[1]
[0]
[2]
[1]
[1, 0]
[0, 2]
[2, 1]
[0, 1]
[0]
[2]
[1]
[0]
[1]
[2]
[0]
[1, 2]
[0]
[2, 1]
[0]
[1]
[0]
[1]
[2]
[1]
[0]
[2]
[1, 0]
[2]
[1]
[2]
[0, 1]
[2]
[0]
[1]
[1]
[2]
[1, 0]
[0]
[2]
[1]
[2]
[1]
[0]
[1, 0]
[1]
[0, 2]
[1]
[0]
[2, 1]
[1, 2]
[0]
[1, 0]
[2]
[1]
[0]
[2]
[0]
[1]
[1]
[0, 1]
[2, 1]
[0, 2]
[1]
[2]
[0]
[1]


KeyboardInterrupt: 