# Imports

In [None]:
import torch
import matplotlib.pyplot as plt
from torchvision import transforms

from time import time as t
import datetime
from tqdm import tqdm

from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.network import Network, load
from bindsnet.learning import PostPre, WeightDependentPostPre
from bindsnet.network.monitors import Monitor, NetworkMonitor
from bindsnet.network.nodes import AdaptiveLIFNodes, Input
from bindsnet.network.topology import LocalConnection, Connection
from bindsnet.analysis.plotting import (
    plot_input,
    plot_spikes,
    plot_conv2d_weights,
    plot_voltages,
)

from IPython.display import clear_output

from LC_SNN import LC_SNN

%matplotlib inline

# Loading and encoding MNIST dataset

In [None]:
time_max = 30
dt = 1
intensity = 127.5

train_dataset = MNIST(
    PoissonEncoder(time=time_max, dt=dt),
    None,
    "MNIST",
    download=False,
    train=True,
    transform=transforms.Compose(
        [transforms.CenterCrop(20), transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
    )
)

# Building network

In [None]:
# Hyperparameters 
n_filters = 25
kernel_size = 12
stride = 4 
padding = 0
conv_size = int((28 - kernel_size + 2 * padding) / stride) + 1
per_class = int((n_filters * conv_size * conv_size) / 10)
tc_trace = 20.  # grid search check
tc_decay = 20.
thresh = -52
refrac = 5

wmin = 0
wmax = 1

# Network
network = Network(learning=True)
GlobalMonitor = NetworkMonitor(network, state_vars=('v', 's', 'w'))


input_layer = Input(n=784, shape=(1, 28, 28), traces=True)

output_layer = AdaptiveLIFNodes(
    n=n_filters * conv_size * conv_size,
    shape=(n_filters, conv_size, conv_size),
    traces=True,
    thres=thresh,
    trace_tc=tc_trace,
    tc_decay=tc_decay,
    theta_plus=0.05,
    tc_theta_decay=1e6)


connection_XY = LocalConnection(
    input_layer,
    output_layer,
    n_filters=n_filters,
    kernel_size=kernel_size,
    stride=stride,
    update_rule=PostPre,
    norm=1/2, #1/(kernel_size ** 2),#0.4 * kernel_size ** 2,  # norm constant - check
    nu=[1e-4, 1e-2],
    wmin=wmin,
    wmax=wmax)

# competitive connections
w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size)
for fltr1 in range(n_filters):
    for fltr2 in range(n_filters):
        if fltr1 != fltr2:
            # change
            for i in range(conv_size):
                for j in range(conv_size):
                    w[fltr1, i, j, fltr2, i, j] = -100.0
                    
connection_YY = Connection(output_layer, output_layer, w=w)

network.add_layer(input_layer, name='X')
network.add_layer(output_layer, name='Y')

network.add_connection(connection_XY, source='X', target='Y')
network.add_connection(connection_YY, source='Y', target='Y')

network.add_monitor(GlobalMonitor, name='Network')

spikes = {}
for layer in set(network.layers):
    spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time_max)
    network.add_monitor(spikes[layer], name="%s_spikes" % layer)
    print('GlobalMonitor.state_vars:', GlobalMonitor.state_vars)

voltages = {}
for layer in set(network.layers) - {"X"}:
    voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time_max)
    network.add_monitor(voltages[layer], name="%s_voltages" % layer)

# Training

In [None]:
visualize = False
n_train = 1
for epoch in range(n_train):
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=True)
    
    for batch in tqdm(train_dataloader):
        inpts = {"X": batch["encoded_image"].transpose(0, 1)}

        network.run(inpts=inpts, time=time_max, input_time_dim=1)
    network.reset_()  # Reset state variables

# Saving network

In [None]:
network.save(f'network_{str(datetime.datetime.today())}'[:-7].replace(' ', '_').replace(':', '-'))

# Locking network

In [None]:
network.train(False);

# Loading network

In [None]:
network = load('default', learning=False)

# Testing network

In [None]:
plt.figure()
test_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=True)

for batch in tqdm(test_dataloader):
    clear_output(wait=True)
    #Processing
    inpts = {"X": batch["encoded_image"]}
    inpts = {"X": batch["encoded_image"].transpose(0, 1)}
    label = batch["label"]
    network.run(inpts=inpts, time=time_max, input_time_dim=1)

    #Visualization
    # Optionally plot various simulation information.
    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights1_im = None
    voltage_ims = None
    voltage_axes = None
    image = batch["image"].view(28, 28)
    #inpts = {"X": batch["encoded_image"]}
    inpt = {"X": batch["encoded_image"].transpose(0, 1)}
    inpt = inpt["X"].view(time_max, 784).sum(0).view(28, 28)
    weights_XY = connection_XY.w
    weights_YY = connection_YY.w

    _spikes = {
        "X": spikes["X"].get("s").view(time_max, -1),
        "Y": spikes["Y"].get("s").view(time_max, -1),
    }
    _voltages = {"Y": voltages["Y"].get("v").view(time_max, -1)}

    inpt_axes, inpt_ims = plot_input(
        image, inpt, label=label, axes=inpt_axes, ims=inpt_ims
    )
    spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes)
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10))
    weights_XY = weights_XY.reshape(28, 28, -1)
    weights_to_display = torch.zeros(0, 28*25)
    i = 0
    while i < 625:
        for j in range(25):
            weights_to_display_row = torch.zeros(28, 0)
            for k in range(25):
                weights_to_display_row = torch.cat((weights_to_display_row, weights_XY[:, :, i]), dim=1)
                i += 1
            weights_to_display = torch.cat((weights_to_display, weights_to_display_row), dim=0)
    im1 = ax1.imshow(weights_to_display.numpy())
    im2 = ax2.imshow(weights_YY.reshape(5*5*25, 5*5*25).numpy())
    f.colorbar(im1, ax=ax1)
    f.colorbar(im2, ax=ax2)
    ax1.set_title('XY weights')
    ax2.set_title('YY weights')
    f.show()
    voltage_ims, voltage_axes = plot_voltages(
        _voltages, ims=voltage_ims, axes=voltage_axes
    )
    plt.close()
network.reset_()  # Reset state variables
    

In [None]:
plt.figure(figsize=(25, 25))
plt.imshow(weights_to_display.numpy())

In [None]:
from LC_SNN import LC_SNN
def grid_search(norm_min=0.01, norm_max=1, n=10):
    norms = np.linspace(norm_min, norm_max, 10)
    for norm in norms:
        net = LC_SNN(norm=norm)
        print(net)
        network_to_save = net.train(n_iter=1)
        network_to_save.save(f'gridsearch//LC_SNN_norm={norm}') # Saving network for later
        f = plt.figure(figsize=(20, 20), dpi=500)
        plt.imshow(net.weights_XY)
        f.savefig(f'gridsearch//weights_XY_norm={norm}.png') # Saving weights_XY

In [None]:
grid_search()

In [None]:
np.linspace(0.01, 100, 10)

In [None]:
from LC_SNN import LC_SNN
net = LC_SNN(n_iter=10000)

In [None]:

import plotly.graph_objs as go

In [1]:
from LC_SNN import LC_SNN, load_LC_SNN
net = LC_SNN()

In [49]:
net.train(n_iter=1000, plot=True, debug=True);

  2%|█▏                                                                              | 15/1000 [00:05<03:43,  4.40it/s]


KeyboardInterrupt: 

In [31]:
import torch
sum_output = net._spikes['Y'].sum(0)

In [42]:
x = torch.rand((11, sum_output.shape[0]))
res = torch.matmul(x, sum_output.type_as(x))

In [47]:
sum_output

tensor([0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        4, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4,
        0, 0, 0, 0, 0, 0, 0, 0, 0])

In [None]:
net._spikes['Y'].nonzero().shape[0]

In [None]:
train_dataloader = torch.utils.data.DataLoader(
            net.train_dataset, batch_size=1, shuffle=True)
        

In [None]:
train_dataloader.dataset[0]['labels']

In [None]:
net.network.monitors['Y_spikes'].get('s').nonzero()

In [None]:
educated_rofl = load_LC_SNN('networks//rofl')

In [None]:
educated_rofl.visualize().show()