<a href="https://colab.research.google.com/github/ergysmedaunipd/thesis/blob/main/ThesisUnipdSNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install snntorch --quiet

In [None]:
import torch
import torch.nn as nn
import snntorch as snn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import snntorch.functional as SF
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

class ADMM_NN:
    """ Class for ADMM Neural Network. """

    def __init__(self, n_inputs, n_hiddens, n_outputs, n_batches,delta,theta,timestep):
        """
        Initialize variables for NN.
        Raises:
            ValueError: Column input samples, for example, the input size of MNIST data should be (28x28, *) instead of (*, 28x28).
        :param n_inputs: Number of inputs.
        :param n_hiddens: Number of hidden units.
        :param n_outputs: Number of outputs
        :param n_batches: Number of data sample that you want to train
        :param return:
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.delta = delta
        self.T = timestep
        self.theta = theta
        self.a0 = torch.zeros((n_inputs, n_batches), device=self.device)

        self.w1 = torch.zeros((n_hiddens, n_inputs), device=self.device)
        self.w2 = torch.zeros((n_hiddens, n_hiddens), device=self.device)
        self.w3 = torch.zeros((n_outputs, n_hiddens), device=self.device)

        self.z1 = torch.rand((n_hiddens, n_batches), device=self.device)
        self.a1 = torch.rand((n_hiddens, n_batches), device=self.device)

        self.z2 = torch.rand((n_hiddens, n_batches), device=self.device)
        self.a2 = torch.rand((n_hiddens, n_batches), device=self.device)

        self.z3 = torch.rand((n_outputs, n_batches), device=self.device)

        self.fc1 = nn.Linear(n_inputs, n_hiddens).to(self.device)
        self.lif1 = snn.Leaky(beta=self.delta, threshold=self.theta).to(self.device)

        self.fc2 = nn.Linear(n_hiddens, n_hiddens).to(self.device)
        self.lif2 = snn.Leaky(beta=self.delta, threshold=self.theta).to(self.device)

        self.fc3 = nn.Linear(n_hiddens, n_outputs).to(self.device)
        self.lif3 = snn.Leaky(beta=self.delta, threshold=self.theta).to(self.device)

        self.lambda_larange = torch.ones((n_outputs, n_batches)).to(self.device)

    def __str__(self):
        model_str = "ADMM SNN Model Structure:\n"
        model_str += f" - Number of timesteps: {self.T}\n"
        model_str += f" - Input dimension: {self.a0.size()}\n"
        model_str += f" - W1 : {self.w1.shape}\n"
        model_str += f" - W2 : {self.w2.shape}\n"
        model_str += f" - W3 : {self.w3.shape}\n"
        model_str += f" - z1 : {self.z1.shape}\n"
        model_str += f" - z2 : {self.z2.shape}\n"
        model_str += f" - z3 : {self.z3.shape}\n"
        model_str += f" - a0 : {self.a0.shape}\n"
        model_str += f" - a1 : {self.a1.shape}\n"
        model_str += f" - a2 : {self.a2.shape}\n"
        model_str += f" - fc1 : {self.fc1}\n"
        model_str += f" - lif1 : {self.lif1}\n"
        model_str += f" - fc2 : {self.fc2}\n"
        model_str += f" - lif2 : {self.lif2}\n"
        model_str += f" - fc3 : {self.fc3}\n"
        model_str += f" - lif3 : {self.lif3}\n"
        model_str += f" - Output dimension (Lagrange Multiplier): {self.lambda_larange.size()}\n"

        """Helper method to print shapes of initialized tensors"""
        print(f"\nModel Initialization Details:")


        print(f"lambda shape: {self.lambda_larange.shape}")
        return model_str
    def _relu(self, x):
        """
        Relu activation function
        :param x: input x
        :return: max 0 and x
        """
        return F.relu(x)

    def _weight_update(self, layer_output, activation_input):
        """
        Consider it now the minimization of the problem with respect to W_l.
        For each layer l, the optimal solution minimizes ||z_l - W_l a_l-1||^2. This is simply
        a least square problem, and the solution is given by W_l = z_l p_l-1, where p_l-1
        represents the pseudo-inverse of the rectangular activation matrix a_l-1.
        :param layer_output: output matrix (z_l)
        :param activation_input: activation matrix l-1  (a_l-1)
        :return: weight matrix
        """
        pinv = torch.pinverse(activation_input)
        weight_matrix = torch.mm(layer_output.float(), pinv.float())
        return weight_matrix

    def _activation_update(self, next_weight, next_layer_output, layer_nl_output, beta, gamma):
        """
        Minimization for a_l is a simple least squares problem similar to the weight update.
        However, in this case the matrix appears in two penalty terms in the problem, and so
        we must minimize:
            beta ||z_l+1 - W_l+1 a_l||^2 + gamma ||a_l - h(z_l)||^2
        :param next_weight:  weight matrix l+1 (w_l+1)
        :param next_layer_output: output matrix l+1 (z_l+1)
        :param layer_nl_output: activate output matrix h(z) (h(z_l))
        :param beta: value of beta
        :param gamma: value of gamma
        :return: activation matrix
        """
        # Calculate ReLU
        layer_nl_output = self._relu(layer_nl_output)

        # Activation inverse
        m1 = beta * torch.mm(next_weight.t(), next_weight)
        m2 = gamma * torch.eye(m1.shape[0], device=m1.device)
        av = torch.inverse(m1.float() + m2.float())

        # Activation formula
        m3 = beta * torch.mm(next_weight.t(), next_layer_output)
        m4 = gamma * layer_nl_output
        af = m3.float() + m4.float()

        # Output
        return torch.mm(av, af)

    def _argminz(self, a, w, a_in, beta, gamma):
        """
        This problem is non-convex and non-quadratic (because of the non-linear term h).
        Fortunately, because the non-linearity h works entry-wise on its argument, the entries
        in z_l are decoupled. This is particularly easy when h is piecewise linear, as it can
        be solved in closed form; common piecewise linear choices for h include rectified
        linear units (ReLUs), that its used here, and non-differentiable sigmoid functions.
        :param a: activation matrix (a_l)
        :param w:  weight matrix (w_l)
        :param a_in: activation matrix l-1 (a_l-1)
        :param beta: value of beta
        :param gamma: value of gamma
        :return: output matrix
        """
        m = torch.mm(w.float(), a_in.float())
        sol1 = (gamma * a + beta * m) / (gamma + beta)
        sol2 = m

        z1 = torch.zeros_like(a)
        z2 = torch.zeros_like(a)
        z = torch.zeros_like(a)

        z1[sol1 >= 0] = sol1[sol1 >= 0]
        z2[sol2 <= 0] = sol2[sol2 <= 0]

        fz_1 = gamma * (a - self._relu(z1)).pow(2) + beta * (z1 - m).pow(2)
        fz_2 = gamma * (a - self._relu(z2)).pow(2) + beta * (z2 - m).pow(2)

        index_z1 = fz_1 <= fz_2
        index_z2 = fz_2 < fz_1

        z[index_z1] = z1[index_z1]
        z[index_z2] = z2[index_z2]

        return z

    def _argminlastz(self, targets, eps, w, a_in, beta):
        """
        Minimization of the last output matrix, using the above function.
        :param targets: target matrix (equal dimensions of z) (y)
        :param eps: lagrange multiplier matrix (equal dimensions of z) (lambda)
        :param w: weight matrix (w_l)
        :param a_in: activation matrix l-1 (a_l-1)
        :param beta: value of beta
        :return: output matrix last layer
        """
        m = torch.mm(w.float(), a_in.float())
        z = (targets - eps + beta * m) / (1 + beta)
        return z

    def _lambda_update(self, zl, w, a_in, beta):
        """
        Lagrange multiplier update.
        :param zl: output matrix last layer (z_L)
        :param w: weight matrix last layer (w_L)
        :param a_in: activation matrix l-1 (a_L-1)
        :param beta: value of beta
        :return: lagrange update
        """
        mpt = torch.mm(w.float(), a_in.float())
        lambda_up = beta * (zl - mpt)
        return lambda_up
    def feed_forward(self, inputs):
        """
        Forward pass using ADMM weights and spiking dynamics.
        """
        # Initialize membrane potentials for spiking neurons
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        spk3_rec = []

        # Process inputs over timesteps
        for step in range(self.T):
            cur1 = torch.mm(self.w1, inputs)  # Replace fc1
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = torch.mm(self.w2, spk1)  # Replace fc2
            spk2, mem2 = self.lif2(cur2, mem2)

            cur3 = torch.mm(self.w3, spk2)  # Replace fc3
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)

        # Return final membrane potential at last timestep (batch_size, num_outputs)
        return mem3.T


    def fit(self, inputs, labels, beta, gamma):
        """
        Training ADMM Neural Network by minimizing sub-problems
        :param inputs: input of training data samples
        :param outputs: label of training data samples
        :param epochs: number of epochs
        :param beta: value of beta
        :param gamma: value of gamma
        :return: loss value
        """
        self.a0 = inputs.to(self.device)

        # Input layer
        self.w1 = self._weight_update(self.z1, self.a0)
        self.a1 = self._activation_update(self.w2, self.z2, self.z1, beta, gamma)
        self.z1 = self._argminz(self.a1, self.w1, self.a0, beta, gamma)

        # Hidden layer
        self.w2 = self._weight_update(self.z2, self.a1)
        self.a2 = self._activation_update(self.w3, self.z3, self.z2, beta, gamma)
        self.z2 = self._argminz(self.a2, self.w2, self.a1, beta, gamma)

        # Output layer
        self.w3 = self._weight_update(self.z3, self.a2)
        self.z3 = self._argminlastz(labels, self.lambda_larange, self.w3, self.a2, beta)
        self.lambda_larange = self._lambda_update(self.z3, self.w3, self.a2, beta)

        loss, accuracy = self.evaluate(inputs, labels)
        return loss, accuracy

    def evaluate(self, inputs, labels, isCategories=True):
        """
        Calculate loss and accuracy (only classification).
        """
        if labels.shape[0] != inputs.shape[0]:  # Ensure labels match input batch
            labels = labels.T

        forward = self.feed_forward(inputs)

        # Compute loss
        loss = torch.mean((forward - labels).pow(2))

        if isCategories:
            if labels.ndim == 1:
                # Labels are already class indices
                accuracy = (labels == torch.argmax(forward, dim=1)).float().mean()
            else:
                # Labels are one-hot encoded
                accuracy = (torch.argmax(labels, dim=1) == torch.argmax(forward, dim=1)).float().mean()
        else:
            accuracy = loss

        return loss, accuracy


    def warming(self, inputs, labels, epochs, beta, gamma):
        """
        Warming ADMM Neural Network by minimizing sub-problems without update lambda
        :param inputs: input of training data samples
        :param outputs: label of training data samples
        :param epochs: number of epochs
        :param beta: value of beta
        :param gamma: value of gamma
        :return:
        """
        self.a0 = inputs.to(self.device)
        for i in range(epochs):
            print(f"------ Warming: {i} ------")
            # Input layer
            self.w1 = self._weight_update(self.z1, self.a0)
            self.fc1.weight.data = self.w1
            self.a1 = self._activation_update(self.w2, self.z2, self.z1, beta, gamma)
            self.z1 = self._argminz(self.a1, self.w1, self.a0, beta, gamma)

            # Hidden layer
            self.w2 = self._weight_update(self.z2, self.a1)
            self.fc2.weight.data = self.w2
            self.a2 = self._activation_update(self.w3, self.z3, self.z2, beta, gamma)
            self.z2 = self._argminz(self.a2, self.w2, self.a1, beta, gamma)

            # Output layer
            self.w3 = self._weight_update(self.z3, self.a2)
            self.fc3.weight.data = self.w3
            self.z3 = self._argminlastz(labels, self.lambda_larange, self.w3, self.a2, beta)

    def drawcurve(self, train_, valid_, id, legend_1, legend_2):
        acc_train = np.array(train_.cpu()).flatten() if isinstance(train_, torch.Tensor) else np.array(train_).flatten()
        acc_test = np.array(valid_.cpu()).flatten() if isinstance(valid_, torch.Tensor) else np.array(valid_).flatten()

        plt.figure(id)
        plt.plot(acc_train, label=legend_1)
        plt.plot(acc_test, label=legend_2)
        plt.ylim(bottom=0)
        plt.legend(loc='upper left')
        plt.draw()
        plt.pause(0.001)
        return 0

In [None]:
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST data with proper transformations
mnist = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())

# Convert training data to PyTorch tensors and move to GPU
train_data = mnist.data.view(-1, 28 * 28).T.float().to(device)  # Reshape and convert to float32
train_labels = F.one_hot(mnist.targets, num_classes=10).T.float().to(device)  # One-hot encode labels

# Load validation data
mnist_valid = datasets.MNIST('./data', train=False, transform=transforms.ToTensor())

# Convert validation data to PyTorch tensors and move to GPU
valid_data = mnist_valid.data.view(-1, 28 * 28).T.float().to(device)  # Reshape and convert to float32
valid_labels = F.one_hot(mnist_valid.targets, num_classes=10).T.float().to(device)  # One-hot encode labels


In [None]:
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST data with proper transformations
mnist = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())

# Convert training data to PyTorch tensors and move to GPU
train_data = mnist.data.view(-1, 28 * 28).T.float().to(device)  # Reshape and convert to float32
train_labels = F.one_hot(mnist.targets, num_classes=10).T.float().to(device)  # One-hot encode labels

# Load validation data
mnist_valid = datasets.MNIST('./data', train=False, transform=transforms.ToTensor())

# Convert validation data to PyTorch tensors and move to GPU
valid_data = mnist_valid.data.view(-1, 28 * 28).T.float().to(device)  # Reshape and convert to float32
valid_labels = F.one_hot(mnist_valid.targets, num_classes=10).T.float().to(device)  # One-hot encode labels

n_inputs = 28*28  # MNIST image shape 28*28
n_outputs = 10    # MNIST classes from 0-9 digits
n_hiddens = 512   # number of neurons
n_batches = train_data.shape[1]  # 55000 number of samples for training
train_epochs = 100
warm_epochs = 10

# Hyperparameter grid
beta_values = [10]
gamma_values = [ 1]
delta_values = [ 0.99]
theta_values = [1]
timestep_values = [75]
patience = 5
# Store results
results = []

# Grid search
for timestep in timestep_values:
    for theta in theta_values:
        for delta in delta_values:
            for gamma in gamma_values:
                for beta in beta_values:
                    print(f"Testing: beta={beta}, gamma={gamma}, delta={delta}, theta={theta}, timestep={timestep}")

                    # Initialize model with current hyperparameters
                    model = ADMM_NN(n_inputs, n_hiddens, n_outputs, n_batches, delta, theta, timestep)

                    # Warming phase
                    model.warming(
                        train_data.clone().detach().to(device),
                        train_labels.clone().detach().to(device),
                        warm_epochs, beta, gamma
                    )

                    # Early stopping variables
                    best_acc = 0

                    patience_counter = 0

                    # Training phase
                    list_loss_train = []
                    list_loss_valid = []
                    list_accuracy_train = []
                    list_accuracy_valid = []
                    for epoch in range(train_epochs):
                        _, accuracy_train = model.fit(
                            train_data.clone().detach().to(device),
                            train_labels.clone().detach().to(device),
                            beta, gamma
                        )
                        _, accuracy_valid = model.evaluate(
                            valid_data.clone().detach().to(device),
                            valid_labels.clone().detach().to(device)
                        )

                        print(f"------ Training Epoch: {epoch}  accuracy train: {accuracy_train:.3f}, accuracy valid: {accuracy_valid:.3f}")

                        # Append  accuracy
                        list_accuracy_train.append(accuracy_train)
                        list_accuracy_valid.append(accuracy_valid)

                        # Early stopping logic
                        if accuracy_valid > best_acc:
                            best_acc = accuracy_valid
                            patience_counter = 0  # Reset patience counter
                        else:
                            patience_counter += 1

                        if patience_counter >= patience:
                            print(f"Early stopping triggered at epoch {epoch}")
                            break

                    model.drawcurve(list_accuracy_train, list_accuracy_valid, 2, 'acc_train', 'acc_valid')

                    # Evaluate on test set
                    test_data = mnist_valid.test_data.numpy().reshape(-1, 28*28).T.astype(np.float32)
                    test_labels = F.one_hot(mnist_valid.test_labels, num_classes=10).numpy().T.astype(np.float32)
                    loss_test, accuracy_test = model.evaluate(
                        torch.from_numpy(test_data).to(device),
                        torch.from_numpy(test_labels).to(device)
                    )
                    print(f"Final Test Accuracy: {accuracy_test.item():.3f}")

                    # Store results
                    results.append({
                        "beta": beta,
                        "gamma": gamma,
                        "delta": delta,
                        "theta": theta,
                        "timestep": timestep,
                        "loss_test": loss_test.item(),
                        "accuracy_test": accuracy_test.item()
                    })

# Display the best result
best_result = max(results, key=lambda x: x['accuracy_test'])
print(f"\nBest Hyperparameters:")
print(f"Beta: {best_result['beta']}")
print(f"Gamma: {best_result['gamma']}")
print(f"Delta: {best_result['delta']}")
print(f"Theta: {best_result['theta']}")
print(f"Timestep: {best_result['timestep']}")
print(f"Test Accuracy: {best_result['accuracy_test']:.3f}")


Testing: beta=10, gamma=1, delta=0.99, theta=1, timestep=75
------ Warming: 0 ------
------ Warming: 1 ------
------ Warming: 2 ------
------ Warming: 3 ------
------ Warming: 4 ------
------ Warming: 5 ------
------ Warming: 6 ------
------ Warming: 7 ------
------ Warming: 8 ------
------ Warming: 9 ------
------ Training Epoch: 0  accuracy train: 0.133, accuracy valid: 0.131
------ Training Epoch: 1  accuracy train: 0.152, accuracy valid: 0.148
------ Training Epoch: 2  accuracy train: 0.181, accuracy valid: 0.181
------ Training Epoch: 3  accuracy train: 0.231, accuracy valid: 0.233
------ Training Epoch: 4  accuracy train: 0.336, accuracy valid: 0.333
------ Training Epoch: 5  accuracy train: 0.465, accuracy valid: 0.468
------ Training Epoch: 6  accuracy train: 0.598, accuracy valid: 0.603
------ Training Epoch: 7  accuracy train: 0.666, accuracy valid: 0.671
------ Training Epoch: 8  accuracy train: 0.717, accuracy valid: 0.722
------ Training Epoch: 9  accuracy train: 0.741, ac

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [None]:
import torch
import torch.nn as nn
import snntorch as snn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import snntorch.functional as SF
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

class ADMM_NN:
    """ Class for ADMM Neural Network. """

    def __init__(self, n_inputs, n_hiddens, n_outputs, n_batches,delta,theta,timesteps):
        """
        Initialize variables for NN.
        Raises:
            ValueError: Column input samples, for example, the input size of MNIST data should be (28x28, *) instead of (*, 28x28).
        :param n_inputs: Number of inputs.
        :param n_hiddens: Number of hidden units.
        :param n_outputs: Number of outputs
        :param n_batches: Number of data sample that you want to train
        :param return:
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.delta = delta
        self.T = timesteps

        self.theta = theta
        self.a0 = torch.zeros((n_inputs, n_batches), device=self.device)

        self.w1 = torch.zeros((n_hiddens, n_inputs), device=self.device)
        self.w2 = torch.zeros((n_hiddens, n_hiddens), device=self.device)
        self.w3 = torch.zeros((n_outputs, n_hiddens), device=self.device)

        self.z1 = torch.rand((n_hiddens, n_batches), device=self.device)
        self.a1 = torch.rand((n_hiddens, n_batches), device=self.device)

        self.z2 = torch.rand((n_hiddens, n_batches), device=self.device)
        self.a2 = torch.rand((n_hiddens, n_batches), device=self.device)

        self.z3 = torch.rand((n_outputs, n_batches), device=self.device)

        self.fc1 = nn.Linear(n_inputs, n_hiddens).to(self.device)
        self.lif1 = snn.Leaky(beta=self.delta, threshold=self.theta).to(self.device)

        self.fc2 = nn.Linear(n_hiddens, n_hiddens).to(self.device)
        self.lif2 = snn.Leaky(beta=self.delta, threshold=self.theta).to(self.device)

        self.fc3 = nn.Linear(n_hiddens, n_outputs).to(self.device)
        self.lif3 = snn.Leaky(beta=self.delta, threshold=self.theta).to(self.device)

        self.lambda_larange = torch.ones((n_outputs, n_batches)).to(self.device)



    def __str__(self):
        model_str = "ADMM SNN Model Structure:\n"

        print(f"lambda shape: {self.lambda_larange.shape}")
        return model_str
    def _relu(self, x):
        """
        Relu activation function
        :param x: input x
        :return: max 0 and x
        """
        return F.relu(x)

    def _weight_update(self, layer_output, activation_input):
        """
        Update weights by minimizing ||z_l - W_l a_l-1||^2 with spike train temporal aggregation.
        :param layer_output: Output matrix (z_l), shape [n_hiddens, n_batches].
        :param activation_input: Spike train activations (a_l-1), shape [timesteps, n_features, n_batches].
        :return: Weight matrix for the layer (W_l).
        """
        # Aggregate activations over timesteps
        if activation_input.dim() == 3:  # Spike train input with temporal dimension
            integrated_activation = activation_input.sum(dim=0)  # Shape: [n_features, n_batches]
        elif activation_input.dim() == 2:  # Already aggregated input
            integrated_activation = activation_input  # Shape: [n_features, n_batches]
        else:
            raise ValueError("activation_input must have 2 or 3 dimensions.")
        # Compute pseudo-inverse
        pinv = torch.pinverse(integrated_activation)

        # Calculate updated weight matrix
        weight_matrix = torch.mm(layer_output.float(), pinv.float())

        return weight_matrix.to(self.device)


    def _activation_update(self, next_weight, next_layer_output, layer_nl_output, beta, gamma):
        """
        Minimization for a_l with spike train temporal aggregation.
        The problem involves minimizing:
            beta ||z_l+1 - W_l+1 a_l||^2 + gamma ||a_l - h(z_l)||^2
        :param next_weight: Weight matrix l+1 (W_l+1).
        :param next_layer_output: Output matrix l+1 (z_l+1), shape [n_hiddens, n_batches].
        :param layer_nl_output: Non-linear activation matrix h(z_l), shape [timesteps, n_hiddens, n_batches].
        :param beta: Regularization parameter for z_l+1.
        :param gamma: Regularization parameter for a_l.
        :return: Activation matrix a_l, shape [n_hiddens, n_batches].
        """
        # Aggregate layer_nl_output over timesteps
        integrated_activation = layer_nl_output.sum(dim=0)  # Shape: [n_hiddens, n_batches]

        # Activation inverse (matrix formulation)
        m1 = beta * torch.mm(next_weight.t(), next_weight)  # [n_hiddens, n_hiddens]
        m2 = gamma * torch.eye(m1.shape[0], device=m1.device)  # [n_hiddens, n_hiddens]
        # Regularize to ensure invertibility
        epsilon = 1e-6  # Small regularization constant
        av = torch.inverse(m1.float() + m2.float() + epsilon * torch.eye(m1.shape[0], device=m1.device))

        # Activation formula
        m3 = beta * torch.mm(next_weight.t(), next_layer_output)  # [n_hiddens, n_batches]
        m4 = gamma * layer_nl_output  # [n_hiddens, n_batches]
        af = m3.float() + m4.float()

        # Output
        return torch.mm(av, af).to(self.device)

    def _argminz(self, a, w, a_in, beta, gamma):
        """
        Minimization for z_l with temporal aggregation of spike train data.
        :param a: Activation matrix (a_l), shape [n_hiddens, n_batches].
        :param w: Weight matrix (w_l), shape [n_hiddens, n_features].
        :param a_in: Input activations (spike train, a_l-1), shape [timesteps, n_features, n_batches] or [n_features, n_batches].
        :param beta: Regularization parameter for z_l+1.
        :param gamma: Regularization parameter for a_l.
        :return: Output matrix z_l, shape [n_hiddens, n_batches].
        """
        # Handle temporal aggregation if a_in is 3D
        if a_in.dim() == 3:  # Spike train input with temporal dimension
            integrated_a_in = a_in.sum(dim=0)  # Shape: [n_features, n_batches]
        elif a_in.dim() == 2:  # Already aggregated input
            integrated_a_in = a_in  # Shape: [n_features, n_batches]
        else:
            raise ValueError("a_in must have 2 or 3 dimensions.")

        # Compute intermediate variables
        m = torch.mm(w.float(), integrated_a_in.float())  # Weighted input activations
        sol1 = (gamma * a + beta * m) / (gamma + beta)  # First candidate solution
        sol2 = m  # Second candidate solution

        # Initialize z candidates
        z1 = torch.zeros_like(a)
        z2 = torch.zeros_like(a)
        z = torch.zeros_like(a)

        # Apply conditions to determine z1 and z2
        z1[sol1 >= 0] = sol1[sol1 >= 0]
        z2[sol2 <= 0] = sol2[sol2 <= 0]

        # Compute objective function values for z1 and z2
        fz_1 = gamma * (a - self._relu(z1)).pow(2) + beta * (z1 - m).pow(2)
        fz_2 = gamma * (a - self._relu(z2)).pow(2) + beta * (z2 - m).pow(2)

        # Select the better solution for each element
        index_z1 = fz_1 <= fz_2
        index_z2 = fz_2 < fz_1

        z[index_z1] = z1[index_z1]
        z[index_z2] = z2[index_z2]

        return z

    def _argminlastz(self, targets, eps, w, a_in, beta):
        """
        Minimization of the last output matrix with temporal aggregation of spike train data.
        :param targets: Target matrix (equal dimensions of z), shape [n_outputs, n_batches].
        :param eps: Lagrange multiplier matrix (equal dimensions of z), shape [n_outputs, n_batches].
        :param w: Weight matrix (w_L), shape [n_outputs, n_hiddens].
        :param a_in: Input activations (spike train, a_L-1), shape [timesteps, n_hiddens, n_batches] or [n_hiddens, n_batches].
        :param beta: Regularization parameter for the last layer.
        :return: Output matrix z_L, shape [n_outputs, n_batches].
        """
        # Handle temporal aggregation if a_in is 3D
        if a_in.dim() == 3:  # Spike train input with temporal dimension
            integrated_a_in = a_in.sum(dim=0)  # Shape: [n_hiddens, n_batches]
        elif a_in.dim() == 2:  # Already aggregated input
            integrated_a_in = a_in  # Shape: [n_hiddens, n_batches]
        else:
            raise ValueError("a_in must have 2 or 3 dimensions.")

        # Compute z_L using the aggregated input
        m = torch.mm(w.float(), integrated_a_in.float())  # Weighted input activations
        z = (targets - eps + beta * m) / (1 + beta)  # Closed-form solution for z

        return z


    def _lambda_update(self, zl, w, a_in, beta):
        """
        Lagrange multiplier update with spike train temporal aggregation.
        :param zl: Output matrix last layer (z_L), shape [n_outputs, n_batches].
        :param w: Weight matrix last layer (w_L), shape [n_outputs, n_hiddens].
        :param a_in: Input activations (spike train, a_L-1), shape [timesteps, n_hiddens, n_batches] or [n_hiddens, n_batches].
        :param beta: Regularization parameter for the Lagrange multiplier update.
        :return: Updated Lagrange multiplier matrix, shape [n_outputs, n_batches].
        """
        # Handle temporal aggregation if a_in is 3D
        if a_in.dim() == 3:  # Spike train input with temporal dimension
            integrated_a_in = a_in.sum(dim=0)  # Shape: [n_hiddens, n_batches]
        elif a_in.dim() == 2:  # Already aggregated input
            integrated_a_in = a_in  # Shape: [n_hiddens, n_batches]
        else:
            raise ValueError("a_in must have 2 or 3 dimensions.")


        # Compute the Lagrange multiplier update
        mpt = torch.mm(w.float(), integrated_a_in.float())  # Weighted activations
        lambda_up = beta * (zl - mpt)  # Update rule for lambda

        return lambda_up



    def feed_forward(self, inputs):
        """
        Forward pass using ADMM weights and spiking dynamics with spike train inputs.
        :param inputs: Spike train tensor, shape [timesteps, n_inputs, n_batches].
        :return: Aggregated membrane potential or spike activity, shape [n_outputs, n_batches].
        """
        inputs = inputs.to(self.device)  # Ensure inputs are on GPU

        # Initialize membrane potentials for spiking neurons
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Initialize spike accumulation for output layer
        spk3_rec = []  # Record spikes from the output layer

        if inputs.dim() == 3:  # Spike train input with temporal dimension
            integrated_inputs = inputs.sum(dim=0)  # Shape: [n_features, n_batches]
        elif inputs.dim() == 2:  # Already aggregated input
            integrated_inputs = inputs  # Shape: [n_features, n_batches]
        else:
            raise ValueError("a_in must have 2 or 3 dimensions.")

        # Process inputs over timesteps
        for step in range(self.T):
            cur1 = torch.mm(self.w1, integrated_inputs)  # Replace fc1
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = torch.mm(self.w2, spk1)  # Replace fc2
            spk2, mem2 = self.lif2(cur2, mem2)

            cur3 = torch.mm(self.w3, spk2)  # Replace fc3
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)


        # Return aggregated spike counts or final membrane potential
        return mem3.T

    def fit(self, inputs, labels, beta, gamma):
        """
        Training ADMM Neural Network by minimizing sub-problems.
        :param inputs: Spike train tensor, shape [timesteps, n_inputs, n_batches].
        :param labels: Label tensor, shape [n_outputs, n_batches].
        :param beta: Regularization parameter for weights.
        :param gamma: Regularization parameter for activations.
        :return: Loss value and accuracy.
        """
        # Aggregate input activations over timesteps
        self.a0 = inputs.to(self.device)  # Shape: [n_inputs, n_batches]
        labels = labels.to(self.device)

        # Input layer
        self.w1 = self._weight_update(self.z1, self.a0)
        self.a1 = self._activation_update(self.w2, self.z2, self.z1, beta, gamma)
        self.z1 = self._argminz(self.a1, self.w1, self.a0, beta, gamma)

        # Hidden layer
        self.w2 = self._weight_update(self.z2, self.a1)
        self.a2 = self._activation_update(self.w3, self.z3, self.z2, beta, gamma)
        self.z2 = self._argminz(self.a2, self.w2, self.a1, beta, gamma)

        # Output layer
        self.w3 = self._weight_update(self.z3, self.a2)
        self.z3 = self._argminlastz(labels, self.lambda_larange, self.w3, self.a2, beta)
        self.lambda_larange = self._lambda_update(self.z3, self.w3, self.a2, beta)

        # Evaluate performance
        loss, accuracy = self.evaluate(inputs, labels)
        return loss, accuracy

    def evaluate(self, inputs, labels, isCategories=True):
        """
        Calculate loss and accuracy (only classification).
        :param inputs: Spike train tensor, shape [timesteps, n_inputs, n_batches].
        :param labels: Label tensor, shape [n_outputs, n_batches].
        :param isCategories: Whether the task is classification.
        :return: Loss value and accuracy.
        """
        inputs = inputs.to(self.device)
        labels = labels.to(self.device)

        # Forward pass
        forward = self.feed_forward(inputs)  # Aggregated output shape: [n_batches, n_outputs]

        # Transpose forward to match labels
        forward = forward.T  # Now shape is [n_outputs, n_batches]


        # Compute loss
        loss = torch.mean((forward - labels).pow(2))

        # Compute accuracy (if applicable)
        if isCategories:
            if labels.ndim == 1:  # Labels as class indices
                accuracy = (labels == torch.argmax(forward, dim=0)).float().mean()
            else:  # Labels as one-hot encoded
                accuracy = (torch.argmax(labels, dim=0) == torch.argmax(forward, dim=0)).float().mean()
        else:
            accuracy = loss

        return loss, accuracy



    def warming(self, inputs, labels, epochs, beta, gamma):
        """
        Warming ADMM Neural Network by minimizing sub-problems without updating lambda.
        :param inputs: Spike train tensor, shape [timesteps, n_inputs, n_batches].
        :param labels: Label tensor, shape [n_outputs, n_batches].
        :param epochs: Number of warming epochs.
        :param beta: Regularization parameter for weights.
        :param gamma: Regularization parameter for activations.
        """
        self.a0 = inputs.sum(dim=0).to(self.device)  # Shape: [n_inputs, n_batches]
        labels = labels.to(self.device)

        for i in range(epochs):
            # Input layer
            self.w1 = self._weight_update(self.z1, self.a0)
            self.fc1.weight.data = self.w1
            self.a1 = self._activation_update(self.w2, self.z2, self.z1, beta, gamma)
            self.z1 = self._argminz(self.a1, self.w1, self.a0, beta, gamma)

            # Hidden layer
            self.w2 = self._weight_update(self.z2, self.a1)
            self.fc2.weight.data = self.w2
            self.a2 = self._activation_update(self.w3, self.z3, self.z2, beta, gamma)
            self.z2 = self._argminz(self.a2, self.w2, self.a1, beta, gamma)

            # Output layer
            self.w3 = self._weight_update(self.z3, self.a2)
            self.fc3.weight.data = self.w3
            self.z3 = self._argminlastz(labels, self.lambda_larange, self.w3, self.a2, beta)


    def drawcurve(self, train_, valid_, id, legend_1, legend_2):
        acc_train = np.array(train_).flatten()
        acc_test = np.array(valid_).flatten()
        plt.figure(id)
        plt.plot(acc_train, label=legend_1)
        plt.plot(acc_test, label=legend_2)
        plt.ylim(bottom=0)
        plt.legend(loc='upper left')
        plt.draw()
        plt.pause(0.001)
        return 0

In [None]:
def rate_coding(data, timesteps):
    """
    Converts static input data to spike trains using rate coding.
    :param data: Input tensor of shape (num_samples, num_features)
    :param timesteps: Number of timesteps for spiking simulation
    :return: Spike train tensor of shape (timesteps, num_samples, num_features)
    """
    spike_trains = torch.rand((timesteps, data.size(0), data.size(1))).to(data.device) < data.unsqueeze(0)
    return spike_trains.float()


In [None]:
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST data with proper transformations
mnist = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())

# Convert training data to PyTorch tensors and move to GPU
train_data = mnist.data.view(-1, 28 * 28).T.float().to(device)  # Reshape and convert to float32
train_labels = F.one_hot(mnist.targets, num_classes=10).T.float().to(device)  # One-hot encode labels

# Load validation data
mnist_valid = datasets.MNIST('./data', train=False, transform=transforms.ToTensor())

# Convert validation data to PyTorch tensors and move to GPU
valid_data = mnist_valid.data.view(-1, 28 * 28).T.float().to(device)  # Reshape and convert to float32
valid_labels = F.one_hot(mnist_valid.targets, num_classes=10).T.float().to(device)  # One-hot encode labels


# Convert the entire dataset into spike trains
timestep = 20  # Number of timesteps
spike_train_data = rate_coding(train_data, timestep)
spike_valid_data = rate_coding(valid_data, timestep)

In [None]:
# Parameters
n_inputs = timestep* 28 * 28  # MNIST image shape 28x28
n_outputs = 10      # MNIST classes from 0-9 digits
n_hiddens = 512     # Number of neurons in the hidden layer
n_batches = train_data.shape[1]  # Number of samples for training
train_epochs = 100
warm_epochs = 10
beta = 12
gamma = 0.5
delta = 0.99
theta = 0.7

# Initialize the model
model = ADMM_NN(n_inputs, n_hiddens, n_outputs, n_batches, delta, theta, timestep)

model.warming(spike_train_data, train_labels, warm_epochs, beta, gamma)

# Lists to store metrics
list_loss_train = []
list_loss_valid = []
list_accuracy_train = []
list_accuracy_valid = []

# Training loop
for epoch in range(train_epochs):
    print(f"Epoch {epoch + 1}/{train_epochs}")

    # Training phase
    epoch_accuracy_train = 0.0
    _, accuracy_train = model.fit(spike_train_data, train_labels, beta, gamma)
    epoch_accuracy_train += accuracy_train.item()

    # Validation phase
    epoch_accuracy_valid = 0.0
    loss_valid, accuracy_valid = model.evaluate(spike_valid_data, valid_labels)
    epoch_accuracy_valid += accuracy_valid.item()


    # Print epoch metrics
    print(f"Training Accuracy: {epoch_accuracy_train:.4f}")
    print(f" Validation Accuracy: {epoch_accuracy_valid:.4f}")

Epoch 1/100
Training Accuracy: 0.1795
 Validation Accuracy: 0.1810
Epoch 2/100
Training Accuracy: 0.1781
 Validation Accuracy: 0.1811
Epoch 3/100
Training Accuracy: 0.1793
 Validation Accuracy: 0.1723
Epoch 4/100
Training Accuracy: 0.1798
 Validation Accuracy: 0.1707
Epoch 5/100
Training Accuracy: 0.1803
 Validation Accuracy: 0.1786
Epoch 6/100
Training Accuracy: 0.1807
 Validation Accuracy: 0.1801
Epoch 7/100
Training Accuracy: 0.1873
 Validation Accuracy: 0.1814
Epoch 8/100
Training Accuracy: 0.1919
 Validation Accuracy: 0.1864
Epoch 9/100
Training Accuracy: 0.1959
 Validation Accuracy: 0.1956
Epoch 10/100
Training Accuracy: 0.2040
 Validation Accuracy: 0.1981
Epoch 11/100
Training Accuracy: 0.2117
 Validation Accuracy: 0.2042
Epoch 12/100
Training Accuracy: 0.2189
 Validation Accuracy: 0.2134
Epoch 13/100
Training Accuracy: 0.2302
 Validation Accuracy: 0.2236
Epoch 14/100
Training Accuracy: 0.2410
 Validation Accuracy: 0.2361
Epoch 15/100
Training Accuracy: 0.2524
 Validation Accura