# Imports

In [2]:
import torch
import matplotlib.pyplot as plt
from torchvision import transforms

from time import time as t
import datetime
from tqdm import tqdm

from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.network import Network, load
from bindsnet.learning import PostPre, WeightDependentPostPre
from bindsnet.network.monitors import Monitor, NetworkMonitor
from bindsnet.network.nodes import AdaptiveLIFNodes, Input
from bindsnet.network.topology import LocalConnection, Connection
from bindsnet.analysis.plotting import (
    plot_input,
    plot_spikes,
    plot_conv2d_weights,
    plot_voltages,
)

from IPython.display import clear_output

from LC_SNN import LC_SNN

%matplotlib inline

# Loading and encoding MNIST dataset

In [None]:
time_max = 30
dt = 1
intensity = 127.5

train_dataset = MNIST(
    PoissonEncoder(time=time_max, dt=dt),
    None,
    "MNIST",
    download=False,
    train=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
    )
)

# Building network

In [None]:
# Hyperparameters 
n_filters = 25
kernel_size = 12
stride = 4 
padding = 0
conv_size = int((28 - kernel_size + 2 * padding) / stride) + 1
per_class = int((n_filters * conv_size * conv_size) / 10)
tc_trace = 20.  # grid search check
tc_decay = 20.
thresh = -52
refrac = 5

wmin = 0
wmax = 1

# Network
network = Network(learning=True)
GlobalMonitor = NetworkMonitor(network, state_vars=('v', 's', 'w'))


input_layer = Input(n=784, shape=(1, 28, 28), traces=True)

output_layer = AdaptiveLIFNodes(
    n=n_filters * conv_size * conv_size,
    shape=(n_filters, conv_size, conv_size),
    traces=True,
    thres=thresh,
    trace_tc=tc_trace,
    tc_decay=tc_decay,
    theta_plus=0.05,
    tc_theta_decay=1e6)


connection_XY = LocalConnection(
    input_layer,
    output_layer,
    n_filters=n_filters,
    kernel_size=kernel_size,
    stride=stride,
    update_rule=PostPre,
    norm=1/2, #1/(kernel_size ** 2),#0.4 * kernel_size ** 2,  # norm constant - check
    nu=[1e-4, 1e-2],
    wmin=wmin,
    wmax=wmax)

# competitive connections
w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size)
for fltr1 in range(n_filters):
    for fltr2 in range(n_filters):
        if fltr1 != fltr2:
            # change
            for i in range(conv_size):
                for j in range(conv_size):
                    w[fltr1, i, j, fltr2, i, j] = -100.0
                    
connection_YY = Connection(output_layer, output_layer, w=w)

network.add_layer(input_layer, name='X')
network.add_layer(output_layer, name='Y')

network.add_connection(connection_XY, source='X', target='Y')
network.add_connection(connection_YY, source='Y', target='Y')

network.add_monitor(GlobalMonitor, name='Network')

spikes = {}
for layer in set(network.layers):
    spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time_max)
    network.add_monitor(spikes[layer], name="%s_spikes" % layer)
    print('GlobalMonitor.state_vars:', GlobalMonitor.state_vars)

voltages = {}
for layer in set(network.layers) - {"X"}:
    voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time_max)
    network.add_monitor(voltages[layer], name="%s_voltages" % layer)

# Training

In [None]:
visualize = False
n_train = 1
for epoch in range(n_train):
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=True)
    
    for batch in tqdm(train_dataloader):
        inpts = {"X": batch["encoded_image"]}
        inpts = {"X": batch["encoded_image"].transpose(0, 1)}

        network.run(inpts=inpts, time=time_max, input_time_dim=1)
    network.reset_()  # Reset state variables

# Saving network

In [None]:
network.save(f'network_{str(datetime.datetime.today())}'[:-7].replace(' ', '_').replace(':', '-'))

# Locking network

In [None]:
network.train(False);

# Loading network

In [None]:
network = load('default', learning=False)

# Testing network

In [None]:
plt.figure()
test_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=True)

for whatever, batch in tqdm(list(zip(range(5), test_dataloader))):
    cmap = plt.cm.jet
    #Processing
    inpts = {"X": batch["encoded_image"].transpose(0, 1)}
    label = batch["label"]
    network.run(inpts=inpts, time=time_max, input_time_dim=1)

    #Visualization
    # Optionally plot various simulation information.
    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights1_im = None
    voltage_ims = None
    voltage_axes = None
    image = batch["image"].view(28, 28)

    inpt = inpts["X"].view(time_max, 784).sum(0).view(28, 28)
    weights_XY = connection_XY.w
    weights_YY = connection_YY.w

    _spikes = {
        "X": spikes["X"].get("s").view(time_max, -1),
        "Y": spikes["Y"].get("s").view(time_max, -1),
    }
    _voltages = {"Y": voltages["Y"].get("v").view(time_max, -1)}

    inpt_axes, inpt_ims = plot_input(
        image, inpt, label=label, axes=inpt_axes, ims=inpt_ims
    )
    spike_ims, spike_axes = plot_spikes_display(_spikes, ims=spike_ims, axes=spike_axes)
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10))
    weights_XY = weights_XY.reshape(28, 28, -1)
    weights_to_display = torch.zeros(0, 28*25)
    i = 0
    while i < 625:
        for j in range(25):
            weights_to_display_row = torch.zeros(28, 0)
            for k in range(25):
                weights_to_display_row = torch.cat((weights_to_display_row, weights_XY[:, :, i]), dim=1)
                i += 1
            weights_to_display = torch.cat((weights_to_display, weights_to_display_row), dim=0)
    im1 = ax1.imshow(weights_to_display.numpy(), cmap=cmap)
    im2 = ax2.imshow(weights_YY.reshape(5*5*25, 5*5*25).numpy(), cmap=cmap)
    f.colorbar(im1, ax=ax1)
    f.colorbar(im2, ax=ax2)
    ax1.set_title('XY weights')
    ax2.set_title('YY weights')
    f.show()
    voltage_ims, voltage_axes = plot_voltages(
        _voltages, ims=voltage_ims, axes=voltage_axes
    )
    
network.reset_()  # Reset state variables
    

# Gridsearch for best parameters

In [None]:
def gridsearch(n_norm=2, n_comp=2, n_iter=100):
    norms  = np.linspace(0.1, 0.4, n_norm)
    comp_weights = np.linspace(-70, -13, n_comp)
    accs = torch.zeros(n_norm, n_comp)
    max_acc = 0
    best_norm = 0
    best_comp_weight = 0
    for i, norm in enumerate(norms):
        for j, competitive_weight in enumerate(comp_weights):
            clear_output(wait=True)
            print(f'Current parameters: norm={norm}, comp_weight={competitive_weight}')
            net = LC_SNN(norm=norm, competitive_weight=competitive_weight)
            net.train(n_iter=n_iter)
            net.network.save(f'gridsearch//network_norm={norm}_comp={competitive_weight}_n_iter={n_iter}')
            acc = net.accuracy(n_iter)
            accs[i, j] = acc
            if acc > max_acc:
                max_acc = acc
                best_norm = norm
                best_comp_weght = competitive_weight
    torch.save(accs, f'gridsearch//accs_n_norm={n_norm}_n_comp={n_comp}_n_iter=n_iter')
    return max_acc, best_norm, competitive_weight

In [None]:
gridsearch(n_norm=5, n_comp=10, n_iter=60000)

In [None]:
from sklearn.base import BaseEstimator
from sklearn.model_selection import RandomizedSearchCV
import numpy as np
from LC_SNN import LC_SNN
from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from torchvision import transforms
from tqdm import tqdm
import torch
class Wtf:
    def __init__(self, norm=1, competitive_weight=1, n_iter=1):
        print('initing')
    def fit(self, X, y, n_iter):
        print('fitting')
    def score(self, X, y):
        print('predicting')
        return np.random.random()
    def get_params(self, **args):
        return {'norm': 1,
                'competitive_weight':1,
                'n_iter':1
               }
    def set_params(self, norm, competitive_weight, n_iter):
        return Wtf(norm=norm, competitive_weight=competitive_weight, n_iter=n_iter)
    def __repr__(self):
        return f'WTF class' 

In [None]:
time_max = 30
dt = 1
intensity = 127.5

train_dataset = MNIST(
    PoissonEncoder(time=time_max, dt=dt),
    None,
    "MNIST",
    download=False,
    train=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
        )
    )

test_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=True)
for batch in test_dataloader:
    print(batch['label'].shape)
    break

In [None]:
data= pd.read_csv('encoded_MNIST.csv', sep='\t')
X_numpy, y_numpy = data['encoded_image'].values, data['label'].values
X = []
y = []
for elem in zip(X_numpy, y_numpy):
    X.append(torch.tensor(elem[0]))
    y.append(str(elem[1][0]))

In [None]:
stride

In [None]:
kernel_size = (12, 12)
conv_size = (5, 5)
stride = (4, 4)
shape = [625, 625]
locations = torch.zeros(
            kernel_size[0], kernel_size[1], conv_size[0], conv_size[1]
        ).long()
for c1 in range(conv_size[0]):
    for c2 in range(conv_size[0]):
        for k1 in range(kernel_size[0]):
            for k2 in range(kernel_size[0]):
                location = (
                    c1 * stride[0] * shape[0]
                    + c2 * stride[0]
                    + k1 * shape[0]
                    + k2
                )
                locations[k1, k2, c1, c2] = location

In [None]:
from LC_SNN import LC_SNN
net = LC_SNN(load=True)
net.load('network')
weights = torch.tensor(net.weights_XY)

In [None]:
weights_formatted = torch.zeros(25*12, 25*12)
for i in range(25):
    for j in range(25):
        for k1 in range(12):
            for k2 in range(12):
                weights_formatted[12*i + k1, 12*j + k2] = weights[28*i + k1 + 4*((i//5)//5), 28*i + k2 + (2*((j//5)//5))]

In [None]:
plt.figure(figsize=(15, 15), dpi=200)
plt.imshow(weights_formatted.numpy(), cmap='YlOrBr')
plt.colorbar()

In [None]:
plt.figure(figsize=(15, 15))
plt.imshow(weights.numpy(), cmap='YlOrBr')
plt.colorbar()

In [3]:
est = LC_SNN()
est.load('network')

In [6]:
est.calibrate_top_classes(10000)

Calibrating top classes for each neuron...


100%|██████████████████████████████████████████████████████████████████████████| 10000/10000 [3:38:03<00:00, 15.91it/s]


(array([[7, 0, 3, ..., 8, 6, 6],
        [0, 8, 2, ..., 2, 0, 0],
        [8, 5, 8, ..., 5, 2, 2],
        ...,
        [None, 4, 1, ..., 4, 9, 8],
        [1, 1, None, ..., None, None, None],
        [6, None, 7, ..., 7, 1, 7]], dtype=object),
 tensor([[1.8756e-02, 6.5153e-02, 5.9230e-03,  ..., 5.9230e-03, 6.4166e-02,
          3.7512e-02],
         [0.0000e+00, 8.7489e-04, 8.7489e-04,  ..., 1.7498e-03, 0.0000e+00,
          1.7498e-03],
         [3.0738e-03, 8.1967e-03, 4.9180e-02,  ..., 3.5861e-02, 1.3320e-02,
          3.0738e-02],
         ...,
         [6.9444e-03, 3.8690e-02, 8.9286e-03,  ..., 1.0516e-01, 4.9603e-03,
          9.9206e-04],
         [5.9406e-03, 2.1782e-02, 1.9802e-03,  ..., 4.9505e-03, 9.9010e-04,
          1.9802e-03],
         [1.0000e-04, 1.0000e-04, 1.0000e-04,  ..., 1.0000e-04, 1.0000e-04,
          1.0000e-04]]))