## BindsNet Network setup
### Following the Diehl & Cook method (Institute of Neuroinformatics, University of Zurich and ETH Zurich)

In [None]:
%pip install bindsnet
%pip install torch

### Libraries

In [None]:
import os
import torch
import numpy as np

from torchvision import transforms

from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.models import DiehlAndCook2015

### Parameters

In [None]:
interval_time = 250  # time in ms
dt = 1.0  # interval length in ms
intensity = 128  # input layer Poisson spikes maximum firing rate, in Hz
n_train = 60000  # number of training samples
n_neurons = 100  # number of neurons

### Load MNIST dataset

In [None]:
# Load MNIST data.
dataset = MNIST(
    image_encoder=PoissonEncoder(time=interval_time, dt=dt),
    label_encoder=None,
    root=os.path.join("..", "data", "MNIST"),
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)])
)

### Build the Diehl & Cook network

Based on the article: Unsupervised learning of digit recognition using spike-timing-dependent plasticity by Peter U. Diehl and Matthew Cook from Institute of Neuroinformatics, ETH Zurich and University Zurich, Zurich, Switzerland.

An input image is encoded to spike trains and send to all *n* neurons in the Excitatory layer. These neurons are initialized with random weights. This layer uses Spike-Timing Dependant Plasticity to strengthens the weights of the neurons that fire shortly after the incoming spike train. In Hebbian learning this is called: "fire together, wire together".
An excitatory neuron fires to only one connected inhibitory LIF neuron which is connected again to all excitatory neurons except the one it receives the spike signal from. The inhibitory neuron inhibits all other excitatory neurons.

![diehlcook](img/diehl_cook2.png)

In [None]:
# build Diehl & Cook network
network = DiehlAndCook2015(n_inpt=784,  # number of input neurons for an 28x28 image
                           n_neurons=n_neurons,  # Number of excitatory, inhibitory neurons
                           exc=22.5,  # Strength of synapse weights from excitatory to inhibitory layer
                           inh=17.5,  # Strength of synapse weights from inhibitory to excitatory layer
                           dt=1.0,  # Simulation time step
                           nu=[1e-10, 1e-3],  # pair of learning rates for pre- and post-synaptic events, resp.
                           inpt_shape=(1, 28, 28))

### The network elements

##### Connection between input and excitatory layer

In [None]:
# Architecture
network.X_to_Ae

In [None]:
# weights
print(f"weights of connection between input and excitatory layer: {network.X_to_Ae.w}")
print(f"Shape of the weights is: {network.X_to_Ae.w.shape}")

##### The one-to-one connection between the excitatory and inhibtory layer

In [None]:
network.Ae_to_Ai.w

##### The oconnection between the inhibtory and the excitatory layer

In [None]:
network.Ai_to_Ae.w

### Running the network

##### Unsupervised

In [None]:
# Create a dataloader to iterate and batch data
dataloader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=True)

for (i, d) in enumerate(dataloader):
    # Run the network on the input.
    image = d["encoded_image"]
    label = d["label"]
    # Get next input sample.
    inputs = {"X": image.view(interval_time, 100, 1, 28, 28)}
    network.run(inputs=inputs, time=interval_time, input_time_dim=1)

##### Result

![diehlcookresult](img/diehl_cook_result.png)

##### Supervised

In [None]:
# Supervised parameters
n_clamp = 1
n_classes = 10
per_class = int(n_neurons / n_classes)

In [None]:
for (i, d) in enumerate(dataloader):
    # Run the network on the input.
    image = d["encoded_image"]
    label = d["label"]
    # Get next input sample.
    inputs = {"X": image.view(interval_time, 100, 1, 28, 28)}
    # Set clamp
    choice = np.random.choice(int(n_neurons / n_classes), size=n_clamp, replace=False)
    clamp = {"Ae": per_class * label.long() + torch.Tensor(choice).long()}
    network.run(inputs=inputs, time=interval_time, clamp=clamp)

#### Reference

For more examples and better monitoring and visuals check the github page: https://github.com/BindsNET/bindsnet