In [1]:
import os, json

import numpy as np
import torch
from torchvision import transforms
from tqdm import tqdm

from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting
from bindsnet.models import DiehlAndCook2015
from bindsnet.network.monitors import Monitor

In [2]:
with open("./config.json") as f:
    config = json.load(f)

In [3]:
# Sets up Gpu use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.manual_seed_all(0)
torch.set_num_threads(os.cpu_count() - 1)

n_classes = 10
n_sqrt = int(np.ceil(np.sqrt(config["n_neurons"])))
start_intensity = config["intensity"]
per_class = int(config["n_neurons"] / n_classes)

In [4]:
# Build Diehl & Cook 2015 network.
network = DiehlAndCook2015(
    n_inpt=784,
    n_neurons=config["n_neurons"],
    exc=config["exc"],
    inh=config["inh"],
    dt=config["dt"],
    nu=[1e-10, 1e-3],  # 0.711
    norm=78.4,
    theta_plus=config["theta_plus"],
    inpt_shape=(1, 28, 28),
)
# Directs network to GPU
network.to("cuda")

DiehlAndCook2015(
  (X): Input()
  (Ae): DiehlAndCookNodes()
  (Ai): LIFNodes()
  (X_to_Ae): Connection(
    (source): Input()
    (target): DiehlAndCookNodes()
  )
  (Ae_to_Ai): Connection(
    (source): DiehlAndCookNodes()
    (target): LIFNodes()
  )
  (Ai_to_Ae): Connection(
    (source): LIFNodes()
    (target): DiehlAndCookNodes()
  )
)

In [5]:
# Voltage recording for excitatory and inhibitory layers.
#exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=config["time"], device=device)
#inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=config["time"], device=device)
#network.add_monitor(exc_voltage_monitor, name="exc_voltage")
#network.add_monitor(inh_voltage_monitor, name="inh_voltage")
def config_monitor(net, time, device):
    evm = Monitor(network.layers["Ae"], ["v"], time=time, device=device)
    ivm = Monitor(network.layers["Ai"], ["v"], time=time, device=device)
    net.add_monitor(evm, name="exc_voltage")
    net.add_monitor(ivm, name="inh_voltage")
    spikes_dict = {}
    for layer in set(net.layers):
        spikes_dict[layer] = Monitor(net.layers[layer], state_vars=["s"], time=config["time"])
        net.add_monitor(spikes_dict[layer], name=f"{layer}_spikes")
    return net, evm, ivm, spikes_dict
network, exc_voltage_monitor, inh_voltage_monitor, spikes = config_monitor(network, config["time"], device)




In [6]:
# Load MNIST data.
dataset = MNIST(
    PoissonEncoder(time=config["time"], dt=config["dt"]),
    None,
    root=os.path.join("..", "..", "data", "MNIST"),
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * config["intensity"])]
    ),
)

In [7]:
# Create a dataloader to iterate and batch data
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

# Record spikes during the simulation.
spike_record = torch.zeros(config["update_interval"], config["time"], config["n_neurons"], device=device)

# Neuron assignments and spike proportions.
assignments = -torch.ones_like(torch.Tensor(config["n_neurons"]), device=device)
proportions = torch.zeros_like(torch.Tensor(config["n_neurons"], n_classes), device=device)
rates = torch.zeros_like(torch.Tensor(config["n_neurons"], n_classes), device=device)

# Sequence of accuracy estimates.
accuracy = {"all": [], "proportion": []}

# Labels to determine neuron assignments and spike proportions and estimate accuracy
labels = torch.empty(config["update_interval"], device=device)

In [9]:
# Train the network.
print("Begin training.\n")

pbar = tqdm(total=config["n_train"])
for (i, datum) in enumerate(dataloader):
    if i > config["n_train"]:
        break

    image = datum["encoded_image"]
    label = datum["label"]

    if i % config["update_interval"] == 0 and i > 0:
        # Get network predictions.
        all_activity_pred = all_activity(spike_record, assignments, n_classes)
        proportion_pred = proportion_weighting(spike_record, assignments, proportions, n_classes)

        # Compute network accuracy according to available classification strategies.
        accuracy["all"].append(100 * torch.sum(labels.long() == all_activity_pred).item() / config["update_interval"])
        accuracy["proportion"].append(100 * torch.sum(labels.long() == proportion_pred).item() / config["update_interval"])

        print(f"\nAll activity accuracy: {accuracy['all'][-1]:.2f} (last), {np.mean(accuracy['all']):.2f} (average), {np.max(accuracy['all']):.2f} (best)")
        print(f"Proportion weighting accuracy: {accuracy['proportion'][-1]:.2f} (last), {np.mean(accuracy['proportion']):.2f} (average), {np.max(accuracy['proportion']):.2f} (best)\n")

        # Assign labels to excitatory layer neurons.
        assignments, proportions, rates = assign_labels(spike_record, labels, n_classes, rates)

    # Add the current label to the list of labels for this update_interval
    labels[i % config["update_interval"]] = label[0]

    # Run the network on the input.
    choice = np.random.choice(int(config["n_neurons"] / n_classes), size=config["n_clamp"], replace=False)
    clamp = {"Ae": per_class * label.long() + torch.Tensor(choice).long()}
    inputs = {"X": image.cuda().view(config["time"], 1, 1, 28, 28)}
    network.run(inputs=inputs, time=config["time"], clamp=clamp)

    # Get voltage recording.
    exc_voltages = exc_voltage_monitor.get("v")
    inh_voltages = inh_voltage_monitor.get("v")

    # Add to spikes recording.
    spike_record[i % config["update_interval"]] = spikes["Ae"].get("s").view(config["time"], config["n_neurons"])

    network.reset_state_variables()  # Reset state variables.
    pbar.set_description_str("Train progress: ")
    pbar.update()

print(f"Progress: {config['n_train']} / {config['n_train']} \n")
print("Training complete.\n")

Begin training.



  0%|          | 0/60000 [00:16<?, ?it/s]


NameError: name 'spikes' is not defined

In [None]:
print("Testing....\n")
# Load MNIST data.
test_dataset = MNIST(
    PoissonEncoder(time=config["time"], dt=config["dt"]),
    None,
    root=os.path.join("..", "..", "data", "MNIST"),
    download=True,
    train=False,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * config["intensity"])]
    ),
)

# Sequence of accuracy estimates.
accuracy = {"all": 0, "proportion": 0}

# Record spikes during the simulation.
spike_record = torch.zeros(1, int(config["time"] / config["dt"]), config["n_neurons"], device=device)

# Train the network.
print("\nBegin testing\n")
network.train(mode=False)

pbar = tqdm(total=config["n_test"])
for step, batch in enumerate(test_dataset):
    if step > config["n_test"]:
        break
    # Get next input sample.
    inputs = {"X": batch["encoded_image"].view(int(config["time"] / config["dt"]), 1, 1, 28, 28)}
    inputs = {k: v.cuda() for k, v in inputs.items()}

    # Run the network on the input.
    network.run(inputs=inputs, time=config["time"], input_time_dim=1)

    # Add to spikes recording.
    spike_record[0] = spikes["Ae"].get("s").squeeze()

    # Convert the array of labels into a tensor
    label_tensor = torch.tensor(batch["label"], device=device)

    # Get network predictions.
    all_activity_pred = all_activity(spikes=spike_record, assignments=assignments, n_labels=n_classes)
    proportion_pred = proportion_weighting(
        spikes=spike_record,
        assignments=assignments,
        proportions=proportions,
        n_labels=n_classes,
    )

    # Compute network accuracy according to available classification strategies.
    accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred).item())
    accuracy["proportion"] += float(torch.sum(label_tensor.long() == proportion_pred).item())

    network.reset_state_variables()  # Reset state variables.

    pbar.set_description_str(f"Accuracy: {(max(accuracy['all'] ,accuracy['proportion'] ) / (step+1)):.3}")
    pbar.update()

print(f"\nAll activity accuracy: {(accuracy['all'] / config['n_test']):.2f}")
print(f"Proportion weighting accuracy: {(accuracy['proportion'] / config['n_test']):.2f} \n")

print("Testing complete.\n")