# Import Libraries

In [24]:
import numpy as np
import numpy.matlib as npm
# import scipy as sc
import random
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import utils
from snntorch import spikegen
import snntorch.functional as SF

# Define Class

In [25]:
class DeepESN():

    def __init__(self, Nu, Nr, Nl, configs, device='cpu'):
        
        # call hyperparameters from configs
        rhos = torch.tensor(configs.rhos, device=device)  # spectral radius
        lis = torch.tensor(configs.lis, device=device)    # leaky rate
        iss = torch.tensor(configs.iss, device=device)    # input scale (max, min of input)
        IPconf = configs.IPconf        # config for Deep Intrinsic Plasticity
        reservoirConf = configs.reservoirConf  # reservoir configurations

        if rhos.dim() == 0:
            rhos = rhos.repeat(Nl)

        if lis.dim() == 0:
            lis = lis.repeat(Nl)

        if iss.dim() == 0:
            iss = iss.repeat(Nl)

        self.W = {}
        self.Win = {}
        #self.Gain = {}
        #self.Bias = {}

        self.Nu = Nu  # number of inputs
        self.Nr = Nr  # number of units per layer
        self.Nl = Nl  # number of layers
        self.rhos = rhos.tolist()
        self.lis = lis
        self.iss = iss

        self.IPconf = IPconf

        self.readout = configs.readout
        self.device = device

        self.leaky = configs.snn.leaky
        self.Gain = configs.snn.Gain
        self.Bias = configs.snn.Bias
        self.spike_train = configs.snn.spike_train
        self.threshold = configs.snn.threshold
        self.number_row_elements = round(reservoirConf.connectivity * Nr)
        

        # if reservoir neurons are not fully connected
        if reservoirConf.connectivity < 1:
            for layer in range(Nl):
                self.W[layer] = torch.zeros((Nr, Nr), device=device)
                for row in range(Nr):
                    number_row_elements = round(reservoirConf.connectivity * Nr)
                    row_elements = torch.randperm(Nr, device=device)[:number_row_elements]
                    self.W[layer][row, row_elements] = torch.rand(number_row_elements, device=device)*2 - 1

        # fully_connected reservoir neurons
        else:
            for layer in range(Nl):
                self.W[layer] = torch.rand(number_row_elements, device=device)*2 - 1

        # initialize layers
        for layer in range(Nl):

            target_li = lis[layer]
            target_rho = rhos[layer]
            input_scale = iss[layer]

            # initializing weight vector
            # size is extended by 1 due to the bias term
            if layer == 0:
                self.Win[layer] = (torch.rand((Nr, Nu), device=device)*2 - 1)
            else:
                self.Win[layer] = (torch.rand((Nr, Nr), device=device)*2 - 1)

            # Ws = (1 - target_li) * torch.eye(Nr, device=device) + target_li * self.W[layer]
            # eig_value, _ = torch.linalg.eig(Ws)
            # actual_rho = torch.max(torch.abs(eig_value))

            # Ws = (Ws * target_rho) / actual_rho
            # self.W[layer] = (target_li ** -1) * (Ws - (1. - target_li) * torch.eye(Nr, device=device))

            #self.Gain[layer] = torch.ones((Nr, 1), device=device)
            #self.Bias[layer] = torch.zeros((Nr, 1), device=device)

    def computeLayerState(self, input, layer, initialStatesSpike, initialStatesLayer):

        # Compute the input for the current layer
        input = self.Win[layer]@input
    
        # Compute the new state using the leaky integration
        state = self.Gain*((self.leaky)*initialStatesLayer + (1-self.leaky)*(input + (self.W[layer]@initialStatesSpike)))+self.Bias
        # state = self.Gain*((1-self.leaky)*initialStatesLayer + 0.5*(input + (self.W[layer]@initialStatesSpike))) + self.Bias
    
        # Initialize the spike tensor
        spk = torch.zeros(self.Nr, device=self.device)
    
        # Determine where the state exceeds the threshold
        spk = (state > self.threshold).float()
    
        # Reset the state values that exceed the threshold to 0
        state[state > self.threshold] = 0
        return state, spk


    def computeState(self, inputs, initialStates=None):
        spikes = []
        states = []
        # print(self.W)
        # print(self.Win)
        
        for i_seq in range(len(inputs)):
            spike, state = self.computeGlobalState(inputs[i_seq])
            spikes.append(spike)
            states.append(state)
            if i_seq % 100 == 0:
                print("Number of Calculated: ", i_seq)

        # Convert the states list to a PyTorch tensor
        return torch.stack(spikes).to(self.device), torch.stack(states).to(self.device)

    def computeGlobalState(self, input):
        state = torch.zeros((self.Nl * self.Nr), device=self.device)
        spike = torch.zeros((self.Nl * self.Nr), device=self.device)
        out_spk = torch.zeros((14, self.Nl * self.Nr), device=self.device)
        out_state = torch.zeros((14, self.Nl * self.Nr), device=self.device)

        for i in range(14):
            parsedinput = input[2*i: 2*(i+1), :].flatten().to(self.device)
            for layer in range(self.Nl):
                initialStatesLayer = state[layer*self.Nr: (layer+1)*self.Nr]
                initialStatesSpike = spike[layer*self.Nr: (layer+1)*self.Nr]
                state[layer*self.Nr:(layer+1)*self.Nr], spike[layer*self.Nr:(layer+1)*self.Nr] = self.computeLayerState(parsedinput, layer, initialStatesSpike, initialStatesLayer)
                parsedinput = state[layer * self.Nr:(layer + 1) * self.Nr]
            out_spk[i] = spike
            out_state[i] = state
        
                    
        # print(torch.stack(out_spk, dim=0).shape)
        # shape of out_spk is [5, 800, 1]
        # Convert the state to a PyTorch tensor
        return out_spk, out_state

    
    
class Readout(nn.Module):
    def __init__(self, Nr, Nl, configs):
        super().__init__()
        self.fc1 = nn.Linear(Nr*Nl, 10)
        self.lif1 = snn.Leaky(beta=0.75, threshold=1)
        self.spike_train = configs.snn.spike_train
        
    def forward(self, x):
        
        # initialize membrane potential
        mem1 = self.lif1.init_leaky()
        
        spk1_rec = []
        mem1_rec = []
        
        for step in range(self.spike_train):
            cur1 = self.fc1(x[:, step,:])
            spk1, mem1 = self.lif1(cur1, mem1)
            spk1_rec.append(spk1)
            mem1_rec.append(mem1)
        
        spk1_rec = torch.stack(spk1_rec, dim=0).squeeze()
        mem1_rec = torch.stack(mem1_rec, dim=0).squeeze()
            
        return spk1_rec, mem1_rec

# Configurations Setup

In [26]:
class Struct(object): pass

def config_CIFAR10(IP_indexes):

    configs = Struct()

    # Hyperparameters Setup
    configs.rhos = 0.98
    configs.lis = 0.03
    configs.iss = 0.4

    # IP training part
    configs.IPconf = Struct()
    configs.IPconf.DeepIP = 0
    configs.IPconf.threshold = 0.1
    configs.IPconf.eta = 10 ** -5
    configs.IPconf.mu = 0
    configs.IPconf.sigma = 0.1
    configs.IPconf.Nepochs = 10
    configs.IPconf.indexes = IP_indexes

    configs.reservoirConf = Struct()
    configs.reservoirConf.connectivity = 0.3

    configs.readout = Struct()
    configs.readout.trainMethod = 'SVD'
    configs.readout.regularizations = 10.0 ** np.array(range(-4, -1, 1))

    configs.snn = Struct()
    configs.snn.leaky = 0.75
    configs.snn.Gain = 1
    configs.snn.Bias = 0
    configs.snn.threshold = 1.0
    configs.snn.spike_train = 14

    return configs

# Load Dataset

In [27]:
import torch
import torchvision
from torchvision import datasets, transforms
import functools

class Struct:
    pass

def metric_function(prediction):
    # Example metric function (to be defined based on the specific task)
    return torch.argmax(prediction, dim=1)

def one_hot_encode(target, num_classes):
    return torch.nn.functional.one_hot(target, num_classes=num_classes).float()

def permute_image(tensor):
    # Permute tensor from (3, 32, 32) to (32, 3, 32)
    return tensor.permute(0, 1, 2)

import torch
from torchvision import datasets, transforms
import functools

def load_MNIST(metric_function):
    # Load the MNIST dataset
    transform = transforms.ToTensor()
    train_dataset = datasets.MNIST('./datasets', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./datasets', train=False, download=True, transform=transform)

    print("First train sample shape:", train_dataset[0][0].shape)
    print("First train label:", train_dataset[0][1])

    # Split train_dataset into training and validation sets
    train_size = int(0.8 * len(train_dataset))
    validation_size = len(train_dataset) - train_size
    train_dataset, validation_dataset = torch.utils.data.random_split(train_dataset, [train_size, validation_size])

    print("Training dataset size:", len(train_dataset))
    print("Validation dataset size:", len(validation_dataset))

    # Prepare the dataset object
    class Struct:
        pass

    dataset = Struct()
    dataset.name = 'MNIST'
    dataset.inputs = torch.stack([train_dataset[i][0].squeeze().T for i in range(len(train_dataset))])
    dataset.targets = torch.stack([torch.tensor(train_dataset[i][1]) for i in range(len(train_dataset))]).view(-1, 1)

    print("First input sample shape:", dataset.inputs[0].shape)
    print("First target sample shape:", dataset.targets[0].shape)

    # Add validation data
    validation_inputs = torch.stack([validation_dataset[i][0].squeeze().T for i in range(len(validation_dataset))])
    validation_targets = torch.stack([torch.tensor(validation_dataset[i][1]) for i in range(len(validation_dataset))]).view(-1, 1)
    dataset.inputs = torch.cat((dataset.inputs, validation_inputs))
    dataset.targets = torch.cat((dataset.targets, validation_targets))

    print("Total dataset size (after adding validation):", len(dataset.inputs))

    # Add test data
    test_inputs = torch.stack([test_dataset[i][0].squeeze().T for i in range(len(test_dataset))])
    test_targets = torch.stack([torch.tensor(test_dataset[i][1]) for i in range(len(test_dataset))]).view(-1, 1)
    dataset.inputs = torch.cat((dataset.inputs, test_inputs))
    dataset.targets = torch.cat((dataset.targets, test_targets))

    print("Total dataset size (after adding test):", len(dataset.inputs))

    print("input shape:", dataset.inputs.shape)
    print("target shape:", dataset.targets.shape)

    # Input dimension
    Nu = 28 * 2

    # Function used for model evaluation
    error_function = functools.partial(metric_function, threshold=0.5)

    # Select the model that achieves the maximum accuracy on validation set
    optimization_problem = torch.argmax

    # Define indexes for training, validation, and test sets
    TR_indexes = range(train_size)
    VL_indexes = range(train_size, train_size + validation_size)
    TS_indexes = range(train_size + validation_size, train_size + validation_size + len(test_dataset))

    return dataset, Nu, error_function, optimization_problem, TR_indexes, VL_indexes, TS_indexes


# Test the load_MNIST function
dataset, Nu, error_function, optimization_problem, TR_indexes, VL_indexes, TS_indexes = load_MNIST(metric_function)


First train sample shape: torch.Size([1, 28, 28])
First train label: 5
Training dataset size: 48000
Validation dataset size: 12000
First input sample shape: torch.Size([28, 28])
First target sample shape: torch.Size([1])
Total dataset size (after adding validation): 60000
Total dataset size (after adding test): 70000
input shape: torch.Size([70000, 28, 28])
target shape: torch.Size([70000, 1])


# Extra Functions

In [28]:
def select_indexes(data, indexes):

    if len(data) == 1:
        return [data[0]]

    return [data[i] for i in indexes]


def computeMNISTAccuracy(predictions, targets):
    # Convert predictions to class indices
    predicted_classes = torch.argmax(predictions, dim=1)

    # Calculate accuracy
    correct = (predicted_classes == targets).sum().item()
    accuracy = correct / len(targets)
    return accuracy

# Example usage:
# Assuming `model` is your neural network, `data_loader` is your data loader for the MNIST dataset

def evaluate_model(model, data_loader):
    model.eval()  # Set the model to evaluation mode
    all_predictions = []
    all_targets = []

    with torch.no_grad():  # Disable gradient calculation
        for inputs, targets in data_loader:
            inputs, targets = inputs.cuda(), targets.cuda()  # Move data to GPU if available
            outputs = model(inputs)
            all_predictions.append(outputs)
            all_targets.append(targets)

    all_predictions = torch.cat(all_predictions)
    all_targets = torch.cat(all_targets)
    accuracy = computeMNISTAccuracy(all_predictions, all_targets)
    return accuracy

# Example of how to use evaluate_model:
# Assuming you have a trained model and a data loader for the test dataset
# test_loader = DataLoader(test_data, batch_size=64, shuffle=False)
# model = YourModelClass().cuda()
# test_accuracy = evaluate_model(model, test_loader)
# print(f'Test Accuracy: {test_accuracy * 100:.2f}%')


# Main Function

In [29]:
# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the CIFAR-10 dataset and related configurations
dataset, Nu, error_function, optimization_problem, TR_indexes, VL_indexes, TS_indexes = load_MNIST(metric_function)
print("Optimization Problem:", optimization_problem)
print(dataset.inputs.shape)
# Load configuration for CIFAR-10 task
configs = config_CIFAR10(list(TR_indexes) + list(VL_indexes))

# Define parameters for DeepESN
Nr = 400  # Number of recurrent units
Nl = 2    # Number of recurrent layers
reg = 0

# Create an instance of DeepESN
deepESN = DeepESN(Nu, Nr, Nl, configs, device)
net = Readout(Nr, Nl, configs)

# Compute states for the entire dataset
spikes, states = deepESN.computeState(dataset.inputs.to(device))
states = states.to('cpu')
states = states
torch.set_printoptions(threshold=float('inf'))
print(torch.tensor(spikes).shape[0])
print(spikes[0])
print(states[0])

# Select training and test states and targets using their respective indexes
train_states = select_indexes(spikes, list(TR_indexes) + list(VL_indexes))
test_states = select_indexes(spikes, TS_indexes)
train_targets = select_indexes(dataset.targets, list(TR_indexes) + list(VL_indexes))
test_targets = select_indexes(dataset.targets, TS_indexes)

# train_states = select_indexes(states, list(TR_indexes)[0:800])
# train_targets = select_indexes(dataset.targets, list(TR_indexes)[0:800])
# test_states = select_indexes(states, list(TR_indexes)[800:1000])
# test_targets = select_indexes(dataset.targets, list(TR_indexes)[800:1000])

# Reshape train_targets and test_targets to the required dimensions
train_targets = torch.tensor(train_targets)
# train_targets = train_targets.reshape(800, 1)
test_targets = torch.tensor(test_targets)
# test_targets = test_targets.reshape(200, 1)


Using device: cpu
First train sample shape: torch.Size([1, 28, 28])
First train label: 5
Training dataset size: 48000
Validation dataset size: 12000
First input sample shape: torch.Size([28, 28])
First target sample shape: torch.Size([1])
Total dataset size (after adding validation): 60000
Total dataset size (after adding test): 70000
input shape: torch.Size([70000, 28, 28])
target shape: torch.Size([70000, 1])
Optimization Problem: <built-in method argmax of type object at 0x00007FFF853E1D40>
torch.Size([70000, 28, 28])
Number of Calculated:  0
Number of Calculated:  100
Number of Calculated:  200
Number of Calculated:  300
Number of Calculated:  400
Number of Calculated:  500
Number of Calculated:  600
Number of Calculated:  700
Number of Calculated:  800
Number of Calculated:  900
Number of Calculated:  1000
Number of Calculated:  1100
Number of Calculated:  1200
Number of Calculated:  1300
Number of Calculated:  1400
Number of Calculated:  1500
Number of Calculated:  1600
Number of

  print(torch.tensor(spikes).shape[0])


70000
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0

In [30]:
lr = 5*(10**-5)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = SF.loss.ce_count_loss()
    
num_epochs = 10
batch_size = 20
    
    # Training Session
for epoch in range(num_epochs):
    running_loss = 0.0
    
    for iter in range(int(60000/batch_size)):
        iter_spike = torch.stack(train_states[iter*batch_size : (iter+1)*batch_size])
        iter_target = train_targets[iter*batch_size : (iter+1)*batch_size]
        
        net.zero_grad()
        outputs, _ = net(iter_spike)
        loss = criterion(outputs, iter_target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
        if(iter + 1) % 1500 == 0:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, iter + 1, running_loss / 1500))
            running_loss = 0.0
        
print('Finished Training')
    
#Testing 10,000 Data from Dataset

outputs, _ = net(torch.stack(test_states))
print(outputs[:,0:1,:])
acc_rate = SF.acc.accuracy_rate(outputs, test_targets)
    
print(f"Accuracy: {acc_rate}")

[1,  1500] loss: 1.173
[1,  3000] loss: 0.734
[2,  1500] loss: 0.621
[2,  3000] loss: 0.559
[3,  1500] loss: 0.513
[3,  3000] loss: 0.484
[4,  1500] loss: 0.457
[4,  3000] loss: 0.442
[5,  1500] loss: 0.422
[5,  3000] loss: 0.413
[6,  1500] loss: 0.397
[6,  3000] loss: 0.391
[7,  1500] loss: 0.379
[7,  3000] loss: 0.374
[8,  1500] loss: 0.364
[8,  3000] loss: 0.362
[9,  1500] loss: 0.352
[9,  3000] loss: 0.350
[10,  1500] loss: 0.342
[10,  3000] loss: 0.341
Finished Training
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 1., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 1., 1., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0., 1., 0., 1., 0., 0.]],

        [[0., 0., 0., 1., 0., 1., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 1., 1., 0., 0., 0., 1., 0