<a href="https://colab.research.google.com/github/ergysmedaunipd/thesis/blob/main/ThesisUnipdSNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tonic
!pip install snntorch
!pip install psutil

Collecting tonic
  Downloading tonic-1.5.0-py3-none-any.whl.metadata (5.4 kB)
Collecting importRosbag>=1.0.4 (from tonic)
  Downloading importRosbag-1.0.4-py3-none-any.whl.metadata (4.3 kB)
Collecting pbr (from tonic)
  Downloading pbr-6.1.0-py2.py3-none-any.whl.metadata (3.4 kB)
Collecting expelliarmus (from tonic)
  Downloading expelliarmus-1.1.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading tonic-1.5.0-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.6/116.6 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading importRosbag-1.0.4-py3-none-any.whl (28 kB)
Downloading expelliarmus-1.1.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m938.3 kB/s[0m eta [36m0:00:00[0m
[?25hDownloading pbr-6.1.0-py2.py3-none-an

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import time
import snntorch as snn
from snntorch import surrogate


from typing import List


class ADMM_SNN:
    """ Class for ADMM Neural Network. """

    def __init__(self, n_samples: int, n_timesteps: int, input_dim: int, hidden_dims: List[int], n_outputs: int, rho: float, delta: float, theta: float):
        """
        Initialize the ADMM-SNN model.

        Parameters:
        n_samples (int): Number of samples in the batch.
        n_timesteps (int): Number of timesteps.
        input_dim (int): Dimension of the input features.
        hidden_dims (List[int]): List of hidden layer dimensions.
        n_outputs (int): Number of output classes.
        rho (float): ADMM penalty parameter.
        delta (float): Decay factor for the intermediate variables (z).
        theta (float): Activation threshold for the spiking neurons.
        """
        # Define the device to run the model on (GPU if available)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Store model parameters
        self.n_samples = n_samples
        self.n_timesteps = n_timesteps
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.n_outputs = n_outputs


        # Loss function for training
        self.loss_fn = nn.CrossEntropyLoss()

        # Hyperparameters for ADMM-SNN
        self.rho = rho             # ADMM penalty parameter
        self.delta = delta         # Decay factor for z updates
        self.theta = theta         # Activation threshold for spiking neurons

        # Model architecture specifications
        self.L = len(hidden_dims) + 1  # Number of hidden layers
        self.T = n_timesteps       # Number of timesteps in each forward pass

        # Initialize initial activations (a0) as a tensor with zero values
        self.a_minus_one = torch.zeros( (n_timesteps, n_samples, input_dim)).to(self.device)

        # Initialize weights (W) for each layer in hidden_dims
        self.W = []
        # First hidden layer (input -> first hidden)
        self.W.append(torch.normal(0, 0.01, (hidden_dims[0], input_dim)).to(self.device))

        # Other hidden layers
        for i in range(1, len(hidden_dims)):
            self.W.append(torch.normal(0, 0.01, (hidden_dims[i], hidden_dims[i-1])).to(self.device))

        # Output layer
        self.W.append(torch.normal(0, 0.01, (n_outputs, hidden_dims[-1])).to(self.device))




        # Initialize intermediate variables (z) for each layer and timestep
        self.z = []
        # Hidden layers
        for hidden_dim in hidden_dims:
            self.z.append(torch.rand((n_timesteps, n_samples, hidden_dim)).to(self.device))
        # Output layer
        self.z.append(torch.rand((n_timesteps, n_samples, n_outputs)).to(self.device))




        # Initialize activations (a) for each layer and timestep (except output)
        self.a = []
        for hidden_dim in hidden_dims:
            self.a.append(torch.rand((n_timesteps, n_samples, hidden_dim)).to(self.device))



        # Initialize Lagrange multipliers (lambda) for the output layer
        self.lambda_lagrange = torch.ones((n_samples, n_outputs)).to(self.device)


    def __str__(self):
        model_str = "ADMM SNN Model Structure:\n"
        model_str += f" - rho: {self.rho}, delta: {self.delta}, theta: {self.theta}\n"
        model_str += f" - Number of timesteps: {self.T}\n"
        model_str += f" - Input dimension: {self.a_minus_one.size()}\n"
        model_str += f" - Hidden layers: {[w.shape for w in self.W]}\n"
        model_str += f" - Output dimension (Lagrange Multiplier): {self.lambda_lagrange.size()}\n"
        model_str += f" - self.L : {self.L }\n"
        model_str += f" - self.T : {self.T }\n"

        """Helper method to print shapes of initialized tensors"""
        print(f"\nModel Initialization Details:")
        print(f"Device: {self.device}")
        print(f"Previous activation shape: {self.a_minus_one.shape}")

        for i, w in enumerate(self.W):
            print(f"Weight layer {i} shape: {w.shape}")

        for i, z_layer in enumerate(self.z):
            print(f"z[{i}] shape: {z_layer.shape}")

        for i, a_layer in enumerate(self.a):
            print(f"a[{i}] shape: {a_layer.shape}")

        print(f"lambda shape: {self.lambda_lagrange.shape}")

        return model_str

    def _heaviside(self, z):
        """
        Implement the Heaviside function to calculate activations based on a threshold.
        This function returns 1 where z meets or exceeds the threshold and 0 otherwise.

        Parameters:
        z (torch.Tensor): The input tensor representing the intermediate variable z at layer l and time t.

        Returns:
        torch.Tensor: A tensor of the same shape as z with values 0 or 1, based on whether each element
                      in z is above or below the threshold self.theta.
        """
        # Apply the Heaviside step function:
        # Convert continuous values in z to binary values (0 or 1) based on self.theta
        # - Returns 1 where z >= theta (neuron spikes).
        # - Returns 0 where z < theta (neuron does not spike).
        return (z >= self.theta).float()


    # ============ W_{l} update functions ============
    def _weight_update(self, l: int) -> torch.Tensor:
        """
        Update weights for layers 1 to L-1 following equation (4) from the paper:
        W_l = (sum_t alpha_l,t * x_l,t * a_l-1,t^T)(sum_t alpha_l,t * a_l-1,t * a_l-1,t^T)^-1

        Args:
            l: Layer index
        Returns:
            Updated weight matrix for layer l
        """
        alpha = self.rho / 2  # Following paper's simplified alpha_l,t = rho/2
        z_l = self.z[l]  # [timesteps, batch, features_out]
        a_prev = self.a[l-1] if l > 0 else self.a_minus_one  # [timesteps, batch, features_in]

        n_timesteps, batch_size, n_features_out = z_l.shape
        _, _, n_features_in = a_prev.shape

        # Initialize accumulator tensors
        x_a_sum = torch.zeros((n_features_out, n_features_in), device=self.device)
        a_a_sum = torch.zeros((n_features_in, n_features_in), device=self.device)

        for t in range(n_timesteps):
            # Construct x_l,t according to paper definition
            if t == 0:
                x_l_t = z_l[0]  # First timestep
            else:
                # x_l,t = z_l,t - δz_l,t-1 + θa_l,t-1
                x_l_t = z_l[t] - self.delta * z_l[t-1] + self.theta * self.a[l][t-1]

            # Update sums using einsum for better efficiency
            x_a_sum += alpha * torch.einsum('bf,bp->fp', x_l_t, a_prev[t])
            a_a_sum += alpha * torch.einsum('bp,bq->pq', a_prev[t], a_prev[t])


        eps = 1e-6
        a_a_sum += eps * torch.eye(n_features_in, device=self.device)

        # Return W_l = x_a_sum @ (a_a_sum)^-1
        return x_a_sum @ torch.inverse(a_a_sum)

    def _weight_update_L(self, y: torch.Tensor) -> torch.Tensor:
        """
        Update weights for the final layer L following equation (6) from the paper:
        W_L = (-λ/2 * a_L-1,T^T + sum_t alpha_L,t * x_L,t * a_L-1,t^T)(sum_t alpha_L,t * a_L-1,t * a_L-1,t^T)^-1

        Args:
            y: Target values [batch, n_outputs]
        Returns:
            Updated weight matrix for output layer
        """
        alpha = self.rho / 2
        z_L = self.z[-1]  # Last layer's z values
        a_L_minus_1 = self.a[-2]  # Second to last layer's activations

        n_timesteps, batch_size, n_features_out = z_L.shape
        _, _, n_features_in = a_L_minus_1.shape

        # Calculate λ term from equation (6)
        lambda_term = -0.5 * torch.einsum('bf,bp->fp',
                                        self.lambda_lagrange,
                                        a_L_minus_1[-1])  # Use last timestep

        # Initialize accumulator tensors
        x_a_sum = torch.zeros((n_features_out, n_features_in), device=self.device)
        a_a_sum = torch.zeros((n_features_in, n_features_in), device=self.device)

        for t in range(n_timesteps):
            # Construct x_L,t according to paper definition for output layer
            if t == 0:
                x_L_t = z_L[0]
            else:
                # Note: output layer doesn't include θa term
                x_L_t = z_L[t] - self.delta * z_L[t-1]

            # Update sums
            x_a_sum += alpha * torch.einsum('bf,bp->fp', x_L_t, a_L_minus_1[t])
            a_a_sum += alpha * torch.einsum('bp,bq->pq', a_L_minus_1[t], a_L_minus_1[t])

        # Add small identity matrix for numerical stability
        eps = 1e-6
        a_a_sum += eps * torch.eye(n_features_in, device=self.device)

        # Return W_L = (lambda_term + x_a_sum) @ (a_a_sum)^-1
        return (lambda_term + x_a_sum) @ torch.inverse(a_a_sum)


    # ============ a_{l,t} update functions ============
    # Activation update for l=1,...,L-2, t=1,...,T-1 (Equation 8)
    def _calculate_u_w_v(self, l, t):
        """
        Calculate helper vectors u_l, v_l, w_l
        """
        z_l = self.z[l]
        a_l = self.a[l]
        a_l_minus_1 = self.a[l-1] if l > 0 else self.a_minus_one

        # Calculate u_l[t]
        if t == 0:
            u_l_t = z_l[0]
        else:
            u_l_t = z_l[t] - self.delta * z_l[t-1]

        # Calculate v_l[t]
        if t == 0:
            v_l_t = u_l_t
        else:
            v_l_t = u_l_t + self.theta * a_l[t-1]

        # Calculate w_l[t]
        projected_a = torch.matmul(a_l_minus_1[t], self.W[l].t())
        w_l_t = u_l_t - projected_a

        return u_l_t, w_l_t, v_l_t

    def _activation_update(self, l, t):
        """
        Update activations based on equation (8)
        """

        alpha = self.rho / 2
        beta = alpha

        # Get dimensions
        batch_size = self.z[l][t].shape[0]  # 256
        n_features = self.z[l][t].shape[1]  # 128 for layer 0
        W_next = self.W[l+1]  # [64, 128] for layer 0->1

        # Matrix terms
        I = torch.eye(n_features, device=self.device)
        term1 = (self.theta ** 2) * alpha * I
        term2 = alpha * W_next.t() @ W_next  # [128, 64] @ [64, 128] -> [128, 128]
        term3 = beta * I

        matrix_to_invert = term1 + term2 + term3

        # Calculate RHS terms
        _, w_next, _ = self._calculate_u_w_v(l, t+1)
        rhs_term1 = -self.theta * alpha * w_next  # [256, 128]

        _, _, v_next = self._calculate_u_w_v(l+1, t)  # v_next shape: [256, 64]
        # Fix: multiply in correct order
        rhs_term2 = alpha * torch.matmul(v_next, W_next)  # [256, 64] @ [64, 128] -> [256, 128]

        heaviside_term = self._heaviside(self.z[l][t])  # [256, 128]
        rhs_term3 = beta * heaviside_term

        # Combine RHS terms
        rhs = rhs_term1 + rhs_term2 + rhs_term3

        # Solve the system
        updated_a = torch.linalg.solve(matrix_to_invert, rhs.t()).t()

        return updated_a

    # Activation update for l=L-1, t=1,...,T-1 (Equation 10)
    def _activation_update_T(self, l):
        """
        Update activations for l=1,...,L-2 at t=T based on equation (9):
        a_l,T = (α_l+1,t*W_{l+1}^T*W_{l+1} + β_l,t*I)^{-1} *
                (α_l+1,t*W_{l+1}^T*v_{l+1,t} + β_l,t*h_l(z_l,t))
        """

        alpha = self.rho / 2
        beta = alpha

        # Get dimensions
        batch_size = self.z[l][self.T-1].shape[0]  # 256
        n_features = self.z[l][self.T-1].shape[1]  # current layer features
        W_next = self.W[l+1]

        # Calculate matrix to invert: α_l+1,t*W_{l+1}^T*W_{l+1} + β_l,t*I
        term1 = alpha * W_next.t() @ W_next
        term2 = beta * torch.eye(n_features, device=self.device)
        matrix_to_invert = term1 + term2

        # Calculate v_{l+1,t} for the first RHS term
        _, _, v_next = self._calculate_u_w_v(l+1, self.T-1)
        # Fix: multiply in correct order
        rhs_term1 = alpha * torch.matmul(v_next, W_next)  # [256, 64] @ [64, 128] -> [256, 128]

        # Calculate h_l(z_l,t) for second RHS term
        heaviside_term = self._heaviside(self.z[l][self.T-1])  # [256, 128]
        rhs_term2 = beta * heaviside_term

        # Combine RHS terms
        rhs = rhs_term1 + rhs_term2

        # Solve the system
        updated_a = torch.linalg.solve(matrix_to_invert, rhs.t()).t()

        return updated_a

    def _activation_update_Lminus1(self, t):
        """
        Update activations for l=L-1 based on equation (11):
        a_L-1,t = (α_L,t*W_L^T*W_L + β_L-1,t*I)^{-1} *
                  (W_L^T(α_L,t*u_L,t - λ/2*1(t=T)) + β_L-1,t*h_L-1(z_L-1,t - θ))
        """

        alpha = self.rho / 2
        beta = alpha
        l = self.L - 2  # L-1 in zero-based indexing
        # Get dimensions
        batch_size = self.z[l][t].shape[0]
        n_features = self.z[l][t].shape[1]
        W_L = self.W[l]

        # Calculate matrix to invert: α_L,t*W_L^T*W_L + β_L-1,t*I
        term1 = alpha * W_L.t() @ W_L
        term2 = beta * torch.eye(n_features, device=self.device)
        matrix_to_invert = term1 + term2

        # Calculate u_L,t
        u_L, _, _ = self._calculate_u_w_v(l, t)

        # Calculate first part of RHS: W_L^T(α_L,t*u_L,t)
        # Transpose u_L to match W_L dimensions
        rhs_term1 = alpha * (W_L.t() @ u_L.t()).t()  # Resulting in shape [batch_size, n_features]

        # Calculate h_L-1(z_L-1,t - θ)
        heaviside_term = self._heaviside(self.z[l][t] - self.theta)
        rhs_term2 = beta * heaviside_term

        # Combine RHS terms
        rhs = rhs_term1 + rhs_term2

        # Solve the system
        updated_a = torch.linalg.solve(matrix_to_invert, rhs.t()).t()

        return updated_a
    def _activation_update_Lminus1_T(self):
        """
        Update activations for l=L-1 at t=T based on equation (11) with t=T:
        a_L-1,T = (α_L,T*W_L^T*W_L + β_L-1,T*I)^{-1} *
                  (W_L^T(α_L,T*u_L,T - λ/2) + β_L-1,T*h_L-1(z_L-1,T - θ))
        """
        alpha = self.rho / 2
        beta = alpha
        l = self.L - 2  # L-1 in zero-based indexing
        t = self.T - 1  # T in zero-based indexing

        # Get dimensions
        batch_size = self.z[l][t].shape[0]  # Should be 64
        n_features = self.z[l][t].shape[1]  # Should be 64 (hidden dim)
        W_L = self.W[l+1]  # Should be [10, 64]


        # Calculate matrix to invert: α_L,T*W_L^T*W_L + β_L-1,T*I
        term1 = alpha * W_L.t() @ W_L  # [64, 10] @ [10, 64] -> [64, 64]
        term2 = beta * torch.eye(n_features, device=self.device)
        matrix_to_invert = term1 + term2  # [64, 64]

        # Calculate u_L,T for final timestep
        u_L, _, _ = self._calculate_u_w_v(l, t)  # Shape: [64, 64]

        # Reshape lambda to match batch size and features
        lambda_term = self.lambda_lagrange / 2  # [64, 10]

        # Calculate W_L^T(α_L,T*u_L,T - λ/2)
        # First, project u_L to output dimension
        u_L_projected = u_L @ W_L.t()  # [64, 64] @ [64, 10] -> [64, 10]

        # Now subtract lambda_term
        diff = alpha * u_L_projected - lambda_term  # [64, 10]

        # Project back using W_L.t()
        rhs_term1 = diff @ W_L  # [64, 10] @ [10, 64] -> [64, 64]

        # Calculate h_L-1(z_L-1,T - θ)
        heaviside_term = self._heaviside(self.z[l][t] - self.theta)  # [64, 64]
        rhs_term2 = beta * heaviside_term

        # Combine RHS terms
        rhs = rhs_term1 + rhs_term2  # [64, 64]

        # Solve the system
        updated_a = torch.linalg.solve(matrix_to_invert, rhs.t()).t()


        return updated_a
     # ============ z_{l,t} update functions ============
    def _calculate_p_s_q_r(self, l, t):
        """
        Calculate helper vectors defined in paper:
        p_l = W_l[a_l-1,1, ..., a_l-1,T]
        s_l = p_l + δ[0, z_l,1, ..., z_l,T-1]
        q_l = s_l - θ[0, a_l,1, ..., a_l,T-1]
        r_l = -p_l + z_l
        """
        # Get previous layer activations and current layer data
        a_l_minus_1 = self.a[l-1] if l > 0 else self.a_minus_one

        # Calculate p_l: project previous layer activations through weights
        p_l = torch.matmul(a_l_minus_1[t], self.W[l].t())  # [batch, features]

        # Calculate s_l: add decayed previous z values
        if t == 0:
            s_l = p_l
        else:
            s_l = p_l + self.delta * self.z[l][t-1]

        # Calculate q_l: subtract theta-weighted activations if available
        if t == 0:
            q_l = s_l
        else:
            q_l = s_l - self.theta * self.a[l][t-1] if l < len(self.a) else s_l

        # Calculate r_l: negative p_l plus current z
        r_l = -p_l + self.z[l][t]

        return p_l, s_l, q_l, r_l

    def _z_update(self, l, t):
        """
        Update z_{l,t} for l=1,...,L-1 and t=1,...,T-1
        z_{l,t} = (α_l,t*q_l,t + α_l,t+1*δ(r_l,t+1 + θa_l,t)*1(t<T)) / (α_l,t + δ²α_l,t+1*1(t<T))
        """
        alpha = self.rho / 2

        # Calculate helper vectors
        _, _, q_l, r_l = self._calculate_p_s_q_r(l, t)

        # Calculate initial z update (before check_entries)
        numerator = alpha * q_l

        if t < self.T - 1:
            _, _, _, r_next = self._calculate_p_s_q_r(l, t+1)
            second_term = alpha * self.delta * (r_next + self.theta * self.a[l][t])
            numerator = numerator + second_term

        denominator = alpha
        if t < self.T - 1:
            denominator = denominator + (self.delta ** 2) * alpha

        # Initial z update
        z_update = numerator / denominator

        # Apply check_entries
        return self.check_entries(z_update)

    def _z_update_T(self, l):
        """
        Update z_{l,T} for l=1,...,L-1
        z_{l,T} = q_{l,T}
        """
        # Calculate helper vectors for final timestep
        _, _, q_l, _ = self._calculate_p_s_q_r(l, self.T-1)

        return self.check_entries(q_l)

    def _z_update_L(self, t):
        """
        Update z_{L,t} for the output layer (L) at intermediate timesteps t = 1,...,T-1
        using the simplified formula:
        z_{L,t} = (s_{L,t} + δ * r_{L,t+1}) / (1 + δ²)

        Parameters:
        - t: current timestep

        Returns:
        - z_{L,t} update
        """

        # Calculate s_{L,t} and r_{L,t+1} for the intermediate timestep
        _, s_L_t, _, _ = self._calculate_p_s_q_r(self.L - 1, t)
        _, _, _, r_L_t_plus_1 = self._calculate_p_s_q_r(self.L - 1, t + 1)

        # Numerator and denominator for the z update
        numerator = s_L_t + self.delta * r_L_t_plus_1
        denominator = 1 + self.delta ** 2

        # Calculate z_{L,t}
        z_update = numerator / denominator

        return self.check_entries(z_update)

    def _z_update_L_T(self, y):
        """
        Update z_{L,T} for the output layer (L) at the final timestep T
        using the formula:
        z_{L,T} = (α * s_{L,T} + α * δ * r_{L,T+1} + (y - λ / 2)) / (α + δ² * α + 1)

        Parameters:
        - y: target output at the final timestep

        Returns:
        - z_{L,T} update
        """

        # Set alpha as a constant value
        alpha = self.rho / 2

        # Calculate s_{L,T} and r_{L,T+1}
        _, s_L_T, _, _ = self._calculate_p_s_q_r(self.L - 1, self.T - 2)
        _, _, _, r_L_T_plus_1 = self._calculate_p_s_q_r(self.L - 1, self.T-1)

        # Numerator and denominator for the z update
        numerator = alpha * s_L_T + alpha * self.delta * r_L_T_plus_1 + (y - self.lambda_lagrange / 2)
        denominator = alpha + (self.delta ** 2) * alpha + 1

        # Calculate z_{L,T}
        z_update = numerator / denominator

        return self.check_entries(z_update)


    def check_entries(self, z):
        """
        Adjust entries of z based on a simplified cost approximation in a vectorized way.

        Parameters:
        - z: Tensor representing z_{l,t}^{n,m}, shape [N_l, M]

        Returns:
        - Adjusted tensor z with entries modified according to the conditions.
        """
        theta = self.theta
        epsilon = self.epsilon if hasattr(self, 'epsilon') else 1e-2

        # Condition 1: z > theta and abs difference <= current value
        mask1 = (z > theta) & (torch.abs(z - theta) <= torch.abs(z))
        z[mask1] = theta

        # Condition 2: z <= theta and abs difference with theta + epsilon < current value
        mask2 = (z <= theta) & (torch.abs(z - (theta + epsilon)) < torch.abs(z))
        z[mask2] = theta + epsilon

        return z
    # ============ lagrange multiplier update ============

    def _lambda_update(self):
        """
        Update the Lagrange multiplier lambda according to the ADMM update rule.

        Formula:
        lambda^+ = lambda + rho * (z_{L,T} - delta * z_{L,T-1} - W_L * a_{L-1,T})

        Returns:
        - Updated lambda
        """
        # Retrieve the necessary variables for the update
        z_L_T = self.z[self.L - 1][self.T - 1]       # z_{L,T}, shape [64, 10]
        z_L_T_minus_1 = self.z[self.L - 1][self.T - 2]  # z_{L,T-1}, shape [64, 10]
        W_L = self.W[self.L - 1]                     # Weight matrix W_L, shape [10, size of a_{L-1,T}]
        a_L_minus_1_T = self.a[self.L - 2][self.T - 1]  # a_{L-1,T}, shape [64, W_L.shape[1]]

        # Calculate the expression inside the parentheses
        term = z_L_T - self.delta * z_L_T_minus_1 - torch.matmul(a_L_minus_1_T, W_L.T)

        # Update lambda
        lambda_update = self.lambda_lagrange + self.rho * term

        # Update and return the lambda variable
        self.lambda_lagrange = lambda_update
        return self.lambda_lagrange

    def feed_forward(self, inputs):
        """
        Implement forward pass using SNNTorch LIF neurons.
        Returns membrane potentials of final layer.

        Parameters:
        - inputs: [timesteps, batch_size, input_dim]
        Returns:
        - mem[-1]: Final layer membrane potentials [batch_size, n_outputs]
        """
        # Initialize membrane potentials for each layer
        mem = []
        spikes = []  # Track spikes for debugging
        for l in range(self.L):
            if l == 0:
                mem.append(torch.zeros(inputs.shape[1], self.hidden_dims[0], device=self.device))
            elif l == self.L - 1:
                mem.append(torch.zeros(inputs.shape[1], self.n_outputs, device=self.device))
            else:
                mem.append(torch.zeros(inputs.shape[1], self.hidden_dims[l], device=self.device))

        # Create LIF neurons for each layer
        neurons = []
        for l in range(self.L):
            if l < self.L - 1:
                # Hidden layers: reset by subtraction
                lif = snn.Leaky(
                    beta=self.delta,
                    threshold=self.theta,
                    reset_mechanism="subtract",
                    learn_beta=False,
                    learn_threshold=False
                )
            else:
                # Output layer: no reset
                lif = snn.Leaky(
                    beta=self.delta,
                    threshold=self.theta,
                    reset_mechanism="none",
                    learn_beta=False,
                    learn_threshold=False
                )
            neurons.append(lif)

        # Process each timestep and track activations
        for t in range(inputs.shape[0]):
            x = inputs[t]  # Current input slice

            for l in range(self.L):

                # Apply weights and LIF neuron dynamics
                x = x @ self.W[l].T  # Apply weights

                spike, mem[l] = neurons[l](x, mem[l])  # Apply LIF neuron
                spikes.append(spike)  # Track spikes


                # Spikes become input to the next layer
                x = spike

        # Check if final layer membrane potentials seem reasonable
        final_output = mem[-1]
        print(f"\nFinal Layer Membrane Potentials - Mean: {final_output.mean():.4f}, Std: {final_output.std():.4f}")
        print(f"Final Layer Potential Range - Min: {final_output.min().item():.4f}, Max: {final_output.max().item():.4f}")

        return final_output  # [batch_size, n_outputs]
    def evaluate(self, inputs, targets):
        """
        Evaluate model performance on N-MNIST dataset with detailed prediction information.

        Parameters:
        - inputs: [timesteps, batch_size, input_dim]
        - targets: [batch_size, 10] (one-hot encoded)

        Returns:
        - loss: Cross-entropy loss
        - predictions: Final layer output
        """
        print("\n=== N-MNIST Evaluation ===")

        # Get final layer activity
        raw_predictions = self.feed_forward(inputs)  # [batch_size, n_outputs]
        predictions = F.softmax(raw_predictions, dim=1)
        print("\nFinal Layer Activity:")
        print(f"Shape: {predictions.shape}")
        print(f"Activity Stats - Mean: {predictions.mean():.4f}, Std: {predictions.std():.4f}")
        print(f"Activity Range - Min: {predictions.min():.4f}, Max: {predictions.max():.4f}")

        # Calculate predictions and true classes
        pred_classes = predictions.argmax(dim=1)  # Predicted classes
        true_classes = targets.argmax(dim=1)          # True classes (from one-hot labels)

        # Initialize confusion matrix and metrics storage
        confusion_matrix = torch.zeros(10, 10, dtype=torch.int32)
        metrics_per_class = []

        # Display predictions count for each real class with highlight format
        print("\nDetailed Predictions for Each Real Class:")
        for class_idx in range(10):
            # Mask for current class
            class_mask = (true_classes == class_idx)

            # Predictions for current class
            predicted_for_class = pred_classes[class_mask]

            # Count predictions for each possible class (0 to 9), with brackets for the current class index
            predictions_count = []
            for i in range(10):
                count = (predicted_for_class == i).sum().item()
                if i == class_idx:
                    predictions_count.append(f"[{count}]")  # Highlight correct predictions with brackets
                else:
                    predictions_count.append(str(count))

            print(f"Real Class {class_idx} -> Prediction Counts: ({', '.join(predictions_count)})")

            # Populate confusion matrix row for real class
            confusion_matrix[class_idx] = torch.tensor([int(predictions_count[i].strip("[]")) for i in range(10)])

            # Calculate metrics for current class
            TP = confusion_matrix[class_idx, class_idx].float()
            FP = confusion_matrix[:, class_idx].sum().float() - TP
            FN = confusion_matrix[class_idx, :].sum().float() - TP

            precision = TP / (TP + FP) if TP + FP > 0 else torch.tensor(0.0)
            recall = TP / (TP + FN) if TP + FN > 0 else torch.tensor(0.0)
            f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else torch.tensor(0.0)

            metrics_per_class.append({'precision': precision, 'recall': recall, 'f1': f1})

        # Calculate overall metrics
        macro_precision = torch.stack([m['precision'] for m in metrics_per_class]).mean()
        macro_recall = torch.stack([m['recall'] for m in metrics_per_class]).mean()
        macro_f1 = torch.stack([m['f1'] for m in metrics_per_class]).mean()
        accuracy = (pred_classes == true_classes).float().mean()

        print("\nOverall Metrics:")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Macro Precision: {macro_precision:.4f}")
        print(f"Macro Recall: {macro_recall:.4f}")
        print(f"Macro F1-Score: {macro_f1:.4f}")

        # Calculate cross-entropy loss
        loss = self.loss_fn(predictions, true_classes)
        print(f"\nLoss: {loss.item():.6f}")

        return loss, predictions



    def fit(self, inputs, targets, warm=True):
        """
        Update the optimization variables following Algorithm 2 from the paper.
        Architecture:
        - L = 4 (3 hidden + 1 output)
        - Layer 0: 1156 -> 128
        - Layer 1: 128 -> 64
        - Layer 2: 64 -> 32
        - Layer 3: 32 -> 10
        """
      #  print("\nStarting optimization updates...")
        self.a_minus_one = inputs  # Initialize a_minus_one with the current batch of inputs

        # First update hidden layers (0 to 2)
        for l in range(self.L - 1):  # Layers 0, 1, 2
            self.W[l] = self._weight_update(l)

            for t in range(self.T - 1):
                if l < self.L - 2:
                    self.a[l][t] = self._activation_update(l, t)
                else:
                    self.a[l][t] = self._activation_update_Lminus1(t)

                self.z[l][t] = self._z_update(l, t)


            if l < self.L - 2:
                self.a[l][self.T - 1] = self._activation_update_T(l)
            else:
                self.a[l][self.T - 1] = self._activation_update_Lminus1_T()


            self.z[l][self.T - 1] = self._z_update_T(l)

        self.W[self.L - 1] = self._weight_update_L(y=targets)

        for t in range(self.T - 1):
            self.z[self.L - 1][t] = self._z_update_L(t)

        self.z[self.L - 1][self.T - 1] = self._z_update_L_T(targets)

        if not warm:
            self._lambda_update()

        # Evaluate current performance (optional for debugging)
        loss, predictions = self.evaluate(inputs, targets)
        return loss, predictions



In [3]:
import torch
import torch.nn as nn
import torchvision
import tonic
import tonic.transforms as transforms
from torch.utils.data import DataLoader
from tonic import DiskCachedDataset
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils


# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations
sensor_size = tonic.datasets.NMNIST.sensor_size
frame_transform = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.ToFrame(sensor_size=sensor_size, time_window=1000)
])

# Load datasets
trainset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=True)
testset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=False)





Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/1afc103f-8799-464a-a214-81bb9b1f9337 to ./data/NMNIST/train.zip


  0%|          | 0/1011893601 [00:00<?, ?it/s]

Extracting ./data/NMNIST/train.zip to ./data/NMNIST
Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/a99d0fee-a95b-4231-ad22-988fdb0a2411 to ./data/NMNIST/test.zip


  0%|          | 0/169674850 [00:00<?, ?it/s]

Extracting ./data/NMNIST/test.zip to ./data/NMNIST


In [4]:
import torch
import torch.nn as nn
import tonic
import tonic.transforms as transforms
from torch.utils.data import DataLoader, Subset
from typing import Tuple, List, Dict

def prepare_nmnist_data(inputs: torch.Tensor,
                       labels: torch.Tensor,
                       device: torch.device,
                       n_timesteps: int = 300) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Prepare NMNIST data for ADMM-SNN training.

    Args:
        inputs (torch.Tensor): Input spike data [batch_size, timesteps, channels]
        labels (torch.Tensor): Labels
        device (torch.device): Device to move data to
        n_timesteps (int): Number of timesteps to use
    """
    batch_size = inputs.shape[0]
    inputs = inputs.to(device)
    labels = labels.to(device)

    # Truncate to desired number of timesteps
    inputs = inputs[:, :n_timesteps]

    # Reshape to [timesteps, batch_size, input_dim]
    inputs = inputs.reshape(batch_size, n_timesteps, -1).permute(1, 0, 2).float()

    # Normalize inputs to [0, 1] range
    inputs = inputs / inputs.max()

    # One-hot encode labels
    labels_onehot = torch.zeros(batch_size, 10, device=device)
    labels_onehot.scatter_(1, labels.unsqueeze(1), 1)

    return inputs, labels_onehot

def train(model, trainloader, num_epochs):
    """
    Train ADMM-SNN model
    """
    # Warming phase
    print("\nWarming Phase:")
    warming_losses = []
    for epoch in range(2):
        print(f'Warming Epoch [{epoch+1}/5]')
        epoch_losses = []
        epoch_accuracies = []

        for batch_idx, (inputs, labels) in enumerate(trainloader):
            print(f'  Batch [{batch_idx+1}/{len(trainloader)}]')
            # Prepare data
            inputs, labels = prepare_nmnist_data(inputs, labels, device)

            # Warming step
            loss, predictions = model.fit(inputs, labels, True)
            # Calculate accuracy
            pred_classes = torch.argmax(predictions, dim=1)
            true_classes = torch.argmax(labels, dim=1)
            accuracy = (pred_classes == true_classes).float().mean().item()

            epoch_losses.append(loss)
            epoch_accuracies.append(accuracy)
        # Epoch summary
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        avg_acc = sum(epoch_accuracies) / len(epoch_accuracies)
        warming_losses.append(avg_loss)
        print(f'  Epoch Summary - Loss: {avg_loss:.6f}, Accuracy: {avg_acc:.4f}')


    # Main training
    print("\nMain Training Phase:")
    training_metrics = []
    for epoch in range(num_epochs):
        print(f'Epoch [{epoch+1}/{num_epochs}]')
        epoch_losses = []
        epoch_accuracies = []

        for batch_idx, (inputs, labels) in enumerate(trainloader):
            print(f'  Batch [{batch_idx+1}/{len(trainloader)}]')

            # Prepare data
            inputs, labels = prepare_nmnist_data(inputs, labels, device)

            # Training step
            loss, predictions = model.fit(inputs, labels,False)

            # Ensure predictions and labels have correct shape
            labels = labels.t()

            # Calculate accuracy
            pred_classes = torch.argmax(predictions, dim=1)
            true_classes = torch.argmax(labels, dim=1)
            accuracy = (pred_classes == true_classes).float().mean().item()

            epoch_losses.append(loss)
            epoch_accuracies.append(accuracy)



        # Epoch summary
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        avg_acc = sum(epoch_accuracies) / len(epoch_accuracies)
        training_metrics.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'accuracy': avg_acc
        })
        print(f'  Epoch Summary - Loss: {avg_loss:.6f}, Accuracy: {avg_acc:.4f}')

    return warming_losses, training_metrics

def evaluate(model, testloader):
    print("\nEvaluation Phase:")
    total_loss = 0
    total_acc = 0
    num_batches = 0
    batch_metrics = []

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(testloader):
            print(f'  Batch [{batch_idx+1}/{len(testloader)}]')

            # Prepare data
            inputs, labels = prepare_nmnist_data(inputs, labels, device)

            # Forward pass
            loss, predictions = model.evaluate(inputs, labels)

            # Calculate accuracy
            pred_classes = torch.argmax(predictions, dim=0)
            true_classes = torch.argmax(labels, dim=0)
            accuracy = (pred_classes == true_classes).float().mean().item()

            # Store batch metrics
            batch_metrics.append({
                'batch': batch_idx + 1,
                'loss': loss,
                'accuracy': accuracy
            })

            total_loss += loss
            total_acc += accuracy
            num_batches += 1

            print(f'    Loss: {loss:.6f}, Accuracy: {accuracy:.4f}')

    avg_loss = total_loss / num_batches
    avg_acc = total_acc / num_batches

    print(f'\nFinal Evaluation Results:')
    print(f'  Average Loss: {avg_loss:.6f}')
    print(f'  Average Accuracy: {avg_acc:.4f}')

    return avg_loss, avg_acc, batch_metrics

In [6]:


# Initialize model and training
n_timesteps = 300
input_dim =   sensor_size[0] * sensor_size[1]
hidden_dims = [256,256]
n_outputs = 10
# DataLoaders


# Create DataLoaders with the subsets
batch_size = 256
# Calculate size of 20% of data
train_size = int( len(trainset)) // batch_size * batch_size
test_size = int(  len(testset)) // batch_size * batch_size

from torch.utils.data import Subset
train_subset = Subset(trainset, torch.randperm(len(trainset))[:train_size])
test_subset = Subset(testset, torch.randperm(len(testset))[:test_size])



trainloader = DataLoader(train_subset,
                        batch_size=batch_size,
                        collate_fn=tonic.collation.PadTensors(),
                        shuffle=True)
testloader = DataLoader(test_subset,
                       batch_size=batch_size,
                       collate_fn=tonic.collation.PadTensors())

print(f"Original training set size: {len(trainset)}")
print(f"training set size: {len(train_subset)}")
print(f"Original test set size: {len(testset)}")
print(f"test set size: {len(test_subset)}")


rho = 0.15
delta = 0.5
theta = 0.2


# Create model
print("\nInitializing model...")
model = ADMM_SNN(
    n_samples=batch_size,
    n_timesteps=n_timesteps,
    input_dim=input_dim,
    hidden_dims=hidden_dims,  # Changed final dimension to 10
    n_outputs=10,
    rho=rho,
    delta=delta,
    theta=theta
)
print(model)


# Train model
print("\nStarting training process...")
num_epochs = 20
warming_losses, training_metrics = train(model, trainloader, num_epochs)

# Evaluate model
print("\nEvaluating model...")
test_loss, test_acc, test_metrics = evaluate(model, testloader)

# Print final results
print("\nTraining Summary:")
print(f"  Warming phase final loss: {warming_losses[-1]:.6f}")
print(f"  Training final loss: {training_metrics[-1]['loss']:.6f}")
print(f"  Training final accuracy: {training_metrics[-1]['accuracy']:.4f}")
print("\nTest Results:")
print(f"  Test Loss: {test_loss:.6f}")
print(f"  Test Accuracy: {test_acc:.4f}")

Original training set size: 60000
training set size: 59904
Original test set size: 10000
test set size: 9984

Initializing model...

Model Initialization Details:
Device: cpu
Previous activation shape: torch.Size([300, 256, 1156])
Weight layer 0 shape: torch.Size([256, 1156])
Weight layer 1 shape: torch.Size([256, 256])
Weight layer 2 shape: torch.Size([10, 256])
z[0] shape: torch.Size([300, 256, 256])
z[1] shape: torch.Size([300, 256, 256])
z[2] shape: torch.Size([300, 256, 10])
a[0] shape: torch.Size([300, 256, 256])
a[1] shape: torch.Size([300, 256, 256])
lambda shape: torch.Size([256, 10])
ADMM SNN Model Structure:
 - rho: 0.15, delta: 0.5, theta: 0.2
 - Number of timesteps: 300
 - Input dimension: torch.Size([300, 256, 1156])
 - Hidden layers: [torch.Size([256, 1156]), torch.Size([256, 256]), torch.Size([10, 256])]
 - Output dimension (Lagrange Multiplier): torch.Size([256, 10])
 - self.L : 3
 - self.T : 300


Starting training process...

Warming Phase:
Warming Epoch [1/5]
  Batc

KeyboardInterrupt: 