In [1]:
! pip install snntorch

Collecting snntorch
  Obtaining dependency information for snntorch from https://files.pythonhosted.org/packages/4b/55/95ee9e0e26cf74a464603ef7ab84be186133bfb95ac0c5ae9d1eb408b69b/snntorch-0.7.0-py2.py3-none-any.whl.metadata
  Downloading snntorch-0.7.0-py2.py3-none-any.whl.metadata (16 kB)
Collecting nir (from snntorch)
  Obtaining dependency information for nir from https://files.pythonhosted.org/packages/a8/e1/60b9014266c26d170b2f1bc7fe1b7b6ad823ad8cb302104ca306685311ac/nir-1.0.1-py3-none-any.whl.metadata
  Downloading nir-1.0.1-py3-none-any.whl.metadata (5.8 kB)
Collecting nirtorch (from snntorch)
  Obtaining dependency information for nirtorch from https://files.pythonhosted.org/packages/cd/74/92cc684fd83636b072318693676877af0d80c4e136067237f147f9a18d6f/nirtorch-1.0-py3-none-any.whl.metadata
  Downloading nirtorch-1.0-py3-none-any.whl.metadata (3.6 kB)
Downloading snntorch-0.7.0-py2.py3-none-any.whl (108 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.0/109.0

# SNN Autoencoder (SAE)

In [2]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import utils as utls

import snntorch as snn
from snntorch import utils
from snntorch import surrogate

import numpy as np

In [3]:
class SAE(nn.Module):
    def __init__(self):
        super().__init__()
        #Encoder
        self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3,padding = 1,stride=2),
                          nn.BatchNorm2d(32),
                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),
                          nn.Conv2d(32, 64, 3,padding = 1,stride=2),
                          nn.BatchNorm2d(64),
                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),
                          nn.Conv2d(64, 128, 3,padding = 1,stride=2),
                          nn.BatchNorm2d(128),
                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),
                          nn.Flatten(start_dim = 1, end_dim = 3),
                          nn.Linear(2048, latent_dim), #this needs to be the final layer output size (channels * pixels * pixels)
                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh)
                          )
       # From latent back to tensor for convolution
        self.linearNet= nn.Sequential(nn.Linear(latent_dim,128*4*4),
                               snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh))        #Decoder

        self.decoder = nn.Sequential(nn.Unflatten(1,(128,4,4)),
                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),
                          nn.ConvTranspose2d(128, 64, 3,padding = 1,stride=(2,2),output_padding=1),
                          nn.BatchNorm2d(64),
                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),
                          nn.ConvTranspose2d(64, 32, 3,padding = 1,stride=(2,2),output_padding=1),
                          nn.BatchNorm2d(32),
                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh),
                          nn.ConvTranspose2d(32, 1, 3,padding = 1,stride=(2,2),output_padding=1),
                          snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,output=True,threshold=20000) #make large so membrane can be trained
                          )

    def forward(self, x): #Dimensions: [Batch,Channels,Width,Length]
        utils.reset(self.encoder) #need to reset the hidden states of LIF
        utils.reset(self.decoder)
        utils.reset(self.linearNet)

        #encode
        spk_mem=[];spk_rec=[];encoded_x=[]
        for step in range(num_steps): #for t in time
            spk_x,mem_x=self.encode(x) #Output spike trains and neuron membrane states
            spk_rec.append(spk_x)
            spk_mem.append(mem_x)
        spk_rec=torch.stack(spk_rec,dim=2)
        spk_mem=torch.stack(spk_mem,dim=2) #Dimensions:[Batch,Channels,Width,Length, Time]

        #decode
        spk_mem2=[];spk_rec2=[];decoded_x=[]
        for step in range(num_steps): #for t in time
            x_recon,x_mem_recon=self.decode(spk_rec[...,step])
            spk_rec2.append(x_recon)
            spk_mem2.append(x_mem_recon)
        spk_rec2=torch.stack(spk_rec2,dim=4)
        spk_mem2=torch.stack(spk_mem2,dim=4)#Dimensions:[Batch,Channels,Width,Length, Time]
        out = spk_mem2[:,:,:,:,-1] #return the membrane potential of the output neuron at t = -1 (last t)
        return out #Dimensions:[Batch,Channels,Width,Length]

    def encode(self,x):
        spk_latent_x,mem_latent_x=self.encoder(x)
        return spk_latent_x,mem_latent_x

    def decode(self,x):
        spk_x,mem_x = self.linearNet(x) #convert latent dimension back to total size of features in encoder final layer
        spk_x2,mem_x2=self.decoder(spk_x)
        return spk_x2,mem_x2

In [4]:
import os

def train(network, trainloader, opti, epoch, device):
    network = network.train()
    train_loss_hist = []
    total_loss = 0

    # Ensure the training directory exists
    training_dir = '/kaggle/working/training/'
    os.makedirs(training_dir, exist_ok=True)

    for batch_idx, (real_img, labels) in enumerate(trainloader):
        opti.zero_grad()
        real_img = real_img.to(device)

        # Pass data into network, and return reconstructed image from Membrane Potential at t = -1
        x_recon = network(real_img)  # Dimensions passed in: [Batch_size, Input_size, Image_Width, Image_Length]

        # Calculate loss
        loss_val = F.mse_loss(x_recon, real_img)

        # Backpropagate and update weights
        loss_val.backward()
        opti.step()
        total_loss += loss_val.item()

        # Add the current batch loss to the history
        # train_loss_hist.append(loss_val.item())

        # Print the loss for the current batch
      

        '''
        # Save reconstructed images at the end of the epoch
        if batch_idx == len(trainloader) - 1:
            # Normalize images to [0, 1] for proper saving
            real_img_norm = (real_img - real_img.min()) / (real_img.max() - real_img.min())
            x_recon_norm = (x_recon - x_recon.min()) / (x_recon.max() - x_recon.min())

            # Save the images using torchvision's save_image function
            utls.save_image(real_img_norm, os.path.join(training_dir, f'epoch{epoch}_finalbatch_inputs.png'))
            utls.save_image(x_recon_norm, os.path.join(training_dir, f'epoch{epoch}_finalbatch_recon.png'))
'''
    # Return the average loss for this epoch
        # train_loss += loss_val.item()*images.size(0)
            
    # print avg training statistics 
    average_loss = total_loss/len(train_loader)
    train_loss_hist.append(average_loss)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, average_loss))
    # return avg_loss


In [5]:
# Parameters and Run training and testing
batch_size = 250
input_size = 32 #size of input to first convolutional layer

#setup GPU
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load MNIST
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Resize((input_size, input_size)),transforms.Normalize((0,), (1,))]), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Resize((input_size, input_size)),transforms.Normalize((0,), (1,))]), download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 102293377.66it/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 85306826.64it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 24517134.60it/s]


Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 7842951.32it/s]


Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [7]:
# SNN parameters
spike_grad = surrogate.atan(alpha=2.0)# alternate surrogate gradient: fast_sigmoid(slope=25) 
beta = 0.5 #decay rate of neurons 
num_steps=5 #time 
latent_dim = 32 #dimension of latent layer (how compressed we want the information)
thresh=1#spiking threshold (lower = more spikes are let through)
epochs=30 #number of epochs
max_epoch=epochs

#Define Network and optimizer
net=SAE()
net = net.to(device)

optimizer = torch.optim.AdamW(net.parameters(), 
                            lr=0.0001,
                            betas=(0.9, 0.999), 
                            weight_decay=0.001)

#Run training and testing        
for e in range(epochs): 
    train(net, train_loader, optimizer, e,device)





Epoch: 0 	Training Loss: 0.300347
Epoch: 1 	Training Loss: 0.134863
Epoch: 2 	Training Loss: 0.112710
Epoch: 3 	Training Loss: 0.101490
Epoch: 4 	Training Loss: 0.091986
Epoch: 5 	Training Loss: 0.082478
Epoch: 6 	Training Loss: 0.074407
Epoch: 7 	Training Loss: 0.067139
Epoch: 8 	Training Loss: 0.060876
Epoch: 9 	Training Loss: 0.055533
Epoch: 10 	Training Loss: 0.050633
Epoch: 11 	Training Loss: 0.046136
Epoch: 12 	Training Loss: 0.042324
Epoch: 13 	Training Loss: 0.038793
Epoch: 14 	Training Loss: 0.035590
Epoch: 15 	Training Loss: 0.032958
Epoch: 16 	Training Loss: 0.030676
Epoch: 17 	Training Loss: 0.028745
Epoch: 18 	Training Loss: 0.027008
Epoch: 19 	Training Loss: 0.025682
Epoch: 20 	Training Loss: 0.024585
Epoch: 21 	Training Loss: 0.023593
Epoch: 22 	Training Loss: 0.022728
Epoch: 23 	Training Loss: 0.022042
Epoch: 24 	Training Loss: 0.021403
Epoch: 25 	Training Loss: 0.020846
Epoch: 26 	Training Loss: 0.020379
Epoch: 27 	Training Loss: 0.019951
Epoch: 28 	Training Loss: 0.01

In [10]:
# Set the network to evaluation mode
network = net.eval()

test_loss_hist = []
total_loss = 0

# Disable gradient calculations
with torch.no_grad():
    for batch_idx, (real_img, labels) in enumerate(test_loader):
        real_img = real_img.to(device)

        # Pass data into network
        x_recon = net(real_img)  # Dimensions passed in: [Batch_size, Input_size, Image_Width, Image_Length]

        # Calculate loss
        loss_val = F.mse_loss(x_recon, real_img)
        total_loss += loss_val.item()

        # Optional: Code to display/save images or other test-time operations

# Calculate average loss for the test data
average_loss = total_loss / len(test_loader)
test_loss_hist.append(average_loss)

print('Test Loss: {:.6f}'.format(average_loss))




Test Loss: 0.092623
