<a href="https://colab.research.google.com/github/Singular-Brain/bindsnet/blob/master/lc_net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Notebook setups

In [11]:
!pip install -q git+https://github.com/Singular-Brain/bindsnet

In [12]:
!wget https://data.deepai.org/mnist.zip
!mkdir -p ../data/MNIST/TorchvisionDatasetWrapper/raw
!unzip mnist.zip -d ../data/MNIST/TorchvisionDatasetWrapper/raw/

--2021-07-25 07:39:33--  https://data.deepai.org/mnist.zip
Resolving data.deepai.org (data.deepai.org)... 138.201.36.183
Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11597176 (11M) [application/x-zip-compressed]
Saving to: ‘mnist.zip.1’


2021-07-25 07:39:35 (8.61 MB/s) - ‘mnist.zip.1’ saved [11597176/11597176]

Archive:  mnist.zip
replace ../data/MNIST/TorchvisionDatasetWrapper/raw/train-labels-idx1-ubyte.gz? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/train-labels-idx1-ubyte.gz  
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/train-images-idx3-ubyte.gz  
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/t10k-images-idx3-ubyte.gz  
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/t10k-labels-idx1-ubyte.gz  


In [113]:
from bindsnet.network.nodes import Nodes
import os
import torch
import torchvision
import numpy as np
import argparse
import matplotlib.pyplot as plt

from torchvision import transforms
from tqdm.notebook import tqdm

from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.network import Network
from bindsnet.network.nodes import Input, LIFNodes, AdaptiveLIFNodes
from bindsnet.network.topology import LocalConnection, Connection
from bindsnet.network.monitors import Monitor
from bindsnet.learning import PostPre, MSTDP, MSTDPET 
from bindsnet.learning.reward import DynamicDopamineInjection
from bindsnet.utils import get_square_assignments, get_square_weights
from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels
from bindsnet.analysis.plotting import (
    plot_input,
    plot_assignments,
    plot_performance,
    plot_weights,
    plot_spikes,
    plot_voltages,
)

## Set up hyper-parameters

In [114]:
n_neurons = 100
n_train = 500
n_test = 300
n_val = 100
theta_plus = 0.05
time = 250
dt = 1
intensity = 32
train = True
gpu = True
device_id = 0
n_classes = 10
neuron_per_class = 10
seed = 2045 # The Singularity is Near!

reward_kwargs = {
    'dopaminergic_layer': 'output', 
    'n_labels': n_classes,
    'neuron_per_class': neuron_per_class,
    'dopamine_per_spike': 0.1, 
    'tc_reward': 20,
    'dopamine_base': 0.02,
}

## Sets up Gpu use


In [115]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if gpu and torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
else:
    torch.manual_seed(seed)
    device = "cpu"
    if gpu:
        gpu = False

torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)

if not train:
    update_interval = n_test

Running on Device =  cuda


# Design network

In [116]:
C = 25
K = 12
S = 4
crop_size = 20
theta_plus = 0.05   ## Adaptive LIF
inh_factor = 25
compute_size = lambda inp_size, k, s: int((inp_size-k)/s) + 1

### Reward function
reward_fn = DynamicDopamineInjection
network = Network(dt = 1, reward_fn = reward_fn)

### nodes
inp = Input(shape= [1,20,20], traces=True)

main = AdaptiveLIFNodes(shape= [C, compute_size(20, K, S), compute_size(20, K, S)], traces=True, tc_trace=20, theta_plus=theta_plus)
#main = LIFNodes(shape= [C, compute_size(20, K, S), compute_size(20, K, S)], traces=True)

# TODO: Diehl & Cook 2015 (v2) 
out = AdaptiveLIFNodes(n= 100, traces=True, tc_trace=20, theta_plus=theta_plus)

### connections 
LC = LocalConnection(inp, main, K, S, C, nu = [1e-2, 1e-4], update_rule = MSTDPET)
# LC.w *= np.sqrt(2/(inp.n))
main_out = Connection(main, out, nu = [1e-2, 1e-4], update_rule = MSTDPET)



w = -inh_factor * (
            torch.ones(self.n_neurons, self.n_neurons)
            - torch.diag(torch.ones(self.n_neurons))
        )
        recurrent_connection = Connection(
            source=self.layers["Y"],
            target=self.layers["Y"],
            w=w,
            wmin=-self.inh,
            wmax=0,
        )
self.add_connection(recurrent_connection, source="Y", target="Y")


out_inhibition = Connection(out, out,)
# main_out.w *= np.sqrt(2/(main.n))


network.add_layer(main, "main")
network.add_layer(inp, "input")
network.add_layer(out, "output")
network.add_connection(LC, "input", "main")
network.add_connection(main_out, "main", "output")
network.add_connection(out_inhibition, "output", "output")

# Directs network to GPU
if gpu:
    network.to("cuda")

# Voltage recording for excitatory and inhibitory layers.
main_monitor = Monitor(network.layers["main"], ["v"], time=time, device=device)
output_monitor = Monitor(network.layers["output"], ["v"], time=time, device=device)
network.add_monitor(main_monitor, name="main")
network.add_monitor(output_monitor, name="output")


#Load Dataset

In [117]:
class ClassSelector(torch.utils.data.sampler.Sampler):
    """Select target classes from the dataset"""
    def __init__(self, target_classes, data_source):
        self.mask = torch.tensor([1 if data_source[i][1] in target_classes else 0 for i in range(len(data_source))])
        self.data_source = data_source

    def __iter__(self):
        return iter([i.item() for i in torch.nonzero(self.mask)])

    def __len__(self):
        return len(self.data_source)

In [118]:
# Load MNIST data.
dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    root=os.path.join("..", "..", "data", "MNIST"),
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
        transforms.Lambda(lambda x: x * intensity),
        transforms.CenterCrop(crop_size)]
    ),
)

# Create a dataloader to iterate and batch data
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True,
                                        #  sampler = ClassSelector(
                                        #         target_classes = (2,4),
                                        #         data_source = dataset)
                                         )

# Load test dataset
test_dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    root=os.path.join("..", "..", "data", "MNIST"),
    download=True,
    train=False,
    transform=transforms.Compose(
        [transforms.ToTensor(),
        transforms.Lambda(lambda x: x * intensity),
        transforms.CenterCrop(crop_size)]
    ),
)

val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True,
                                        #  sampler = ClassSelector(
                                        #         target_classes = (2,4),
                                        #         data_source = dataset)
                                         )

# Evaluation Utils

In [119]:
def evaluate(val_loader):
    correct = 0
    false = 0
    spikes_val = {}
    for layer in set(network.layers):
        spikes_val[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time)
        network.add_monitor(spikes_val[layer], name="%s_spikes" % layer)

    network.train(mode=False)

    for (i, datum) in enumerate(val_loader):
        if i > n_val:
            break

        image = datum["encoded_image"]
        label = datum["label"]

        # Run the network on the input.
        if gpu:
            inputs = {"input": image.cuda().view(time, 1, 1, 20, 20)}
        else:
            inputs = {"input": image.view(time, 1, 1, 20, 20)}
        network.run(inputs=inputs, time=time, **reward_kwargs, labels = label)

        # Add to spikes recording.
        output_spikes = spikes_val["output"].get("s").view(time, n_classes, neuron_per_class).sum(0)
        predicted_label = torch.argmax(output_spikes.sum(1))

        if predicted_label == label:
            correct+=1
        else:
            false+=1
    val_acc = 100 * correct/(correct + false)
    return val_acc

# Train

In [120]:
# Train the network.
print("Begin training.\n")

correct = 0
false = 0
spikes = {}
for layer in set(network.layers):
    spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time)
    network.add_monitor(spikes[layer], name="%s_spikes" % layer)

val_acc = 0.0

pbar = tqdm(total=n_train)
for (i, datum) in enumerate(dataloader):
    if i > n_train:
        break

    image = datum["encoded_image"]
    label = datum["label"]

    # Run the network on the input.
    if gpu:
        inputs = {"input": image.cuda().view(time, 1, 1, 20, 20)}
    else:
        inputs = {"input": image.view(time, 1, 1, 20, 20)}
    network.run(inputs=inputs, time=time, **reward_kwargs, labels = label)

    # Get voltage recording.
    # main_voltage = main_monitor.get("v")
    # out_voltage = output_monitor.get("v")

    # Add to spikes recording.
    output_spikes = spikes["output"].get("s").view(time, n_classes, neuron_per_class).sum(0)
    predicted_label = torch.argmax(output_spikes.sum(1))

    print("\routput", output_spikes.sum(1), 'predicted_label:', predicted_label.item(), 'GT:', label.item(), end = '')
    # print('main', spikes["main"].get("s").sum(0).sum(1))
    # print("input", spikes["input"].get("s").sum(0).sum(1))
    if  i % 100 == 0 and i!=0:
        if i != n_train:
            correct = 0
            false = 0
        val_acc = evaluate(val_loader)

    if predicted_label == label:
        correct+=1
    else:
        false+=1
        
    acc = 100 * correct/(correct + false)
    # network.reset_state_variables()  # Reset state variables.
    
    pbar.set_description_str("Running accuracy: " + "{:.2f}".format(acc) + "%, " + "Current val accuracy: " + "{:.2f}".format(val_acc) + "%, ")
    pbar.update()

Begin training.



  0%|          | 0/500 [00:00<?, ?it/s]

output tensor([80, 81, 80, 81, 80, 80, 81, 80, 80, 77]) predicted_label: 1 GT: 0