In [None]:
!pip install git+https://github.com/BindsNET/bindsnet.git

In [1]:
import os
from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from SNN_supervised import create_network, create_monitorings, train_network, save_network, test_network
from torchvision import transforms
import torch

In [2]:
config = {
    "n_neurons": 100,
    "n_train": 50000,
    "n_test": 10000,
    "n_eval": 500,
    "exc": 22.5,
    "inh": 120,
    "theta_plus": 0.05,
    "time": 1,
    "dt": 1.0,
    "intensity": 32,
    "update_interval": 10,
    "eval_interval": 1000,
    "dataset": "mnist",
    "n_classes": 10,
    "lr": [1e-10, 1e-3]
}

In [3]:
network = create_network(config["n_neurons"], config["exc"], config["inh"], config["dt"], config["theta_plus"], config["lr"])
network.load_state_dict(torch.load("models/standard_5000.pth"))
network, spikes = create_monitorings(network, config["time"])

In [4]:
dataset = MNIST(
    PoissonEncoder(time=config["time"], dt=config["dt"]),
    None,
    root=os.path.join("..", "..", "data", "MNIST"),
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * config["intensity"])]
    ),
)

In [5]:
train_set, val_set = torch.utils.data.random_split(dataset, [config["n_train"], config["n_test"]])

In [None]:
network = train_network(network, train_set, val_set, spikes, config)

In [None]:
save_network(network, "test.pth")

In [None]:
test_network(network, config, spikes, val_set)