## Desenvolvimento do modelo de SNN para classificar cada imagem

### Carregamento da base de imagens, junto com os valores esperados (X, Y)

In [2]:
import os
import pandas as pd
import numpy as np

import snntorch as snn
import torch
import torch.nn as nn
import torch.nn.functional as F
from snntorch import functional as SF
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split
import cv2
import itertools
import tqdm

In [3]:
img_db_dir = 'final_image_database_loss_deciles.csv'

img_deciles = pd.read_csv(img_db_dir)
img_deciles

Unnamed: 0,img_file,img_nfi,d0,d1,d2,d3,d4,d5,d6,d7,d8,d9,soft_label
0,cone_luz_sem_grafeno_n2_(11.6964_+_0.1i).jpg,0.1,0.920530,0.033113,0.006623,0.019868,0.000000,0.000000,0.000000,0.006623,0.006623,0.006623,0
1,cone_luz_sem_grafeno_n2_(11.6964_+_0.2i).jpg,0.2,0.827815,0.006623,0.000000,0.006623,0.013245,0.006623,0.006623,0.013245,0.013245,0.105960,0
2,cone_luz_sem_grafeno_n2_(11.6964_+_0.3i).jpg,0.3,0.821192,0.006623,0.019868,0.006623,0.000000,0.019868,0.000000,0.006623,0.006623,0.112583,0
3,cone_luz_sem_grafeno_n2_(11.6964_+_0.4i).jpg,0.4,0.854305,0.006623,0.013245,0.019868,0.006623,0.000000,0.000000,0.000000,0.000000,0.099338,0
4,cone_luz_sem_grafeno_n2_(11.6964_+_0.5i).jpg,0.5,0.543046,0.298013,0.006623,0.026490,0.013245,0.026490,0.006623,0.000000,0.006623,0.072848,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
228,cone_luz_sem_grafeno_n2_(11.6964_+_9.6i).jpg,9.6,0.039735,0.039735,0.079470,0.119205,0.178808,0.092715,0.198675,0.165563,0.006623,0.079470,6
229,cone_luz_sem_grafeno_n2_(11.6964_+_9.7i).jpg,9.7,0.033113,0.052980,0.072848,0.092715,0.178808,0.072848,0.205298,0.165563,0.013245,0.112583,6
230,cone_luz_sem_grafeno_n2_(11.6964_+_9.8i).jpg,9.8,0.026490,0.039735,0.059603,0.119205,0.178808,0.086093,0.198675,0.165563,0.006623,0.119205,6
231,cone_luz_sem_grafeno_n2_(11.6964_+_9.9i).jpg,9.9,0.059603,0.039735,0.059603,0.092715,0.192053,0.079470,0.185430,0.158940,0.006623,0.125828,4


### Construção do DataSet de Entrada para pré-processamento.
### Ele contém a imagem como uma matriz numérica, junto do perfil de decil (d0-d9) como valor esperado.

### Para adequar ao formato do pytorch, é necessário criar classes para criar o dataset e preprocessamento:

In [4]:
class WaveGuideImageData(Dataset):
    def __init__(self, train=True, transform=None):
        #Get images directory
        img_dir = '../imgs'
        images = os.listdir(img_dir)

        # Create empty Pandas DataFrame
        full_ds_pdf = pd.DataFrame(columns=['feature', 'd0', 'd1', 'd2', 'd3', 'd4', 'd5', 'd6', 'd7', 'd8', 'd9'])

        #Load all files into one Pandas Dataframe
        for img_file in images:
            nova_img = {}
            img = cv2.imread('../imgs/'+img_file)
            #img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            nova_img['feature'] = img
            img_decil_df = img_deciles[img_deciles.img_file == img_file].copy()
            nova_img['d0'] = float(img_decil_df.d0)
            nova_img['d1'] = float(img_decil_df.d1)
            nova_img['d2'] = float(img_decil_df.d2)
            nova_img['d3'] = float(img_decil_df.d3)
            nova_img['d4'] = float(img_decil_df.d4)
            nova_img['d5'] = float(img_decil_df.d5)
            nova_img['d6'] = float(img_decil_df.d6)
            nova_img['d7'] = float(img_decil_df.d7)
            nova_img['d8'] = float(img_decil_df.d8)
            nova_img['d9'] = float(img_decil_df.d9)
            full_ds_pdf = full_ds_pdf.append(nova_img, ignore_index=True)

        #Splitting data into train and validation set
        X_train, X_test, y_d0_train, y_d0_test,\
                 y_d1_train, y_d1_test,\
                 y_d2_train, y_d2_test,\
                 y_d3_train, y_d3_test,\
                 y_d4_train, y_d4_test,\
                 y_d5_train, y_d5_test,\
                 y_d6_train, y_d6_test,\
                 y_d7_train, y_d7_test,\
                 y_d8_train, y_d8_test,\
                 y_d9_train, y_d9_test = train_test_split(full_ds_pdf.feature,
                                                          full_ds_pdf.d0,
                                                          full_ds_pdf.d1,
                                                          full_ds_pdf.d2,
                                                          full_ds_pdf.d3,
                                                          full_ds_pdf.d4,
                                                          full_ds_pdf.d5,
                                                          full_ds_pdf.d6,
                                                          full_ds_pdf.d7,
                                                          full_ds_pdf.d8,
                                                          full_ds_pdf.d9, test_size=0.2)
        if train == True:
            self.x = X_train
            self.d0_y = y_d0_train
            self.d1_y = y_d1_train
            self.d2_y = y_d2_train
            self.d3_y = y_d3_train
            self.d4_y = y_d4_train
            self.d5_y = y_d5_train
            self.d6_y = y_d6_train
            self.d7_y = y_d7_train
            self.d8_y = y_d8_train
            self.d9_y = y_d9_train
        else:
            self.x = X_test
            self.d0_y = y_d0_test
            self.d1_y = y_d1_test
            self.d2_y = y_d2_test
            self.d3_y = y_d3_test
            self.d4_y = y_d4_test
            self.d5_y = y_d5_test
            self.d6_y = y_d6_test
            self.d7_y = y_d7_test
            self.d8_y = y_d8_test
            self.d9_y = y_d9_test
        # Applying Transformation
        self.transform = transform
    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        image = np.array(self.x.iloc[idx]).astype('float')
        d0 = self.d0_y.iloc[idx].astype('float')
        d1 = self.d1_y.iloc[idx].astype('float')
        d2 = self.d2_y.iloc[idx].astype('float')
        d3 = self.d3_y.iloc[idx].astype('float')
        d4 = self.d4_y.iloc[idx].astype('float')
        d5 = self.d5_y.iloc[idx].astype('float')
        d6 = self.d6_y.iloc[idx].astype('float')
        d7 = self.d7_y.iloc[idx].astype('float')
        d8 = self.d8_y.iloc[idx].astype('float')
        d9 = self.d9_y.iloc[idx].astype('float')

        sample={'image':image,
                'd0': d0,
                'd1': d1,
                'd2': d2,
                'd3': d3,
                'd4': d4,
                'd5': d5,
                'd6': d6,
                'd7': d7,
                'd8': d8,
                'd9': d9}
        # Applying Transformation
        if self.transform:
            sample = self.transform = sample
        
        return sample

In [5]:

## Tamanho da imagem 442 x 442
transform = transforms.Compose([transforms.ToTensor(), transforms.Grayscale(), transforms.Normalize((0,),(1,))])

train_data = WaveGuideImageData(transform=transform)
test_data = WaveGuideImageData(train=False, transform=transform)

train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers=0, drop_last=True)
test_loader = DataLoader(test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=True)

### Criação da Arquitetura do modelo inicial:

In [6]:
beta = 0.9 # membrane potential decay rate
num_steps = 10 # 10 time steps

In [7]:
class Net(torch.nn.Module):
    """Simple spiking neural network in snntorch."""

    def __init__(self, timesteps, hidden, beta):
        super().__init__()

        self.timesteps = timesteps
        self.hidden = hidden
        self.beta = beta

        # layer 1
        self.fc1 = torch.nn.Linear(in_features=586092, out_features=self.hidden)
        self.rlif1 = snn.RLeaky(beta=self.beta, V=0.5)

        # layer 2
        self.fc2 = torch.nn.Linear(in_features=self.hidden, out_features=10)
        self.rlif2 = snn.RLeaky(beta=self.beta, V=0.5)

    def forward(self, x):
        """Forward pass for several time steps."""

        # Initalize membrane potential
        spk1, mem1 = self.rlif1.init_rleaky()
        spk2, mem2 = self.rlif2.init_rleaky()

        # Empty lists to record outputs
        spk_recording = []
        for step in range(self.timesteps):
            spk1, mem1 = self.rlif1(self.fc1(x.float()), spk1, mem1)
            spk2, mem2 = self.rlif2(self.fc2(spk1), spk2, mem2)
            spk_recording.append(spk2)

        return torch.stack(spk_recording)

In [12]:
hidden = 1024
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = Net(timesteps=num_steps, hidden=hidden, beta=0.9).to(device)

### Definição da função erro

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_function = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

### Treinamento do Modelo

In [14]:
num_epochs = 5
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
loss_hist = []

with tqdm.trange(num_epochs) as pbar:
    for _ in pbar:
        train_batch = enumerate(train_loader)
        minibatch_counter = 0
        loss_epoch = []

        for batch_idx, batched_sample in train_batch:
            
            image = batched_sample['image'].to(device)
            d0 = batched_sample['d0'].to(device)
            d1 = batched_sample['d1'].to(device)
            d2 = batched_sample['d2'].to(device)
            d3 = batched_sample['d3'].to(device)
            d4 = batched_sample['d4'].to(device)
            d5 = batched_sample['d5'].to(device)
            d6 = batched_sample['d6'].to(device)
            d7 = batched_sample['d7'].to(device)
            d8 = batched_sample['d8'].to(device)
            d9 = batched_sample['d9'].to(device)
            
            spk = model(image.flatten(1)) # forward-pass
            d0_hat = spk[0]
            d1_hat = spk[1]
            d2_hat = spk[2]
            d3_hat = spk[3]
            d4_hat = spk[4]
            d5_hat = spk[5]
            d6_hat = spk[6]
            d7_hat = spk[7]
            d8_hat = spk[8]
            d9_hat = spk[9]
            
            loss_val0 = loss_function(d0_hat, d0) # apply loss
            loss_val1 = loss_function(d1_hat, d1) # apply loss
            loss_val2 = loss_function(d2_hat, d2) # apply loss
            loss_val3 = loss_function(d3_hat, d3) # apply loss
            loss_val4 = loss_function(d4_hat, d4) # apply loss
            loss_val5 = loss_function(d5_hat, d5) # apply loss
            loss_val6 = loss_function(d6_hat, d6) # apply loss
            loss_val7 = loss_function(d7_hat, d7) # apply loss
            loss_val8 = loss_function(d8_hat, d8) # apply loss
            loss_val9 = loss_function(d9_hat, d9) # apply loss
            
            loss = (loss_val0+loss_val1+loss_val2+loss_val3+loss_val4+loss_val5+loss_val6+loss_val7+loss_val8+loss_val9)/10

            optimizer.zero_grad() # zero out gradients
            loss.backward() # calculate gradients
            optimizer.step() # update weights

            loss_hist.append(loss.item())
            minibatch_counter += 1

            avg_batch_loss = sum(loss_hist) / minibatch_counter
            pbar.set_postfix(loss="%.3e" % avg_batch_loss)

100%|██████████| 5/5 [1:09:19<00:00, 831.99s/it, loss=1.134e+00] 


### Validação do Modelo

In [15]:
test_batch = enumerate(test_loader)
minibatch_counter = 0
loss_epoch = []

model.eval()
with torch.no_grad():
  total = 0
  acc = 0
  for batch_idx, batched_sample in test_batch:
        
      image = batched_sample['image'].to(device)
      d0 = batched_sample['d0'].to(device)
      d1 = batched_sample['d1'].to(device)
      d2 = batched_sample['d2'].to(device)
      d3 = batched_sample['d3'].to(device)
      d4 = batched_sample['d4'].to(device)
      d5 = batched_sample['d5'].to(device)
      d6 = batched_sample['d6'].to(device)
      d7 = batched_sample['d7'].to(device)
      d8 = batched_sample['d8'].to(device)
      d9 = batched_sample['d9'].to(device)

      spk = model(image.flatten(1)) # forward-pass
      
      
      acc += SF.accuracy_rate(spk, d0) * spk[0].size(1)
      acc += SF.accuracy_rate(spk, d1) * spk[1].size(1)
      acc += SF.accuracy_rate(spk, d2) * spk[2].size(1)
      acc += SF.accuracy_rate(spk, d3) * spk[3].size(1)
      acc += SF.accuracy_rate(spk, d4) * spk[4].size(1)
      acc += SF.accuracy_rate(spk, d5) * spk[5].size(1)
      acc += SF.accuracy_rate(spk, d6) * spk[6].size(1)
      acc += SF.accuracy_rate(spk, d7) * spk[7].size(1)
      acc += SF.accuracy_rate(spk, d8) * spk[8].size(1)
      acc += SF.accuracy_rate(spk, d9) * spk[9].size(1)

      total += spk[0].size(1)
      total += spk[1].size(1)
      total += spk[2].size(1)
      total += spk[3].size(1)
      total += spk[4].size(1)
      total += spk[5].size(1)
      total += spk[6].size(1)
      total += spk[7].size(1)
      total += spk[8].size(1)
      total += spk[9].size(1)

print(f"The total accuracy on the test set is: {(acc/total) * 100:.2f}%")

The total accuracy on the test set is: 7.73%
