# Import Libraries

In [7]:
import numpy as np
import numpy.matlib as npm
# import scipy as sc
import random
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import utils
from snntorch import spikegen
import snntorch.functional as SF

# Define Class

In [8]:
class DeepESN():

    def __init__(self, Nu, Nr, Nl, configs, device='cpu'):
        
        reservoirConf = configs.reservoirConf  # reservoir configurations


        self.W = {}
        self.Win = {}
        #self.Gain = {}
        #self.Bias = {}

        self.Nu = Nu  # number of inputs
        self.Nr = Nr  # number of units per layer
        self.Nl = Nl  # number of layers

        self.device = device

        self.leaky = configs.snn.leaky
        self.Gain = configs.snn.Gain
        self.Bias = configs.snn.Bias
        self.spike_train = configs.snn.spike_train
        self.threshold = configs.snn.threshold
        self.number_row_elements = round(reservoirConf.connectivity * Nr)
        

        # if reservoir neurons are not fully connected
        if reservoirConf.connectivity < 1:
            for layer in range(Nl):
                self.W[layer] = torch.zeros((Nr, Nr), device=device)
                for row in range(Nr):
                    number_row_elements = round(reservoirConf.connectivity * Nr)
                    row_elements = torch.randperm(Nr, device=device)[:number_row_elements]
                    self.W[layer][row, row_elements] = torch.rand(number_row_elements, device=device)*2 - 1

        # fully_connected reservoir neurons
        else:
            for layer in range(Nl):
                self.W[layer] = torch.rand(number_row_elements, device=device)*2 - 1

        # initialize layers
        for layer in range(Nl):

            # initializing weight vector
            # size is extended by 1 due to the bias term
            if layer == 0:
                self.Win[layer] = (torch.rand((Nr, Nu), device=device)*2 - 1)
            else:
                self.Win[layer] = (torch.rand((Nr, Nr), device=device)*2 - 1)

            # Ws = (1 - target_li) * torch.eye(Nr, device=device) + target_li * .W[layer]
            # eig_value, _ = torch.linalg.eig(Ws)
            # actual_rho = torch.max(torch.abs(eig_value))

            # Ws = (Ws * target_rho) / actual_rho
            # self.W[layer] = (target_li ** -1) * (Ws - (1. - target_li) * torch.eye(Nr, device=device))

            #self.Gain[layer] = torch.ones((Nr, 1), device=device)
            #self.Bias[layer] = torch.zeros((Nr, 1), device=device)

    def computeLayerState(self, input, layer, initialStatesSpike, initialStatesLayer):

        # Compute the input for the current layer
        input = self.Win[layer]@input
    
        # Compute the new state using the leaky integration
        state = self.Gain*((self.leaky)*initialStatesLayer + (input + (self.W[layer]@initialStatesSpike)))
        # state = self.Gain*((1-self.leaky)*initialStatesLayer + 0.5*(input + (self.W[layer]@initialStatesSpike))) + self.Bias
    
        # Initialize the spike tensor
        spk = torch.zeros(self.Nr, device=self.device)
    
        # Determine where the state exceeds the threshold
        spk = (state > self.threshold).float()
    
        # Reset the state values that exceed the threshold to 0
        state[state > self.threshold] =0
        return state, spk


    def computeState(self, inputs, initialStates=None):
        spikes = []
        states = []
        # print(self.W)
        # print(self.Win)
        
        for i_seq in range(len(inputs)):
            spike, state = self.computeGlobalState(inputs[i_seq])
            spikes.append(spike)
            states.append(state)
            if i_seq % 100 == 0:
                print("Number of Calculated: ", i_seq)

        # Convert the states list to a PyTorch tensor
        return torch.stack(spikes).to(self.device), torch.stack(states).to(self.device)

    def computeGlobalState(self, input):
        state = torch.zeros((self.Nl * self.Nr), device=self.device)
        spike = torch.zeros((self.Nl * self.Nr), device=self.device)
        out_spk = torch.zeros((10, self.Nl * self.Nr), device=self.device)
        out_state = torch.zeros((10, self.Nl * self.Nr), device=self.device)
        
        for step in range(self.spike_train):
            for i in range(7):           
                parsedinput = input[step, 4*i: 4*(i+1), :].flatten().to(self.device)
                for layer in range(self.Nl):
                    initialStatesLayer = state[layer*self.Nr: (layer+1)*self.Nr]
                    initialStatesSpike = spike[layer*self.Nr: (layer+1)*self.Nr]
                    state[layer*self.Nr:(layer+1)*self.Nr], spike[layer*self.Nr:(layer+1)*self.Nr] = self.computeLayerState(parsedinput, layer, initialStatesSpike, initialStatesLayer)
                    parsedinput = state[layer * self.Nr:(layer + 1) * self.Nr]
                if i == 6: 
                    out_spk[step] = spike
                    out_state[step] = state
                    
        # print(torch.stack(out_spk, dim=0).shape)
        # shape of out_spk is [5, 800, 1]
        # Convert the state to a PyTorch tensor
        return out_spk, out_state

    
    
class Readout(nn.Module):
    def __init__(self, Nr, Nl, configs):
        super().__init__()
        self.fc1 = nn.Linear(Nr*Nl, 10)
        self.lif1 = snn.Leaky(beta=0.8, threshold=1, output=True)
        self.spike_train = configs.snn.spike_train
        
    def forward(self, x):
        
        # initialize membrane potential
        mem1 = self.lif1.init_leaky()
        
        spk1_rec = []
        mem1_rec = []
        
        for step in range(self.spike_train):
            cur1 = self.fc1(x[:, step,:])
            spk1, mem1 = self.lif1(cur1, mem1)
            spk1_rec.append(spk1)
            mem1_rec.append(mem1)
        
        spk1_rec = torch.stack(spk1_rec, dim=0)
        mem1_rec = torch.stack(mem1_rec, dim=0)
            
        return spk1_rec, mem1_rec

# Configurations Setup

In [9]:
class Struct(object): pass

def config_CIFAR10(IP_indexes):

    configs = Struct()


    configs.reservoirConf = Struct()
    configs.reservoirConf.connectivity = .3

    configs.snn = Struct()
    configs.snn.leaky = 0.8
    configs.snn.Gain = 1
    configs.snn.Bias = 0
    configs.snn.threshold = 1
    configs.snn.spike_train = 10

    return configs

# Load Dataset

In [10]:
import torch
import torchvision
from torchvision import datasets, transforms
import functools

steps = 10

class Struct:
    pass

def load_MNIST():
    # Load the MNIST dataset
    transform = transforms.ToTensor()
    train_dataset = datasets.MNIST('./datasets', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./datasets', train=False, download=True, transform=transform)

    print("First train sample shape:", train_dataset[0][0].shape)
    print("First train label:", train_dataset[0][1])

    # Split train_dataset into training and validation sets
    train_size = int(0.8 * len(train_dataset))
    validation_size = len(train_dataset) - train_size
    train_dataset, validation_dataset = torch.utils.data.random_split(train_dataset, [train_size, validation_size])

    print("Training dataset size:", len(train_dataset))
    print("Validation dataset size:", len(validation_dataset))

    # Prepare the dataset object
    dataset = Struct()
    dataset.name = 'MNIST'
    dataset.inputs = torch.stack([spikegen.rate(train_dataset[i][0].squeeze().T, num_steps=steps) for i in range(len(train_dataset))])
    dataset.targets = torch.stack([torch.tensor(train_dataset[i][1]) for i in range(len(train_dataset))]).view(-1, 1)

    print("First input sample shape:", dataset.inputs[0].shape)
    print("First target sample shape:", dataset.targets[0].shape)

    # Add validation data
    validation_inputs = torch.stack([spikegen.rate(validation_dataset[i][0].squeeze().T, num_steps=steps) for i in range(len(validation_dataset))])
    validation_targets = torch.stack([torch.tensor(validation_dataset[i][1]) for i in range(len(validation_dataset))]).view(-1, 1)
    dataset.inputs = torch.cat((dataset.inputs, validation_inputs))
    dataset.targets = torch.cat((dataset.targets, validation_targets))

    print("Total dataset size (after adding validation):", len(dataset.inputs))

    # Add test data
    test_inputs = torch.stack([spikegen.rate(test_dataset[i][0].squeeze().T, num_steps=steps) for i in range(len(test_dataset))])
    test_targets = torch.stack([torch.tensor(test_dataset[i][1]) for i in range(len(test_dataset))]).view(-1, 1)
    dataset.inputs = torch.cat((dataset.inputs, test_inputs))
    dataset.targets = torch.cat((dataset.targets, test_targets))

    print("Total dataset size (after adding test):", len(dataset.inputs))

    print("input shape:", dataset.inputs.shape)
    print("target shape:", dataset.targets.shape)

    # Input dimension
    Nu = 28*4



    # Define indexes for training, validation, and test sets
    TR_indexes = range(train_size)
    VL_indexes = range(train_size, train_size + validation_size)
    TS_indexes = range(train_size + validation_size, train_size + validation_size + len(test_dataset))

    return dataset, Nu, TR_indexes, VL_indexes, TS_indexes

# Test the load_MNIST function
dataset, Nu, TR_indexes, VL_indexes, TS_indexes = load_MNIST()


First train sample shape: torch.Size([1, 28, 28])
First train label: 5
Training dataset size: 48000
Validation dataset size: 12000
First input sample shape: torch.Size([10, 28, 28])
First target sample shape: torch.Size([1])
Total dataset size (after adding validation): 60000
Total dataset size (after adding test): 70000
input shape: torch.Size([70000, 10, 28, 28])
target shape: torch.Size([70000, 1])


# Extra Functions

In [11]:
def select_indexes(data, indexes):

    if len(data) == 1:
        return [data[0]]

    return [data[i] for i in indexes]


# Main Function

In [12]:
# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the CIFAR-10 dataset and related configurations
dataset, Nu, TR_indexes, VL_indexes, TS_indexes = load_MNIST()
print(dataset.inputs.shape)
# Load configuration for CIFAR-10 task
configs = config_CIFAR10(list(TR_indexes) + list(VL_indexes))

# Define parameters for DeepESN
Nr = 400  # Number of recurrent units
Nl = 2    # Number of recurrent layers
reg = 0

# Create an instance of DeepESN
deepESN = DeepESN(Nu, Nr, Nl, configs, device)
net = Readout(Nr, Nl, configs)

# Compute states for the entire dataset
spikes, states = deepESN.computeState(dataset.inputs.to(device))
states = states.to('cpu') 


torch.set_printoptions(threshold=float('inf'))
print(spikes[0])
print(states[0])

# Select training and test states and targets using their respective indexes
train_states = select_indexes(spikes, list(TR_indexes) + list(VL_indexes))
test_states = select_indexes(spikes, TS_indexes)
train_targets = select_indexes(dataset.targets, list(TR_indexes) + list(VL_indexes))
test_targets = select_indexes(dataset.targets, TS_indexes)

# train_states = select_indexes(states, list(TR_indexes)[0:800])
# train_targets = select_indexes(dataset.targets, list(TR_indexes)[0:800])
# test_states = select_indexes(states, list(TR_indexes)[800:1000])
# test_targets = select_indexes(dataset.targets, list(TR_indexes)[800:1000])

# Reshape train_targets and test_targets to the required dimensions
train_targets = torch.tensor(train_targets)
# train_targets = train_targets.reshape(800, 1)
test_targets = torch.tensor(test_targets)
# test_targets = test_targets.reshape(200, 1)


Using device: cpu
First train sample shape: torch.Size([1, 28, 28])
First train label: 5
Training dataset size: 48000
Validation dataset size: 12000
First input sample shape: torch.Size([10, 28, 28])
First target sample shape: torch.Size([1])
Total dataset size (after adding validation): 60000
Total dataset size (after adding test): 70000
input shape: torch.Size([70000, 10, 28, 28])
target shape: torch.Size([70000, 1])
torch.Size([70000, 10, 28, 28])
Number of Calculated:  0
Number of Calculated:  100
Number of Calculated:  200


KeyboardInterrupt: 

In [None]:
lr = 5*(10**-5)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = SF.loss.ce_count_loss()
    
num_epochs = 10
batch_size = 20
    
    # Training Session
for epoch in range(num_epochs):
    running_loss = 0.0
    
    for iter in range(int(60000/batch_size)):
        iter_spike = torch.stack(train_states[iter*batch_size : (iter+1)*batch_size])
        iter_target = train_targets[iter*batch_size : (iter+1)*batch_size]
        
        net.zero_grad()
        outputs, _ = net(iter_spike)
        loss = criterion(outputs, iter_target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
        if(iter + 1) % 1500 == 0:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, iter + 1, running_loss / 1500))
            running_loss = 0.0
        
print('Finished Training')
    
#Testing 10,000 Data from Dataset

outputs, _ = net(torch.stack(test_states))
print(outputs[:,0:1,:])
acc_rate = SF.acc.accuracy_rate(outputs, test_targets)
    
print(f"Accuracy: {acc_rate}")

[1,  1500] loss: 1.474
[1,  3000] loss: 0.929
[2,  1500] loss: 0.792
[2,  3000] loss: 0.720
[3,  1500] loss: 0.668
[3,  3000] loss: 0.630
[4,  1500] loss: 0.605
[4,  3000] loss: 0.579
[5,  1500] loss: 0.565
[5,  3000] loss: 0.547
[6,  1500] loss: 0.537
[6,  3000] loss: 0.523
[7,  1500] loss: 0.517
[7,  3000] loss: 0.504
[8,  1500] loss: 0.501
[8,  3000] loss: 0.489
[9,  1500] loss: 0.487
[9,  3000] loss: 0.478
[10,  1500] loss: 0.476
[10,  3000] loss: 0.468
Finished Training
tensor([[[0., 0., 0., 0., 0., 1., 0., 1., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 1., 0., 1., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 1., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 1., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 1., 0., 1.]],

        [[0., 0., 1., 0., 0., 0., 0., 1., 0