## Segunda versão de uma SNN, agora utilizando Convoluções:

In [7]:
# imports
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools
import os
import pandas as pd
import cv2
from sklearn.model_selection import train_test_split

In [8]:
class WaveGuideImageData(Dataset):
    def __init__(self, train=True, transform=None):
        #Get images directory
        img_dir = '../imgs'
        images = os.listdir(img_dir)
        img_db_dir = 'final_image_database_loss_deciles.csv'
        img_deciles = pd.read_csv(img_db_dir)
        # Create empty Pandas DataFrame
        full_ds_pdf = pd.DataFrame(columns=['feature', 'd0', 'd1', 'd2', 'd3', 'd4', 'd5', 'd6', 'd7', 'd8', 'd9'])

        #Load all files into one Pandas Dataframe
        for img_file in images:
            nova_img = {}
            img = cv2.imread('../imgs/'+img_file)
            #img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            nova_img['feature'] = img
            img_decil_df = img_deciles[img_deciles.img_file == img_file].copy()
            nova_img['d0'] = float(img_decil_df.d0)
            nova_img['d1'] = float(img_decil_df.d1)
            nova_img['d2'] = float(img_decil_df.d2)
            nova_img['d3'] = float(img_decil_df.d3)
            nova_img['d4'] = float(img_decil_df.d4)
            nova_img['d5'] = float(img_decil_df.d5)
            nova_img['d6'] = float(img_decil_df.d6)
            nova_img['d7'] = float(img_decil_df.d7)
            nova_img['d8'] = float(img_decil_df.d8)
            nova_img['d9'] = float(img_decil_df.d9)
            full_ds_pdf = full_ds_pdf.append(nova_img, ignore_index=True)

        #Splitting data into train and validation set
        X_train, X_test, y_d0_train, y_d0_test,\
                 y_d1_train, y_d1_test,\
                 y_d2_train, y_d2_test,\
                 y_d3_train, y_d3_test,\
                 y_d4_train, y_d4_test,\
                 y_d5_train, y_d5_test,\
                 y_d6_train, y_d6_test,\
                 y_d7_train, y_d7_test,\
                 y_d8_train, y_d8_test,\
                 y_d9_train, y_d9_test = train_test_split(full_ds_pdf.feature,
                                                          full_ds_pdf.d0,
                                                          full_ds_pdf.d1,
                                                          full_ds_pdf.d2,
                                                          full_ds_pdf.d3,
                                                          full_ds_pdf.d4,
                                                          full_ds_pdf.d5,
                                                          full_ds_pdf.d6,
                                                          full_ds_pdf.d7,
                                                          full_ds_pdf.d8,
                                                          full_ds_pdf.d9, test_size=0.2)
        if train == True:
            self.x = X_train
            self.d0_y = y_d0_train
            self.d1_y = y_d1_train
            self.d2_y = y_d2_train
            self.d3_y = y_d3_train
            self.d4_y = y_d4_train
            self.d5_y = y_d5_train
            self.d6_y = y_d6_train
            self.d7_y = y_d7_train
            self.d8_y = y_d8_train
            self.d9_y = y_d9_train
        else:
            self.x = X_test
            self.d0_y = y_d0_test
            self.d1_y = y_d1_test
            self.d2_y = y_d2_test
            self.d3_y = y_d3_test
            self.d4_y = y_d4_test
            self.d5_y = y_d5_test
            self.d6_y = y_d6_test
            self.d7_y = y_d7_test
            self.d8_y = y_d8_test
            self.d9_y = y_d9_test
        # Applying Transformation
        self.transform = transform
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        image = np.array(self.x.iloc[idx]).astype('float')
        d0 = self.d0_y.iloc[idx].astype('float')
        d1 = self.d1_y.iloc[idx].astype('float')
        d2 = self.d2_y.iloc[idx].astype('float')
        d3 = self.d3_y.iloc[idx].astype('float')
        d4 = self.d4_y.iloc[idx].astype('float')
        d5 = self.d5_y.iloc[idx].astype('float')
        d6 = self.d6_y.iloc[idx].astype('float')
        d7 = self.d7_y.iloc[idx].astype('float')
        d8 = self.d8_y.iloc[idx].astype('float')
        d9 = self.d9_y.iloc[idx].astype('float')

        sample={'image':image,
                'd0': d0,
                'd1': d1,
                'd2': d2,
                'd3': d3,
                'd4': d4,
                'd5': d5,
                'd6': d6,
                'd7': d7,
                'd8': d8,
                'd9': d9}
        # Applying Transformation
        if self.transform:
            sample = self.transform = sample
        
        return sample

In [9]:
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 50

In [10]:
batch_size = 4
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [11]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((442, 442)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

train_data = WaveGuideImageData(transform=transform)
test_data = WaveGuideImageData(train=False, transform=transform)

train_loader = DataLoader(train_data, batch_size=4, shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=4, shuffle=True, drop_last=True)

In [12]:
# Define Network
class CSNN(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.conv1 = nn.Conv2d(1, 448, 5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(448, 896, 5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(896*4*4, 10)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        cur1 = F.max_pool2d(self.conv1(x), 2)
        spk1, mem1 = self.lif1(cur1, mem1)

        cur2 = F.max_pool2d(self.conv2(spk1), 2)
        spk2, mem2 = self.lif2(cur2, mem2)

        cur3 = self.fc1(spk2.view(batch_size, -1))
        spk3, mem3 = self.lif3(cur3, mem3)

        return spk3, mem3

In [13]:
def forward_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)

  return torch.stack(spk_rec), torch.stack(mem_rec)

def batch_accuracy(train_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    train_loader = enumerate(train_loader)
    for idx, batched_sample in train_loader:
      # Data itself: an image
      data = batched_sample['image'].to(device)

      # The Labels
      d0 = batched_sample['d0'].to(device)
      d1 = batched_sample['d1'].to(device)
      d2 = batched_sample['d2'].to(device)
      d3 = batched_sample['d3'].to(device)
      d4 = batched_sample['d4'].to(device)
      d5 = batched_sample['d5'].to(device)
      d6 = batched_sample['d6'].to(device)
      d7 = batched_sample['d7'].to(device)
      d8 = batched_sample['d8'].to(device)
      d9 = batched_sample['d9'].to(device)

      spk_rec, _ = forward_pass(net, num_steps, data)
      
      acc += SF.accuracy_rate(spk, d0) * spk[0].size(1)
      acc += SF.accuracy_rate(spk, d1) * spk[1].size(1)
      acc += SF.accuracy_rate(spk, d2) * spk[2].size(1)
      acc += SF.accuracy_rate(spk, d3) * spk[3].size(1)
      acc += SF.accuracy_rate(spk, d4) * spk[4].size(1)
      acc += SF.accuracy_rate(spk, d5) * spk[5].size(1)
      acc += SF.accuracy_rate(spk, d6) * spk[6].size(1)
      acc += SF.accuracy_rate(spk, d7) * spk[7].size(1)
      acc += SF.accuracy_rate(spk, d8) * spk[8].size(1)
      acc += SF.accuracy_rate(spk, d9) * spk[9].size(1)

      total += spk[0].size(1)
      total += spk[1].size(1)
      total += spk[2].size(1)
      total += spk[3].size(1)
      total += spk[4].size(1)
      total += spk[5].size(1)
      total += spk[6].size(1)
      total += spk[7].size(1)
      total += spk[8].size(1)
      total += spk[9].size(1)

  return acc/total

In [15]:
net = CSNN()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))
num_epochs = 1
loss_hist = []
test_acc_hist = []
counter = 0
loss_function = SF.ce_rate_loss()
# Outer training loop
for epoch in range(num_epochs):

    # Training loop
    for idx, batched_sample in enumerate(train_loader):
        # Data itself: an image
        data = batched_sample['image'].to(device)

        # The Labels
        d0 = batched_sample['d0'].to(device)
        d1 = batched_sample['d1'].to(device)
        d2 = batched_sample['d2'].to(device)
        d3 = batched_sample['d3'].to(device)
        d4 = batched_sample['d4'].to(device)
        d5 = batched_sample['d5'].to(device)
        d6 = batched_sample['d6'].to(device)
        d7 = batched_sample['d7'].to(device)
        d8 = batched_sample['d8'].to(device)
        d9 = batched_sample['d9'].to(device)

        # forward pass
        net.train()
        spk_rec, _ = forward_pass(net, num_steps, data)

        # initialize the loss & sum over time
        loss_val0 = loss_function(d0_hat, spk_rec[0]) # apply loss
        loss_val1 = loss_function(d1_hat, spk_rec[1]) # apply loss
        loss_val2 = loss_function(d2_hat, spk_rec[2]) # apply loss
        loss_val3 = loss_function(d3_hat, spk_rec[3]) # apply loss
        loss_val4 = loss_function(d4_hat, spk_rec[4]) # apply loss
        loss_val5 = loss_function(d5_hat, spk_rec[5]) # apply loss
        loss_val6 = loss_function(d6_hat, spk_rec[6]) # apply loss
        loss_val7 = loss_function(d7_hat, spk_rec[7]) # apply loss
        loss_val8 = loss_function(d8_hat, spk_rec[8]) # apply loss
        loss_val9 = loss_function(d9_hat, spk_rec[9]) # apply loss

        loss_val = (loss_val0+loss_val1+loss_val2+loss_val3+loss_val4+loss_val5+loss_val6+loss_val7+loss_val8+loss_val9)/10

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        if counter % 50 == 0:
            with torch.no_grad():
                net.eval()

                # Test set forward pass
                test_acc = batch_accuracy(test_loader, net, num_steps)
                print(f"Iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n")
                test_acc_hist.append(test_acc.item())

        counter += 1

RuntimeError: Given groups=1, weight of size [448, 1, 5, 5], expected input[4, 442, 442, 3] to have 1 channels, but got 442 channels instead