# Create SNN Model

**Install packages**

In [2]:
import torch.nn as nn

!pip install git+https://github.com/miladmozafari/SpykeTorch.git
import SpykeTorch.snn as snn
import SpykeTorch.functional as sf

Collecting git+https://github.com/miladmozafari/SpykeTorch.git
  Cloning https://github.com/miladmozafari/SpykeTorch.git to /tmp/pip-req-build-hbvvl3_0
  Running command git clone --filter=blob:none --quiet https://github.com/miladmozafari/SpykeTorch.git /tmp/pip-req-build-hbvvl3_0
  Resolved https://github.com/miladmozafari/SpykeTorch.git to commit aa302c77cb61de6ec64e40927313137a9855ec6a
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: SpykeTorch-miladmozafari
  Building wheel for SpykeTorch-miladmozafari (setup.py) ... [?25l[?25hdone
  Created wheel for SpykeTorch-miladmozafari: filename=SpykeTorch_miladmozafari-0.0.1-py3-none-any.whl size=27718 sha256=fb0261c5bcfc4b302d45d46c3923abf3e502215fb60fd929666bd2287ebe84b5
  Stored in directory: /tmp/pip-ephem-wheel-cache-17o7o__6/wheels/dc/1a/be/800ff08c666101c718fdfe0659a11e19c6fe7913e354fea6b4
Successfully built SpykeTorch-miladmozafari
Installing collected packages: SpykeTorch-miladmozafari

**Create layers**

In [3]:
class DCSNN(nn.Module):
  def __init__(self):
    super(DCSNN, self).__init__()

    #(in_channels, out_channels, kernel_size, weight_mean=0.8, weight_std=0.02)
    self.conv1 = snn.Convolution(6, 30, 5, 0.8, 0.05)
    self.conv2 = snn.Convolution(30, 250, 3, 0.8, 0.05)
    self.conv3 = snn.Convolution(250, 200, 5, 0.8, 0.05)

    #(conv_layer, learning_rate, use_stabilizer=True, lower_bound=0, upper_bound=1)
    self.stdp1 = snn.STDP(self.conv1, (0.004, -0.003))
    self.stdp2 = snn.STDP(self.conv2, (0.004, -0.003))
    self.stdp3 = snn.STDP(self.conv3, (0.004, -0.003), False, 0.2, 0.8)
    self.anti_stdp3 = snn.STDP(self.conv3, (0.004, -0.0005), False, 0.2, 0.8)

**Create dictionary to save data during each forward pass while training**

In [8]:
def save_data(self, input_spk, pot, spk, winners):
  self.ctx["input_spikes"] = input_spk
  self.ctx["potentials"] = pot
  self.ctx["output_spikes"] = spk
  self.ctx["winners"] = winners

**Create forward function**

In [7]:
def forward(self, input, max_layer):
  input = sf.pad(input, (2, 2, 2, 2))
  if self.training: # forward pass for training
    pot = self.conv1(input)
    spk, pot = sf.fire(pot, 15, True)
    if max_layer == 1:
      winners = sf.get_k_winners(pot, 5, 3)
      self.save_data(input, pot, spk, winners)
      return spk, pot
    spk_in = sf.pad(sf.pooling(spk, 2, 2), (1, 1, 1, 1))
    pot = self.conv2(spk_in)
    spk, pot = sf.fire(pot, 10, True)
    if max_layer == 2:
      winners = sf.get_k_winners(pot, 8, 2)
      self.save_data(spk_in, pot, spk, winners)
    spk_in = sf.pad(sf.pooling(spk, 3, 3), (2, 2, 2, 2))
    pot = self.conv3(spk_in)
    spk = sf.fire_(pot)
    winners = sf.get_k_winners(pot, 1)
    self.save_data(spk_in, pot, spk, winners)
    output = -1
    if len(winners) != 0:
      output = self.decision_map[winners[0][0]]
    return output
  else: # forward pass for testing
    pot = self.conv1(input)
    spk = sf.fire(pot, 15)
    pot = self.conv2(sf.pad(sf.pooling(spk, 2, 2), (1, 1, 1, 1)))
    spk = sf.fire(pot, 10)
    pot = self.conv3(sf.pad(sf.pooling(spk, 3, 3), (2, 2, 2, 2)))
    # omitting the threshold parameter means infinite threshold
    spk = sf.fire_(pot)
    winners = sf.get_k_winners(pot, 1)
    output = -1
    # each winner is a tuple of form (feature, row, column)
    if len(winners) != 0:
      output = self.decision_map[winners[0][0]]
    return output

**Create STDP helper functions**

In [9]:
def stdp(self, layer_idx):
  if layer_idx == 1:
    self.stdp1(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
  if layer_idx == 2:
    self.stdp2(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

def reward(self):
  self.stdp3(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

def punish(self):
  self.anti_stdp3(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

**Transform input images into spike waves**

In [11]:
import SpykeTorch.utils as utils
import torchvision.transforms as transforms

class InputTransform:
  def __init__(self, filter):
    self.to_tensor = transforms.ToTensor()
    self.filter = filter
    self.temporal_transform = utils.Intensity2Latency(15, to_spike=True)

  def __call__(self, image):
    image = self.to_tensor(image) * 255 # convert image to tensor
    image.unsqueeze_(0) # add 1 extra time dimension
    image = self.filter(image) # apply filter
    image = sf.local_normalization(image, 8) # normalize
    return self.temporal_transform(image) # generate spike wave tensor

kernels = [utils.DoGKernel(3, 3/9, 6/9), utils.DoGKernel(3, 6/9, 3/9),
           utils.DoGKernel(7, 7/9, 14/9), utils.DoGKernel(7, 14/9, 7/9),
           utils.DoGKernel(13, 13/9, 26/9), utils.DoGKernel(13, 26/9, 13/9)]
filter = utils.Filter(kernels, padding=6, thresholds=50)
transform = InputTransform(filter)

# Prepare MNIST Dataset

In [14]:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

data_root = "./"

MNIST_train = utils.CacheDataset(MNIST(root=data_root, train=True, download=True, transform=transform))
MNIST_test = utils.CacheDataset(MNIST(root=data_root, train=False, download=True, transform=transform))
MNIST_loader = DataLoader(MNIST_train, batch_size=1000)
MNIST_test_loader = DataLoader(MNIST_test, batch_size=len(MNIST_test))

# Train and Test

**Unsupervised learning for normal STDP layers**

In [19]:
import torch
use_cuda = torch.cuda.is_available()

def train_unsupervised(network, data, layer_idx):
  network.train()
  for i in range(len(data)):
    data_in = data[i].cuda() if use_cuda else data[i]
    network(data_in, layer_idx)
    network.stdp(layer_idx)

**Reinforcement learning for R-STDP layer(s)**

In [20]:
import numpy as np

def train_rl(network, data, target):
  network.train()
  perf = np.array([0, 0, 0]) # correct, wrong, silent
  for i in range(len(data)):
    data_in = data[i].cuda() if use_cuda else data[i]
    target_in = target[i].cuda() if use_cuda else target[i]
    d = network(data_in, 3)
    if d != -1:
      if d == target_in:
        perf[0] += 1
        network.reward()
      else:
        perf[1] += 1
        network.punish()
    else:
      perf[2] += 1
  return perf / len(data)