[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/notebooks/colab-github-demo.ipynb)

In [None]:
!git clone https://{TOKEN}@github.com/mapolinario94/ECE570-Project.git

In [None]:
import sys
sys.path.insert(0,'/content/ECE570-Project/models')
sys.path.insert(0,'/content/ECE570-Project/models/spiking_layers.py')
sys.path.insert(0,'/content/ECE570-Project/models/conversion_methods.py')

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import timeit
import math
try:
    from spiking_layers import LinearLIF, Conv2dLIF
    from conversion_method import SpikeNorm
except:
    from models.spiking_layers import LinearLIF, Conv2dLIF
    from models.conversion_method import SpikeNorm

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                      torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = torchvision.datasets.CIFAR10('/data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10('/data', train=False, download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(train_dataset)

In [None]:
batch_size_train, batch_size_test = 64, 1000

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False)

In [None]:
def train(classifier, epoch, optimizer):

  classifier.train() # we need to set the mode for our model

  for batch_idx, (images, targets) in enumerate(train_loader):
    images, targets = images.to(DEVICE), targets.to(DEVICE)
    optimizer.zero_grad()
    output = classifier(images)
    loss = F.cross_entropy(output, targets)
    loss.backward()
    optimizer.step()

    if batch_idx % 10 == 0: # We record our output every 10 batches
      train_losses.append(loss.item()) # item() is to get the value of the tensor directly
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
    if batch_idx % 300 == 0: # We visulize our output every 10 batches
      print(f'Epoch {epoch}: [{batch_idx*len(images)}/{len(train_loader.dataset)}] Loss: {loss.item()}')


def test(classifier, epoch):

  classifier.eval() # we need to set the mode for our model

  test_loss = 0
  correct = 0

  with torch.no_grad():
    for images, targets in test_loader:
      images, targets = images.to(DEVICE), targets.to(DEVICE)
      output = classifier(images)
      test_loss += F.cross_entropy(output, targets, reduction='sum').item()
      pred = output.data.max(1, keepdim=True)[1] # we get the estimate of our result by look at the largest class value
      correct += pred.eq(targets.data.view_as(pred)).sum() # sum up the corrected samples

  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  test_counter.append(len(train_loader.dataset)*epoch)

  print(f'Test result on epoch {epoch}: Avg loss is {test_loss}, Accuracy: {100.*correct/len(test_loader.dataset)}%')


In [None]:
class SpikingModel(nn.Module):
  def __init__(self, device=None):
    super(SpikingModel, self).__init__()
    self.features = nn.Sequential(
        Conv2dLIF(3, 32, kernel_size=5, padding='same', device=device, leak=1.0),
        nn.MaxPool2d(kernel_size=2, stride=2),
        Conv2dLIF(32, 64, kernel_size=5, padding='same', device=device, leak=1.0),
        nn.MaxPool2d(kernel_size=2, stride=2),
        Conv2dLIF(64, 64, kernel_size=3, padding='same', device=device, leak=1.0)
        )

    self.classifier = nn.Sequential(
        LinearLIF(8*8*64, 256, device=device, leak=1.0),
        nn.ReLU(inplace=True),
        LinearLIF(256, 10, cumulative=True, device=device, leak=1.0)
        )

    self._init_weights()

  def _init_internal_states(self):
    self.mem_classifier = []
    self.mem_features = []
    for layer_idx in range(len(self.classifier)):
      self.mem_classifier += [None]
    for layer_idx in range(len(self.features)):
      self.mem_features += [None]

  def _init_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0,  math.sqrt(2. / n))
        if m.bias is not None:
          m.bias.data.zero_()
      elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
      elif isinstance(m, nn.Linear):
        n = m.weight.size(1)
        m.weight.data.normal_(0, 0.01)
        if m.bias is not None:
          m.bias.data.zero_()

  def forward(self, X):
    batch_size = X.shape[0]
    self._init_internal_states()
    for t in range(10):
      spk = X
      for layer_idx in range(len(self.features)):
        if isinstance(self.features[layer_idx], nn.Conv2d):
          spk, self.mem_features[layer_idx] = self.features[layer_idx](spk, self.mem_features[layer_idx])
          # print(layer_idx, "conv")
        if isinstance(self.features[layer_idx], nn.MaxPool2d):
          spk = self.features[layer_idx](spk)
          # print(layer_idx, "pool")

      spk = spk.view(batch_size, -1)
      for layer_idx in range(len(self.classifier)):
        if isinstance(self.classifier[layer_idx], nn.Linear):
          spk, self.mem_classifier[layer_idx] = self.classifier[layer_idx](spk, self.mem_classifier[layer_idx])
          # print(layer_idx, "linear")
    return self.mem_classifier[layer_idx]

snn_model = SpikingModel(DEVICE)
snn_model.to(DEVICE)

optimizer_snn = optim.Adam(snn_model.parameters(), lr=0.001)

In [None]:
train_losses = []
train_counter = []
test_losses = []
test_counter = []
max_epoch = 6

def training_snn():
  for epoch in range(1, max_epoch+1):
    train(snn_model, epoch, optimizer_snn)
    test(snn_model, epoch)

# print(f'Total time for CNN: {timeit.timeit(training_snn, number=1)/60} min')

In [None]:
class ANNModel(nn.Module):
  def __init__(self):
    super(ANNModel, self).__init__()
    self.features = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=5, padding='same', bias=False),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(32, 64, kernel_size=5, padding='same', bias=False),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(64, 64, kernel_size=3, padding='same', bias=False)
    )

    self.classifier = nn.Sequential(
        nn.Linear(8*8*64, 256, bias=False),
        nn.ReLU(inplace=True),
        nn.Linear(256, 10, bias=False)
    )


  def forward(self, X):
    batch_size = X.shape[0]
    x = self.features(X)
    x = x.view(batch_size, -1)
    x = self.classifier(x)
    return x

ann_model = ANNModel()
ann_model.to(DEVICE)

optimizer_ann = optim.Adam(ann_model.parameters(), lr=0.0001)

In [None]:
train_losses = []
train_counter = []
test_losses = []
test_counter = []
max_epoch = 20

def training_ann():
  for epoch in range(1, max_epoch+1):
    train(ann_model, epoch, optimizer_ann)
    test(ann_model, epoch)

print(f'Total time for CNN: {timeit.timeit(training_ann, number=1)/60} min')

In [None]:
model_cifar_cnn = SpikeNorm(ann_model, snn_model, train_loader, DEVICE, 10)