# Create SNN Model

**Install packages**

In [15]:
import torch.nn as nn

!pip install git+https://github.com/miladmozafari/SpykeTorch.git
import SpykeTorch.snn as snn
import SpykeTorch.functional as sf

Collecting git+https://github.com/miladmozafari/SpykeTorch.git
  Cloning https://github.com/miladmozafari/SpykeTorch.git to /tmp/pip-req-build-u0jnfilv
  Running command git clone --filter=blob:none --quiet https://github.com/miladmozafari/SpykeTorch.git /tmp/pip-req-build-u0jnfilv
  Resolved https://github.com/miladmozafari/SpykeTorch.git to commit aa302c77cb61de6ec64e40927313137a9855ec6a
  Preparing metadata (setup.py) ... [?25l[?25hdone


**Create DCSNN model**

In [16]:
class DCSNN(nn.Module):
  def __init__(self):
    super(DCSNN, self).__init__()

    #(in_channels, out_channels, kernel_size, weight_mean=0.8, weight_std=0.02)
    self.conv1 = snn.Convolution(6, 30, 5, 0.8, 0.05)
    self.conv2 = snn.Convolution(30, 250, 3, 0.8, 0.05)
    self.conv3 = snn.Convolution(250, 200, 5, 0.8, 0.05)

    #(conv_layer, learning_rate, use_stabilizer=True, lower_bound=0, upper_bound=1)
    self.stdp1 = snn.STDP(self.conv1, (0.004, -0.003))
    self.stdp2 = snn.STDP(self.conv2, (0.004, -0.003))
    self.stdp3 = snn.STDP(self.conv3, (0.004, -0.003), False, 0.2, 0.8)
    self.anti_stdp3 = snn.STDP(self.conv3, (0.004, -0.0005), False, 0.2, 0.8)

    def save_data(self, input_spk, pot, spk, winners):
      self.ctx["input_spikes"] = input_spk
      self.ctx["potentials"] = pot
      self.ctx["output_spikes"] = spk
      self.ctx["winners"] = winners

    def forward(self, input, max_layer):
      input = sf.pad(input, (2, 2, 2, 2))
      if self.training: # forward pass for training
        pot = self.conv1(input)
        spk, pot = sf.fire(pot, 15, True)
        if max_layer == 1:
          winners = sf.get_k_winners(pot, 5, 3)
          self.save_data(input, pot, spk, winners)
          return spk, pot
        spk_in = sf.pad(sf.pooling(spk, 2, 2), (1, 1, 1, 1))
        pot = self.conv2(spk_in)
        spk, pot = sf.fire(pot, 10, True)
        if max_layer == 2:
          winners = sf.get_k_winners(pot, 8, 2)
          self.save_data(spk_in, pot, spk, winners)
        spk_in = sf.pad(sf.pooling(spk, 3, 3), (2, 2, 2, 2))
        pot = self.conv3(spk_in)
        spk = sf.fire_(pot)
        winners = sf.get_k_winners(pot, 1)
        self.save_data(spk_in, pot, spk, winners)
        output = -1
        if len(winners) != 0:
          output = self.decision_map[winners[0][0]]
        return output
      else: # forward pass for testing
        pot = self.conv1(input)
        spk = sf.fire(pot, 15)
        pot = self.conv2(sf.pad(sf.pooling(spk, 2, 2), (1, 1, 1, 1)))
        spk = sf.fire(pot, 10)
        pot = self.conv3(sf.pad(sf.pooling(spk, 3, 3), (2, 2, 2, 2)))
        # omitting the threshold parameter means infinite threshold
        spk = sf.fire_(pot)
        winners = sf.get_k_winners(pot, 1)
        output = -1
        # each winner is a tuple of form (feature, row, column)
        if len(winners) != 0:
          output = self.decision_map[winners[0][0]]
        return output

      def stdp(self, layer_idx):
        if layer_idx == 1:
          self.stdp1(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 2:
          self.stdp2(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

      def reward(self):
        self.stdp3(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

      def punish(self):
        self.anti_stdp3(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

**Transform input images into spike waves**

In [17]:
import SpykeTorch.utils as utils
import torchvision.transforms as transforms

class InputTransform:
  def __init__(self, filter):
    self.to_tensor = transforms.ToTensor()
    self.filter = filter
    self.temporal_transform = utils.Intensity2Latency(15, to_spike=True)

  def __call__(self, image):
    image = self.to_tensor(image) * 255 # convert image to tensor
    image.unsqueeze_(0) # add 1 extra time dimension
    image = self.filter(image) # apply filter
    image = sf.local_normalization(image, 8) # normalize
    return self.temporal_transform(image) # generate spike wave tensor

kernels = [utils.DoGKernel(3, 3/9, 6/9), utils.DoGKernel(3, 6/9, 3/9),
           utils.DoGKernel(7, 7/9, 14/9), utils.DoGKernel(7, 14/9, 7/9),
           utils.DoGKernel(13, 13/9, 26/9), utils.DoGKernel(13, 26/9, 13/9)]
filter = utils.Filter(kernels, padding=6, thresholds=50)
transform = InputTransform(filter)

# Prepare MNIST Dataset

In [18]:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

data_root = "./"

MNIST_train = utils.CacheDataset(MNIST(root=data_root, train=True, download=True, transform=transform))
MNIST_test = utils.CacheDataset(MNIST(root=data_root, train=False, download=True, transform=transform))
MNIST_loader = DataLoader(MNIST_train, batch_size=1000)
MNIST_test_loader = DataLoader(MNIST_test, batch_size=len(MNIST_test))

# Train and Test

**Unsupervised learning for normal STDP layers**

In [19]:
import torch
use_cuda = torch.cuda.is_available()

def train_unsupervised(network, data, layer_idx):
  network.train()
  for i in range(len(data)):
    data_in = data[i].cuda() if use_cuda else data[i]
    network(data_in, layer_idx)
    network.stdp(layer_idx)

**Reinforcement learning for R-STDP layer(s)**

In [20]:
import numpy as np

def train_rl(network, data, target):
  network.train()
  perf = np.array([0, 0, 0]) # correct, wrong, silent
  for i in range(len(data)):
    data_in = data[i].cuda() if use_cuda else data[i]
    target_in = target[i].cuda() if use_cuda else target[i]
    d = network(data_in, 3)
    if d != -1:
      if d == target_in:
        perf[0] += 1
        network.reward()
      else:
        perf[1] += 1
        network.punish()
    else:
      perf[2] += 1
  return perf / len(data)

**Train and test the network**

In [21]:
net = DCSNN()

if use_cuda:
  net.cuda()

epochs_1 = 10
epochs_2 = 10
epochs_3 = 10

# first layer
for epoch in range(epochs_1):
  for data, targets in MNIST_loader:
    train_unsupervised(net, data, 1)

# second layer
for epoch in range(epochs_2):
  for data, targets in MNIST_loader:
    train_unsupervised(net, data, 2)

# third layer
for epoch in range(epochs_3):
  for data, targets in MNIST_loader: # training
    print(train_rl(net, data, targets))
  for data, targets in MNIST_test_loader: # testing
    print(test(net, data, targets))

NotImplementedError: Module [DCSNN] is missing the required "forward" function

# Kheradpisheh Model

In [None]:
###################################################################################
# Reimplementation of the Digit Recognition Experiment (MNIST) Performed in:      #
# https://www.sciencedirect.com/science/article/pii/S0893608017302903             #
#                                                                                 #
# Reference:                                                                      #
# Kheradpisheh, Saeed Reza, et al.                                                #
# "STDP-based spiking deep convolutional neural networks for object recognition." #
# Neural Networks 99 (2018): 56-67.                                               #
#                                                                                 #
###################################################################################

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.nn.parameter import Parameter
import torchvision
import numpy as np
from SpykeTorch import snn
from SpykeTorch import functional as sf
from SpykeTorch import visualization as vis
from SpykeTorch import utils
from torchvision import transforms

use_cuda = True

class KheradpishehMNIST(nn.Module):
    def __init__(self):
        super(KheradpishehMNIST, self).__init__()

        self.conv1 = snn.Convolution(2, 32, 5, 0.8, 0.05)
        self.conv1_t = 10
        self.k1 = 5
        self.r1 = 2

        self.conv2 = snn.Convolution(32, 150, 2, 0.8, 0.05)
        self.conv2_t = 1
        self.k2 = 8
        self.r2 = 1

        self.stdp1 = snn.STDP(self.conv1, (0.004, -0.003))
        self.stdp2 = snn.STDP(self.conv2, (0.004, -0.003))
        self.max_ap = Parameter(torch.Tensor([0.15]))

        self.ctx = {"input_spikes":None, "potentials":None, "output_spikes":None, "winners":None}
        self.spk_cnt1 = 0
        self.spk_cnt2 = 0

    def save_data(self, input_spike, potentials, output_spikes, winners):
        self.ctx["input_spikes"] = input_spike
        self.ctx["potentials"] = potentials
        self.ctx["output_spikes"] = output_spikes
        self.ctx["winners"] = winners

    def forward(self, input, max_layer):
        input = sf.pad(input.float(), (2,2,2,2), 0)
        if self.training:
            pot = self.conv1(input)
            spk, pot = sf.fire(pot, self.conv1_t, True)
            if max_layer == 1:
                self.spk_cnt1 += 1
                if self.spk_cnt1 >= 500:
                    self.spk_cnt1 = 0
                    ap = torch.tensor(self.stdp1.learning_rate[0][0].item(), device=self.stdp1.learning_rate[0][0].device) * 2
                    ap = torch.min(ap, self.max_ap)
                    an = ap * -0.75
                    self.stdp1.update_all_learning_rate(ap.item(), an.item())
                pot = sf.pointwise_inhibition(pot)
                spk = pot.sign()
                winners = sf.get_k_winners(pot, self.k1, self.r1, spk)
                self.save_data(input, pot, spk, winners)
                return spk, pot
            spk_in = sf.pad(sf.pooling(spk, 2, 2, 1), (1,1,1,1))
            spk_in = sf.pointwise_inhibition(spk_in)
            pot = self.conv2(spk_in)
            spk, pot = sf.fire(pot, self.conv2_t, True)
            if max_layer == 2:
                pot = sf.pointwise_inhibition(pot)
                spk = pot.sign()
                winners = sf.get_k_winners(pot, self.k2, self.r2, spk)
                self.save_data(spk_in, pot, spk, winners)
                return spk, pot
            spk_out = sf.pooling(spk, 2, 2, 1)
            return spk_out
        else:
            pot = self.conv1(input)
            spk, pot = sf.fire(pot, self.conv1_t, True)
            pot = self.conv2(sf.pad(sf.pooling(spk, 2, 2, 1), (1,1,1,1)))
            spk, pot = sf.fire(pot, self.conv2_t, True)
            spk = sf.pooling(spk, 2, 2, 1)
            return spk

    def stdp(self, layer_idx):
        if layer_idx == 1:
            self.stdp1(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 2:
            self.stdp2(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

def train_unsupervise(network, data, layer_idx):
    network.train()
    for i in range(len(data)):
        data_in = data[i]
        if use_cuda:
            data_in = data_in.cuda()
        network(data_in, layer_idx)
        network.stdp(layer_idx)

def test(network, data, target, layer_idx):
    network.eval()
    ans = [None] * len(data)
    t = [None] * len(data)
    for i in range(len(data)):
        data_in = data[i]
        if use_cuda:
            data_in = data_in.cuda()
        output,_ = network(data_in, layer_idx).max(dim = 0)
        ans[i] = output.reshape(-1).cpu().numpy()
        t[i] = target[i]
    return np.array(ans), np.array(t)

class S1Transform:
    def __init__(self, filter, timesteps = 15):
        self.to_tensor = transforms.ToTensor()
        self.filter = filter
        self.temporal_transform = utils.Intensity2Latency(timesteps)
        self.cnt = 0
    def __call__(self, image):
        if self.cnt % 1000 == 0:
            print(self.cnt)
        self.cnt+=1
        image = self.to_tensor(image) * 255
        image.unsqueeze_(0)
        image = self.filter(image)
        image = sf.local_normalization(image, 8)
        temporal_image = self.temporal_transform(image)
        return temporal_image.sign().byte()

kernels = [ utils.DoGKernel(7,1,2),
            utils.DoGKernel(7,2,1),]
filter = utils.Filter(kernels, padding = 3, thresholds = 50)
s1 = S1Transform(filter)

data_root = "data"
MNIST_train = utils.CacheDataset(torchvision.datasets.MNIST(root=data_root, train=True, download=True, transform = s1))
MNIST_test = utils.CacheDataset(torchvision.datasets.MNIST(root=data_root, train=False, download=True, transform = s1))
MNIST_loader = DataLoader(MNIST_train, batch_size=len(MNIST_train), shuffle=False)
MNIST_testLoader = DataLoader(MNIST_test, batch_size=len(MNIST_test), shuffle=False)

kheradpisheh = KheradpishehMNIST()
if use_cuda:
    kheradpisheh.cuda()

# Training The First Layer
print("Training the first layer")
if os.path.isfile("saved_l1.net"):
    kheradpisheh.load_state_dict(torch.load("saved_l1.net"))
else:
    for epoch in range(2):
        print("Epoch", epoch)
        iter = 0
        for data,_ in MNIST_loader:
            print("Iteration", iter)
            train_unsupervise(kheradpisheh, data, 1)
            print("Done!")
            iter+=1
    torch.save(kheradpisheh.state_dict(), "saved_l1.net")

# Training The Second Layer
print("Training the second layer")
if os.path.isfile("saved_l2.net"):
    kheradpisheh.load_state_dict(torch.load("saved_l2.net"))
for epoch in range(20):
    print("Epoch", epoch)
    iter = 0
    for data,_ in MNIST_loader:
        print("Iteration", iter)
        train_unsupervise(kheradpisheh, data, 2)
        print("Done!")
        iter+=1
torch.save(kheradpisheh.state_dict(), "saved_l2.net")

# Classification
# Get train data
for data,target in MNIST_loader:
    train_X, train_y = test(kheradpisheh, data, target, 2)


# Get test data
for data,target in MNIST_testLoader:
    test_X, test_y = test(kheradpisheh, data, target, 2)

# SVM
from sklearn.svm import LinearSVC
clf = LinearSVC(C=2.4)
clf.fit(train_X, train_y)
predict_train = clf.predict(train_X)
predict_test = clf.predict(test_X)

def get_performance(X, y, predictions):
    correct = 0
    silence = 0
    for i in range(len(predictions)):
        if X[i].sum() == 0:
            silence += 1
        else:
            if predictions[i] == y[i]:
                correct += 1
    return (correct/len(X), (len(X)-(correct+silence))/len(X), silence/len(X))

print(get_performance(train_X, train_y, predict_train))
print(get_performance(test_X, test_y, predict_test))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 17892323.08it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 474504.17it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4463446.34it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3760467.58it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Training the first layer
Epoch 0
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
Iteration 0
Done!
Epoch 1
Iteration 0
Done!
Training the second layer
Epoch 0
Iteration 0
Done!
Epoch 1
Iteration 0
Done!
Epoch 2
Iteration 0
Done!
Epoch 3
Iteration 0
Done!
Epoch 4
Iteration 0
Done!
Epoch 5
Iteration 0
Done!
Epoch 6
Iteration 0
Done!
Epoch 7
Iteration 0
Done!
Epoch 8
Iteration 0
Done!
Epoch 9
Iteration 0
Done!
Epoch 10
Iteration 0
Done!
Epoch 11
Iteration 0
Done!
Epoch 12
Iteration 0
Done!
Epoch 13
Iteration 0
Done!
Epoch 14
Iteration 0
Done!
Epoch 15
Iteration 0
Done!
Epoch 16
Iteration 0
Done!
Epoch 17
Iteration 0
Done!
Epo