<a href="https://colab.research.google.com/github/ergysmedaunipd/thesis/blob/main/ThesisUnipdSNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tonic
!pip install snntorch
!pip install psutil

Collecting tonic
  Downloading tonic-1.5.0-py3-none-any.whl.metadata (5.4 kB)
Collecting importRosbag>=1.0.4 (from tonic)
  Downloading importRosbag-1.0.4-py3-none-any.whl.metadata (4.3 kB)
Collecting pbr (from tonic)
  Downloading pbr-6.1.0-py2.py3-none-any.whl.metadata (3.4 kB)
Collecting expelliarmus (from tonic)
  Downloading expelliarmus-1.1.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading tonic-1.5.0-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.6/116.6 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading importRosbag-1.0.4-py3-none-any.whl (28 kB)
Downloading expelliarmus-1.1.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pbr-6.1.0-py2.py3-none-any.

In [133]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import time
import snntorch as snn

from typing import List


class ADMM_SNN:

    """ Class for ADMM Neural Network. """

    def __init__(self, n_samples: int, n_timesteps: int, input_dim: int, hidden_dims: List[int], n_outputs: int, rho: float, deltas: torch.Tensor, thetas: torch.Tensor):
        self.device = "cpu"

        # Define hyperparameters:
        # - thetas = Thresholds (can be all the same or different for each neuron)
        # - deltas = Decay factors (can be all the same or different for each neuron)
        # - roh = Penalty parameter. All the \alpha_{l,t} = \beta_{l,t} = \rho/2

        self.rho = rho
        self.deltas = deltas
        self.thetas = thetas

        self.L = len(hidden_dims)
        self.T = n_timesteps

        # Define a_0_t, which will be the input to the first layer
        self.a0 = torch.zeros(
            (n_timesteps, n_samples, input_dim)).to(self.device)

        # === Initialize W_l ===
        self.W = []

        # Now define the weights for each layer
        for i, hidden_dim in enumerate(hidden_dims):
            if i == 0:
                self.W.append(torch.zeros(
                    (hidden_dim, input_dim)).to(self.device))
            else:
                self.W.append(torch.zeros(
                    (hidden_dim, hidden_dims[i-1])).to(self.device))

        # === Initialize z_l ===
        self.z = []

        for i, hidden_dim in enumerate(hidden_dims):
            self.z.append(torch.zeros(
                (n_timesteps, n_samples, hidden_dim)).to(self.device))

        # === Initialize a_l ===
        self.a = []

        for i, hidden_dim in enumerate(hidden_dims):
            self.a.append(torch.zeros(
                (n_timesteps, n_samples, hidden_dim)).to(self.device))

        # Check how to make initialization?
        self.lambda_lagrange = torch.zeros(
            (n_samples, n_outputs)).to(self.device)

        print("Shapes in __init__:")
        for i, W in enumerate(self.W):
            print(f"W[{i}] shape: {W.shape}")
        for i, z in enumerate(self.z):
            print(f"z[{i}] shape: {z.shape}")
        for i, a in enumerate(self.a):
            print(f"a[{i}] shape: {a.shape}")

    def _heaviside(self, x):
        # Implement the Heaviside function that compares each element in x to self.thetas
        # Returns 1 if the element in x exceeds the threshold in self.thetas, otherwise 0
        return (x >= self.thetas).float()

    # ============ W_{l} update functions ============
    def _weight_update(self, layer_output, activation_input):
        # Implements the weight update for layers 1 to L-1 as in line 2 of Algorithm 2 (Equation 4)
        # where α_{l,t} = ρ/2

        T = self.T
        rho_half = self.rho / 2
        numerator = sum(rho_half * layer_output[t] @ activation_input[t].T for t in range(T))
        denominator = sum(rho_half * (activation_input[t] @ activation_input[t].T) for t in range(T))

        return numerator @ torch.linalg.pinv(denominator)

    def _weight_update_L(self, layer_output, activation_input):
        # Implements the weight update for the output layer L as in line 10 of Algorithm 2 (Equation 6)
        # where α_{L,t} = ρ/2 and includes the Lagrange multiplier λ

        T = self.T
        rho_half = self.rho / 2
        lambda_term = -self.lambda_lagrange / self.rho

        numerator = lambda_term @ activation_input[-1].T + sum(rho_half * layer_output[t] @ activation_input[t].T for t in range(T))
        denominator = sum(rho_half * (activation_input[t] @ activation_input[t].T) for t in range(T))

        return numerator @ torch.linalg.pinv(denominator)

    # ============ z_{l,t} update functions ============
    def _z_update(self, q, r, alpha_l, delta, a, t):
        # Implements the z_{l,t} update for l = 1, ..., L - 1 and t < T (Equation 14)
        # q and r are precomputed terms as defined in the equations

        theta_a = self.thetas * a
        delta_term = alpha_l[t + 1] * delta * (r + theta_a) if t < self.T - 1 else 0
        numerator = alpha_l[t] * q + delta_term
        denominator = alpha_l[t] + delta**2 * alpha_l[t + 1] if t < self.T - 1 else alpha_l[t]

        return numerator / denominator

    def _z_update_T(self, q, alpha_l, t):
        # Implements the z_{l,T} update for l = 1, ..., L - 1 and t = T (Equation 14 with t = T)

        return q  # The simplification for the last time step where t = T

    def _z_update_L(self, s, r, alpha_l, delta, a, t):
        # Implements the z_{L,t} update for l = L and t < T (Equation 16 with t < T)

        theta_a = self.thetas * a
        delta_term = alpha_l[t + 1] * delta * (r + theta_a) if t < self.T - 1 else 0
        numerator = alpha_l[t] * s + delta_term
        denominator = alpha_l[t] + delta**2 * alpha_l[t + 1] if t < self.T - 1 else alpha_l[t]

        return numerator / denominator

    def _z_update_L_T(self, s, y, alpha_l, rho, lambda_lagrange, t):
        # Implements the z_{L,T} update for l = L and t = T (Equation 16 with t = T)

        lambda_term = (y - lambda_lagrange / rho) if t == self.T - 1 else 0
        numerator = alpha_l[t] * s + lambda_term
        denominator = alpha_l[t] + 1 if t == self.T - 1 else alpha_l[t]

        return numerator / denominator

    def check_entries(self, z, cost_function):
        # Implements Algorithm 1 from the document to check and modify entries in z if needed
        for n in range(z.shape[0]):  # Assuming z has shape (N_l, M) where N_l is number of neurons
            for m in range(z.shape[1]):  # M is the number of samples
                if z[n, m] > self.thetas and cost_function(z[n, m]) > cost_function(self.thetas):
                    z[n, m] = self.thetas
                elif z[n, m] <= self.thetas and cost_function(z[n, m] + 1e-5) < cost_function(z[n, m]):
                    z[n, m] = self.thetas + 1e-5
        return z

    # ============ a_{l,t} update functions ============
    def _activation_update(self, wl, wl_next, alpha_l, beta_l, delta, theta, z, a_prev, t):
        # Implement the Activation update for l=1,...,L-2, t=1,...,T-1 (line 4 of Algorithm 2)

        rho_theta_I = (theta**2) * torch.eye(wl.size(0)).to(self.device)  # Term (θ^2) * I
        WtW = wl_next.T @ wl_next
        term1 = -theta * delta * (wl @ a_prev)  # Term with -θ * delta * wl
        term2 = WtW + rho_theta_I + beta_l * torch.eye(wl.size(0)).to(self.device)  # Adding W_{l+1}^T W_{l+1} + (θ^2) * I
        term3 = wl.T @ z[t + 1]  # Term with W_{l+1}^T * z_{l+1, t+1}

        # Compute activation update
        activation_update = torch.inverse(term2) @ (term1 + term3)

        return activation_update

    def _activation_update_T(self, wl, wl_next, alpha_l, beta_l, z, theta, a_prev, t):
        # Implement the Activation update for l=1,...,L-2, t=T (line 7 of Algorithm 2)

        term1 = wl.T @ (z[-1] - theta)  # Term W_{l+1}^T * z_{l+1, T}
        WtW = wl_next.T @ wl_next
        term2 = WtW + beta_l * torch.eye(wl.size(0)).to(self.device)  # Adding W_{l+1}^T W_{l+1}

        # Compute activation update for last time step
        activation_update_T = torch.inverse(term2) @ (term1 + wl.T)

        return activation_update_T

    def _activation_update_Lminus1(self, wl, wl_next, alpha_l, beta_l, delta, theta, z, a_prev, t):
        # Implement the Activation update for l=L-1, t=1,...,T-1 (line 4 again, check Indicator functions)

        rho_theta_I = (theta**2) * torch.eye(wl.size(0)).to(self.device)  # Term (θ^2) * I
        WtW = wl_next.T @ wl_next
        term1 = -theta * delta * (wl @ a_prev)  # Term with -θ * delta * wl
        term2 = WtW + rho_theta_I + beta_l * torch.eye(wl.size(0)).to(self.device)  # Adding W_{l+1}^T W_{l+1} + (θ^2) * I
        term3 = wl.T @ z[t + 1]  # Term with W_{l+1}^T * z_{l+1, t+1}

        # Compute activation update for L-1 layer
        activation_update_Lminus1 = torch.inverse(term2) @ (term1 + term3)

        return activation_update_Lminus1

    def _activation_update_Lminus1_T(self, wl, wl_next, alpha_l, beta_l, z, theta, a_prev, t):
        # Implement the Activation update for l=L-1, t=T (line 7 again, check Indicator functions)

        term1 = wl.T @ (z[-1] - theta)  # Term W_{l+1}^T * z_{l+1, T}
        WtW = wl_next.T @ wl_next
        term2 = WtW + beta_l * torch.eye(wl.size(0)).to(self.device)  # Adding W_{l+1}^T W_{l+1}

        # Compute activation update for last layer at T
        activation_update_Lminus1_T = torch.inverse(term2) @ (term1 + wl.T)

        return activation_update_Lminus1_T

    # ============ lagrange multiplier update ============
    def _lambda_update(self, zL_T, delta_zL_T_minus_1, WL, aL_minus_1_T, rho):
        # Implement the update of the Lagrange multiplier lambda (Line 15 of Algorithm 2)

        # Calculate the term inside the parentheses: zL,T - delta * zL,T-1 - WL * aL-1,T
        term = zL_T - delta_zL_T_minus_1 - WL @ aL_minus_1_T

        # Update lambda using the formula in Line 15
        lambda_update = self.lambda_lagrange + rho * term

        return lambda_update

    def feed_forward(self, inputs):
        # Ensure inputs are on the correct device
        inputs = inputs.to(self.device)

        # Initialize hidden states for each layer
        mem_states = [snn.Leaky(beta=self.deltas[i]).init_leaky() for i in range(self.L)]

        # Record outputs for the final layer
        mem_rec = []

        # Run the simulation for each timestep
        for t in range(self.T):
            current_input = inputs[t] if t < inputs.size(0) else torch.zeros_like(inputs[0])

            # Process through each layer
            for l in range(self.L):
                # Calculate post-synaptic current for the current layer
                if l == 0:
                    post_syn_current = current_input @ self.W[l].T
                else:
                    post_syn_current = self.a[l - 1][t] @ self.W[l].T

                # Apply snn.Leaky (LIF neuron model) to the post-synaptic current
                lif_neuron = snn.Leaky(beta=self.deltas[l], threshold=self.thetas[l],
                                       reset_mechanism='subtract' if l < self.L - 1 else 'none').to(self.device)
                mem, mem_states[l] = lif_neuron(post_syn_current, mem_states[l])

                # Apply Heaviside function to determine if there is a spike
                spike = self._heaviside(mem)

                # Store membrane potential and spike
                self.z[l][t] = mem  # Membrane potential
                self.a[l][t] = spike  # Binary spike output

            # Record membrane potential of the last layer for the final timestep
            if t == self.T - 1:
                mem_rec.append(self.z[-1][t])

        # Return the membrane potentials of the last layer at the final timestep
        return mem_rec[-1]

    def fit(self, inputs):
        # This function updates the optimization variables, given an input batch of data samples.

        # Carry out the updates following algorithm (2)

        # Here is a skeleton of the implementation:
        for l in range(1, self.L):
            # Update self.W[l] using the function _weight_update
            self.W[l] = self._weight_update(self.z[l], self.a[l - 1])

            for t in range(1, self.T):
                if l < self.L - 1:
                    self.a[l][t] = self._activation_update(self.W[l], self.W[l + 1], self.rho / 2, self.rho / 2, self.deltas[l], self.thetas[l], self.z[l], self.a[l - 1][t], t)
                else:
                    self.a[l][t] = self._activation_update_Lminus1(self.W[l], self.W[l + 1], self.rho / 2, self.rho / 2, self.deltas[l], self.thetas[l], self.z[l], self.a[l - 1][t], t)

                # update selfz[l][t] using the function _z_update and check_entries
                self.z[l][t] = self.check_entries(self._z_update(self.z[l][t], self.a[l][t], self.rho / 2, self.deltas[l], self.a[l][t], t), self._heaviside)

            if l < self.L-1:
                self.a[l][self.T] = self._activation_update_T(self.W[l], self.W[l + 1], self.rho / 2, self.rho / 2, self.z[l], self.thetas[l], self.a[l - 1][self.T - 1], self.T)
            else:
                # update self.a[l][T] using the function _activation_update_Lminus1_T
                self.a[l][self.T] = self._activation_update_Lminus1_T(self.W[l], self.W[l + 1], self.rho / 2, self.rho / 2, self.z[l], self.thetas[l], self.a[l - 1][self.T - 1], self.T)


            # update self.z[l][T] using the function _z_update_T and check_entries
            self.z[l][self.T] = self.check_entries(self._z_update_T(self.z[l][self.T], self.rho / 2, self.T), self._heaviside)


        # ----- Update the last layer -----
        # Update self.W[L] using the function _weight_update_L
        self.W[self.L] = self._weight_update_L(self.z[self.L], self.a[self.L - 1])
        for t in range(1, self.T):
            # update self.z[L][t] using the function _z_update_L
            self.z[self.L][t] = self._z_update_L(self.z[self.L][t], self.a[self.L][t], self.rho / 2, self.deltas[self.L], self.a[self.L][t], t)

        # ----- Update the last layer at time T -----
        # update self.z[L][T] using the function _z_update_L_T
        self.z[self.L][self.T] = self._z_update_L_T(self.z[self.L][self.T], inputs, self.rho, self.lambda_lagrange, self.T)

        # Update the lagrange multiplier using the function _lambda_update
        self.lambda_lagrange = self._lambda_update(self.z[self.L][self.T], self.deltas[self.L] * self.z[self.L][self.T - 1], self.W[self.L], self.a[self.L - 1][self.T - 1], self.rho)
        return

    def evaluate(self, inputs):
        """
        Standard evaluation phase.

        Args:
            inputs (torch.Tensor): Input data batch to be evaluated.

        Returns:
            torch.Tensor: Output membrane potentials or activations of the last layer.
        """
        # Ensure inputs are on the correct device
        inputs = inputs.to(self.device)

        # Initialize hidden states for each layer
        mem_states = [snn.Leaky(beta=self.deltas[i]).init_leaky() for i in range(self.L)]

        # Record the output of the final layer for each time step
        mem_rec = []

        # Run the simulation for each timestep
        for t in range(self.T):
            current_input = inputs[t] if t < inputs.size(0) else torch.zeros_like(inputs[0])

            # Process through each layer
            for l in range(self.L):
                # Calculate post-synaptic current for the current layer
                if l == 0:
                    post_syn_current = current_input @ self.W[l].T
                else:
                    post_syn_current = self.a[l - 1][t] @ self.W[l].T

                # Apply snn.Leaky (LIF neuron model) to the post-synaptic current
                lif_neuron = snn.Leaky(beta=self.deltas[l], threshold=self.thetas[l],
                                       reset_mechanism='subtract' if l < self.L - 1 else 'none').to(self.device)
                mem, mem_states[l] = lif_neuron(post_syn_current, mem_states[l])

                # Apply Heaviside function to determine if there is a spike
                spike = self._heaviside(mem)

                # Store membrane potential and spike
                self.z[l][t] = mem  # Membrane potential
                self.a[l][t] = spike  # Binary spike output

            # Record the membrane potential of the last layer at each time step
            if l == self.L - 1:
                mem_rec.append(self.z[l][t])

        # Return the final membrane potentials (or spikes) of the last layer at each time step
        return torch.stack(mem_rec, dim=0)

    def warming(self, inputs, labels, epochs, beta, gamma):
        """
        Warming phase for ADMM SNN by minimizing sub-problems without updating lambda.

        Args:
            inputs (torch.Tensor): Input data samples for training.
            labels (torch.Tensor): Labels for training data samples.
            epochs (int): Number of warming-up epochs.
            beta (float): Value of beta for activation update.
            gamma (float): Value of gamma for membrane potential update.

        """
        self.a0 = inputs  # Set initial input layer activations

        for epoch in range(epochs):
            print(f"------ Warming Epoch: {epoch} ------")

            # Iterate over each layer
            for l in range(1, self.L):
                # Update weights W_l using the weight update function
                self.W[l] = self._weight_update(self.z[l], self.a[l - 1])

                # Update activations (a) and membrane potentials (z) for each time step
                for t in range(1, self.T):
                    if l < self.L - 1:
                        # Intermediate layer activation update
                        self.a[l][t] = self._activation_update(self.W[l], self.W[l + 1], beta, gamma, self.deltas[l], self.thetas[l], self.z[l], self.a[l - 1][t], t)
                    else:
                        # Last hidden layer activation update
                        self.a[l][t] = self._activation_update_Lminus1(self.W[l], self.W[l + 1], beta, gamma, self.deltas[l], self.thetas[l], self.z[l], self.a[l - 1][t], t)

                    # Update z for current layer and time step
                    self.z[l][t] = self._z_update(self.z[l][t], self.a[l][t], beta, self.deltas[l], self.a[l][t], t)

                # Update activations and membrane potential for the last time step T
                if l < self.L - 1:
                    self.a[l][self.T] = self._activation_update_T(self.W[l], self.W[l + 1], beta, gamma, self.z[l], self.thetas[l], self.a[l - 1][self.T - 1], self.T)
                else:
                    self.a[l][self.T] = self._activation_update_Lminus1_T(self.W[l], self.W[l + 1], beta, gamma, self.z[l], self.thetas[l], self.a[l - 1][self.T - 1], self.T)

                # Update z for the last time step T
                self.z[l][self.T] = self._z_update_T(self.z[l][self.T], beta, self.T)

            # ----- Warming up the last layer -----
            # Update weights for the output layer
            self.W[self.L] = self._weight_update_L(self.z[self.L], self.a[self.L - 1])

            for t in range(1, self.T):
                # Update z in the last layer for each time step
                self.z[self.L][t] = self._z_update_L(self.z[self.L][t], self.a[self.L][t], beta, self.deltas[self.L], self.a[self.L][t], t)

            # Update z for the last layer at time T
            self.z[self.L][self.T] = self._z_update_L_T(self.z[self.L][self.T], labels, beta, self.lambda_lagrange, self.T)






In [104]:
import torch
import torch.nn as nn
import torchvision
import tonic
import tonic.transforms as transforms
from torch.utils.data import DataLoader
from tonic import DiskCachedDataset
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils


# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations
sensor_size = tonic.datasets.NMNIST.sensor_size
frame_transform = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.ToFrame(sensor_size=sensor_size, time_window=1000)
])

# Load datasets
trainset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=True)
testset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=False)

# Cache datasets
cached_trainset = DiskCachedDataset(trainset, cache_path='./cache/nmnist/train')
cached_testset = DiskCachedDataset(testset, cache_path='./cache/nmnist/test')

# DataLoaders
batch_size = 128
trainloader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(), shuffle=True)
testloader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors())



Shapes in __init__:
W[0] shape: torch.Size([100, 1156])
W[1] shape: torch.Size([50, 100])
z[0] shape: torch.Size([300, 128, 100])
z[1] shape: torch.Size([300, 128, 50])
a[0] shape: torch.Size([300, 128, 100])
a[1] shape: torch.Size([300, 128, 50])
Sample data shape: torch.Size([128, 309, 2, 34, 34])
Sample target shape: torch.Size([128])


Shapes in __init__:
W[0] shape: torch.Size([100, 1156])
W[1] shape: torch.Size([50, 100])
z[0] shape: torch.Size([300, 128, 100])
z[1] shape: torch.Size([300, 128, 50])
a[0] shape: torch.Size([300, 128, 100])
a[1] shape: torch.Size([300, 128, 50])
Sample data shape: torch.Size([128, 311, 2, 34, 34])
Sample target shape: torch.Size([128])


In [134]:

# Define the ADMM SNN model
n_samples = batch_size
n_timesteps = 300  # Adjust based on data
input_dim = sensor_size[0] * sensor_size[1]
hidden_dims = [128, 64]  # Example hidden layer dimensions
n_outputs = 10
rho = 0.5
deltas = torch.tensor([0.9, 0.8, 0.7])  # Adjust for each layer
thetas = torch.tensor([1.0, 1.0, 1.0])  # Adjust for each layer

model = ADMM_SNN(n_samples, n_timesteps, input_dim, hidden_dims, n_outputs, rho, deltas, thetas)


sample_data, sample_target = next(iter(trainloader))
print("Sample data shape:", sample_data.shape)
print("Sample target shape:", sample_target.shape)

# Warming phase parameters
warming_epochs = 5
beta = 0.5
gamma = 0.5

# Perform the warming phase
for i, (data, targets) in enumerate(trainloader):
    data, targets = data.to(device), targets.to(device)

    # Flatten spatial dimensions
    data = data.view(data.size(0), data.size(1), -1)
    if data.size(1) > 300:
        data = data[:, :300, :]

    model.warming(data, targets, epochs=warming_epochs, beta=beta, gamma=gamma)
    print(f"Warming completed on batch {i+1}/{len(trainloader)}")
    break  # Run warming on only one batch for simplicity; you could extend it if needed




Shapes in __init__:
W[0] shape: torch.Size([128, 1156])
W[1] shape: torch.Size([64, 128])
z[0] shape: torch.Size([300, 128, 128])
z[1] shape: torch.Size([300, 128, 64])
a[0] shape: torch.Size([300, 128, 128])
a[1] shape: torch.Size([300, 128, 64])
Sample data shape: torch.Size([128, 310, 2, 34, 34])
Sample target shape: torch.Size([128])
------ Warming Epoch: 0 ------


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x64 and 128x128)

In [None]:
# Training parameters
num_epochs = 1

# Main ADMM training loop
for epoch in range(num_epochs):
    for i, (data, targets) in enumerate(trainloader):
        data, targets = data.to(device), targets.to(device)

        # Flatten spatial dimensions
        data = data.view(data.size(0), data.size(1), -1)
        if data.size(1) > 300:
            data = data[:, :300, :]

        # Perform ADMM optimization using the `fit` method
        model.fit(data)

        # Compute loss (for monitoring purposes only)
        spk_rec = model.evaluate(data)
        loss_val = SF.mse_count_loss(spk_rec, targets)

        if i % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(trainloader)}], Loss: {loss_val.item():.4f}")

# Evaluation
total = 0
correct = 0
for data, targets in testloader:
    data, targets = data.to(device), targets.to(device)

    # Flatten spatial dimensions
    data = data.view(data.size(0), data.size(1), -1)
    if data.size(1) > 300:
        data = data[:, :300, :]

    # Forward pass using `evaluate` method of ADMM_SNN
    spk_rec = model.evaluate(data)
    _, predicted = spk_rec.sum(dim=0).max(1)
    total += targets.size(0)
    correct += (predicted == targets).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")