### Unsupervised Learning with Self-Organizing Spiking Neural Networks
https://arxiv.org/pdf/1807.09374.pdf

In [1]:
import os
import numpy as np
import torch
from torchvision import transforms
from tqdm import tqdm

In [2]:
import bindsnet as bn

In [3]:
intensity = 64
n_neurons = 900
epochs = 10
time = 250
batch_size = 100
test_interval = 1000 // batch_size
update_inhibition_weights = 500 // batch_size
start_inhib = 10
max_inhib = -40.0

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
device

'cuda'

In [6]:
network = bn.network.Network(batch_size=batch_size)

input_layer = bn.network.nodes.Input(n=784, shape=(1,28,28), traces=True, tc_trace=20.0)
network.add_layer(input_layer, name='X')

output_layer = bn.network.nodes.DiehlAndCookNodes(n_neurons, traces=True, reset=-60.0)
network.add_layer(output_layer, name='Y')

input_output_conn = bn.network.topology.Connection(
    source=input_layer,
    target=output_layer,
    w=0.1 * torch.ones(784, n_neurons),
    update_rule=bn.learning.PostPre,
    nu=[1e-4, 1e-2],
    reduction=torch.sum,
    wmin=0.0, wmax=1.0,
    norm=784*0.1
)
network.add_connection(input_output_conn, source='X', target='Y')

In [7]:
w = torch.zeros(n_neurons, n_neurons)
n_sqrt = int(np.sqrt(n_neurons))
for i in range(n_neurons):
    x1, y1 = i // n_sqrt, i % n_sqrt
    for j in range(n_neurons):
        x2, y2 = j // n_sqrt, i % n_sqrt
        dx, dy = x2 - x1, y2 - y1
        w[i, j] = dx*dx + dy*dy
    w[i, i] = 0.0
w = torch.pow(w / w.max(), 0.25)
w = start_inhib + w * max_inhib

recurrent_output_conn = bn.network.topology.Connection(source=output_layer, target=output_layer, w=w)
network.add_connection(recurrent_output_conn, source='Y', target='Y')

weights_mask = (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons))).to(device)

In [8]:
network.to(device)
print(bn.analysis.visualization.summary(network))

         NETWORK SUMMARY
         [0mbatch size:100
    [0m··········································
    Layer: 'X' (trainable)
   784 neurons (1, 28, 28)
       [94m·connected to 'Y' by 705,600 synapses
    [0m··········································
    Layer: 'Y' (trainable)
   900 neurons [900]
       [94m·connected to 'Y' by 810,000 synapses
[95mTotal neurons: 1,684 (1,684 trainable)
Total synapses weights: 1,515,600 (1,515,600 trainable)[0m


In [9]:
dataset = bn.datasets.MNIST(
    bn.encoding.PoissonEncoder(time=time, dt=1),
    None,
    root=os.path.join('data', 'MNIST'),
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: intensity * x)])
)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=True)

In [10]:
out_spikes = bn.network.monitors.Monitor(network.layers['Y'], state_vars=['s'], time=time)
network.add_monitor(out_spikes, name='out_spikes')

In [11]:
spike_record = torch.zeros(test_interval * batch_size, time, n_neurons)
labels = torch.zeros(test_interval * batch_size)
accuracy = .0

assignments = -torch.ones(n_neurons)
proportions = torch.zeros(n_neurons, 10)
rates = torch.zeros(n_neurons, 10)

In [None]:
for epoch in range(epochs):
    for step, batch in enumerate(tqdm(dataloader, desc='MNIST Pass', leave=False)):
        inputs = {'X': batch['encoded_image'].permute(1,0,2,3,4).to(device)}

        network.reset_state_variables()
        network.run(inputs=inputs, time=time)

        s = (step % test_interval) * batch_size
        labels[s:s+batch_size] = batch['label']
        spike_record[s:s+batch_size] = out_spikes.get('s').permute(1,0,2)

        if step % update_inhibition_weights == 0 and step > 0:
            if step % (update_inhibition_weights * 10) == 0:
                network.Y_to_Y.w -= weights_mask * 50
            else:
                network.Y_to_Y.w -= weights_mask * 0.5

        if step % test_interval == test_interval - 1:
            y_pred = bn. evaluation.proportion_weighting(
                spikes=spike_record,
                assignments=assignments,
                proportions=proportions,
                n_labels=10
            )
            accuracy = 0.8 * accuracy + 0.2 * torch.sum(labels == y_pred) / (batch_size * test_interval)
            print(f"\n epoch = {epoch}, smoothed accuracy = {100. * accuracy.numpy()}")

            assignments, proportions, rates = bn.evaluation.assign_labels(
                spikes=spike_record,
                labels=torch.tensor(labels),
                n_labels=10,
                rates=rates
            )

In [15]:
if not os.path.exists("./models"):
    os.mkdir("./models")
torch.save(network.state_dict(), os.path.join("./models", "test.pt"))