# Imports

In [1]:
import torch
import matplotlib.pyplot as plt
from torchvision import transforms

from time import time as t
from tqdm import tqdm

from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.network import Network
from bindsnet.learning import PostPre, WeightDependentPostPre
from bindsnet.network.monitors import Monitor, NetworkMonitor
from bindsnet.network.nodes import AdaptiveLIFNodes, Input
from bindsnet.network.topology import LocalConnection, Connection
from bindsnet.analysis.plotting import (
    plot_input,
    plot_spikes,
    plot_conv2d_weights,
    plot_voltages,
)

from IPython.display import clear_output

# Loading and encoding MNIST dataset

In [2]:
time = 30
dt = 1
intensity = 127.5

train_dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    "MNIST",
    download=False,
    train=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
    )
)

# Building network

In [3]:
# Hyperparameters 
n_filters = 25
kernel_size = 12
stride = 4 
padding = 0
conv_size = int((28 - kernel_size + 2 * padding) / stride) + 1
per_class = int((n_filters * conv_size * conv_size) / 10)
tc_trace = 20.  # grid search check
tc_decay = 20.
thresh = -52
refrac = 5

wmin = 0
wmax = 1

# Network
network = Network()
GlobalMonitor = NetworkMonitor(network, state_vars=('v', 's', 'w'))

input_layer = Input(n=784, shape=(1, 28, 28), traces=True)

output_layer = AdaptiveLIFNodes(
    n=n_filters * conv_size * conv_size,
    shape=(n_filters, conv_size, conv_size),
    traces=True,
    thres=thresh,
    trace_tc=tc_trace,
    tc_decay=tc_decay,
    theta_plus=0.05,
    tc_theta_decay=1e6)


connection_XY = LocalConnection(
    input_layer,
    output_layer,
    n_filters=n_filters,
    kernel_size=kernel_size,
    stride=stride,
    update_rule=PostPre,
    norm=0.4 * kernel_size ** 2,  # norm constant - check
    nu=[1e-4, 1e-2],
    wmin=wmin,
    wmax=wmax)

# competitive connections
w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size)
for fltr1 in range(n_filters):
    for fltr2 in range(n_filters):
        if fltr1 != fltr2:
            # change
            for i in range(conv_size):
                for j in range(conv_size):
                    w[fltr1, i, j, fltr2, i, j] = -100.0
                    
connection_competitive = Connection(output_layer, output_layer, w=w)

network.add_layer(input_layer, name='X')
network.add_layer(output_layer, name='Y')

network.add_connection(connection_XY, source='X', target='Y')
network.add_connection(connection_competitive, source='Y', target='Y')

network.add_monitor(GlobalMonitor, name='Network')

spikes = {}
for layer in set(network.layers):
    spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time)
    network.add_monitor(spikes[layer], name="%s_spikes" % layer)

voltages = {}
for layer in set(network.layers) - {"X"}:
    voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time)
    network.add_monitor(voltages[layer], name="%s_voltages" % layer)

# Training

In [None]:
n_epochs = 1
for epoch in range(n_epochs):
    #print(f'Batsh size in loop: {network.batch_size}')
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=True)
    
    for batch in tqdm(train_dataloader):
        inpts = {"X": batch["encoded_image"]}
        inpts = {"X": batch["encoded_image"].transpose(0, 1)}
        #print(f'input size {batch["encoded_image"].size()}')

        network.run(inpts=inpts, time=time, input_time_dim=1)
    network.reset_()  # Reset state variables

 10%|█         | 6189/60000 [1:29:39<3:06:56,  4.80it/s]      

In [None]:
sample['encoded_image'][0][0][0].numpy()

In [None]:
id('a')

In [None]:
train_dataloader?