In [1]:
from time import time as t

import torch
from torchvision import transforms
from tqdm import tqdm

from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.learning import PostPre
from bindsnet.network import Network
from bindsnet.network.monitors import Monitor
from bindsnet.network.nodes import DiehlAndCookNodes, Input
from bindsnet.network.topology import Connection, Conv2dConnection

In [2]:
n_epochs = 1
n_test = 10000
n_train = 60000
batch_size = 1
kernel_size = 16
stride = 4
n_filters = 25
padding = 0
time = 50
dt = 1.0
intensity = 128.0
progress_interval = 10
update_interval = 250
train = True
plot = True

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.manual_seed_all(0)

if not train:
    update_interval = n_test

conv_size = int((28 - kernel_size + 2 * padding) / stride) + 1
per_class = int((n_filters * conv_size * conv_size) / 10)

In [4]:
# Build network.
network = Network()
input_layer = Input(n=1*28*28, shape=(1, 28, 28), traces=True)

conv_layer = DiehlAndCookNodes(
    n=n_filters * conv_size * conv_size,
    shape=(n_filters, conv_size, conv_size),
    traces=True,
)

conv_conn = Conv2dConnection(
    input_layer,
    conv_layer,
    kernel_size=kernel_size,
    stride=stride,
    update_rule=PostPre,
    norm=0.4 * kernel_size**2,
    nu=[1e-4, 1e-2],
    wmax=1.0,
)

w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size)
for fltr1 in range(n_filters):
    for fltr2 in range(n_filters):
        if fltr1 != fltr2:
            for i in range(conv_size):
                for j in range(conv_size):
                    w[fltr1, i, j, fltr2, i, j] = -100.0

w = w.view(n_filters * conv_size * conv_size, n_filters * conv_size * conv_size)
recurrent_conn = Connection(conv_layer, conv_layer, w=w)

network.add_layer(input_layer, name="X")
network.add_layer(conv_layer, name="Y")
network.add_connection(conv_conn, source="X", target="Y")
network.add_connection(recurrent_conn, source="Y", target="Y")

# Voltage recording for excitatory and inhibitory layers.
voltage_monitor = Monitor(network.layers["Y"], ["v"], time=time)
network.add_monitor(voltage_monitor, name="output_voltage")
network.to("cuda")

Network(
  (X): Input()
  (Y): DiehlAndCookNodes()
  (X_to_Y): Conv2dConnection(
    (source): Input()
    (target): DiehlAndCookNodes()
  )
  (Y_to_Y): Connection(
    (source): DiehlAndCookNodes()
    (target): DiehlAndCookNodes()
  )
)

In [5]:
# Load MNIST data.
train_dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    "../../data/MNIST",
    download=True,
    train=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
    ),
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/MNIST\TorchvisionDatasetWrapper\raw\train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../../data/MNIST\TorchvisionDatasetWrapper\raw\train-images-idx3-ubyte.gz to ../../data/MNIST\TorchvisionDatasetWrapper\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/MNIST\TorchvisionDatasetWrapper\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../../data/MNIST\TorchvisionDatasetWrapper\raw\train-labels-idx1-ubyte.gz to ../../data/MNIST\TorchvisionDatasetWrapper\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/MNIST\TorchvisionDatasetWrapper\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../../data/MNIST\TorchvisionDatasetWrapper\raw\t10k-images-idx3-ubyte.gz to ../../data/MNIST\TorchvisionDatasetWrapper\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/MNIST\TorchvisionDatasetWrapper\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../../data/MNIST\TorchvisionDatasetWrapper\raw\t10k-labels-idx1-ubyte.gz to ../../data/MNIST\TorchvisionDatasetWrapper\raw



In [6]:
spikes = {}
for layer in set(network.layers):
    spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time)
    network.add_monitor(spikes[layer], name="%s_spikes" % layer)

voltages = {}
for layer in set(network.layers) - {"X"}:
    voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time)
    network.add_monitor(voltages[layer], name="%s_voltages" % layer)

In [8]:
# Train the network.
print("Begin training.\n")
start = t()

for epoch in range(n_epochs):
    if epoch % progress_interval == 0:
        print(f"Progress: {epoch} / {n_epochs} ({(t() - start):.4f} seconds)")
        start = t()

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
    )

    for step, batch in enumerate(tqdm(train_dataloader)):
        # Get next input sample.
        if step > n_train:
            break
        inputs = {"X": batch["encoded_image"].view(time, batch_size, 1, 28, 28)}
        inputs = {k: v.cuda() for k, v in inputs.items()}
        label = batch["label"]

        # Run the network on the input.
        network.run(inputs=inputs, time=time, input_time_dim=1)

        network.reset_state_variables()  # Reset state variables.

print(f"Progress: {n_epochs} / {n_epochs} ({(t() - start):.4f} seconds)\n")
print("Training complete.\n")

Begin training.

Progress: 0 / 1 (0.0000 seconds)


  0%|          | 161/60000 [00:38<4:01:06,  4.14it/s]


KeyboardInterrupt: 

In [None]:
torch.save(network.state_dict(), "./models/SNN_conv_test.pth")