## BindsNet Poisson Encoder

In [None]:
%pip install bindsnet
%pip install torch

### Libraries

In [None]:
import os
import copy
import numpy as np
import torch
import pandas as pd

from matplotlib import pyplot as plt

from bindsnet.encoding import PoissonEncoder
from bindsnet.datasets import MNIST
from torchvision import transforms

plt.rcParams['figure.figsize'] = [20, 10]

### Parameters

In [None]:
interval_time = 250  # time in ms
dt = 1.0  # interval length in ms
intensity = 128  # input layer Poisson spikes maximum firing rate, in Hz

### Load MNIST dataset

In [None]:
# Load MNIST data.
dataset = MNIST(
    image_encoder=PoissonEncoder(time=interval_time, dt=dt),
    label_encoder=None,
    root=os.path.join("..", "data", "MNIST"),
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)])
)

### Examples

In [None]:
print(f"This dataset is dict with the following keys: {dataset[0].keys()}")

In [None]:
first_image = np.array(dataset[0]["image"], dtype='float')
print(f"The first element is the digit: {dataset[0]['label']}")
pixels = first_image.reshape((28, 28))
plt.imshow(pixels, cmap='gray')
plt.show()

##### In pixel grey levels

In [None]:
dataset[0]["image"]

In [None]:
dataset[0]["image"].shape

In [None]:
# Check for instance the 9th row
dataset[0]["image"][0][9]

##### Encode in Spikes

In [None]:
encoded_image_1 = copy.deepcopy(dataset[0]["encoded_image"])

In [None]:
encoded_image_1.shape

In [None]:
# example of spike tensor in interval 2, row 9
encoded_image_1[2][0][8]

This tensor is created via the BindsNet Poisson encoder. This encoder creates tensors with spike times during interval $\delta$ with Poisson distribution based on intensity of image pixel. In case new spikes were generated the value for this interval in the output tensor becomes 1.

In [None]:
rows = []
spikes = []
sum_row = 0
for r in range(encoded_image_1.shape[2]):
    for i in range(int(interval_time / dt)):
        sum_row += torch.sum(encoded_image_1[i][0][r])
        # print(encoded_image_1[i][0][r])
    rows.append(r)
    spikes.append(int(sum_row))
    print(f"The number of spike on row {r} are: {sum_row}")
    sum_row = 0

In [None]:
row_spikes = pd.DataFrame({"image_row": rows, "number_spikes": spikes})
plt.bar("image_row", "number_spikes", data=row_spikes)
plt.xlabel("Image row")
plt.ylabel("Number of spikes")
plt.title("Number of generated spikes per image row")
plt.show()

##### Example 2

In [None]:
second_image = np.array(dataset[1]["image"], dtype='float')
print(f"The second element is the digit: {dataset[1]['label']}")
pixels = second_image.reshape((28, 28))
plt.imshow(pixels, cmap='gray')
plt.show()

In [None]:
encoded_image_2 = copy.deepcopy(dataset[1]["encoded_image"])
rows = []
spikes = []
sum_row = 0
for r in range(encoded_image_2.shape[2]):
    for i in range(interval_time):
        sum_row += torch.sum(encoded_image_2[i][0][r])
    rows.append(r)
    spikes.append(int(sum_row))
    print(f"The number of spike on row {r} are: {sum_row}")
    sum_row = 0

In [None]:
row_spikes = pd.DataFrame({"image_row": rows, "number_spikes": spikes})
plt.bar("image_row", "number_spikes", data=row_spikes)
plt.xlabel("Image row")
plt.ylabel("Number of spikes")
plt.title("Number of generated spikes per image row")
plt.show()