<a href="https://colab.research.google.com/github/ergysmedaunipd/thesis/blob/main/ThesisUnipdSNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tonic
!pip install snntorch
!pip install psutil

Collecting tonic
  Downloading tonic-1.5.0-py3-none-any.whl.metadata (5.4 kB)
Collecting importRosbag>=1.0.4 (from tonic)
  Downloading importRosbag-1.0.4-py3-none-any.whl.metadata (4.3 kB)
Collecting pbr (from tonic)
  Downloading pbr-6.1.0-py2.py3-none-any.whl.metadata (3.4 kB)
Collecting expelliarmus (from tonic)
  Downloading expelliarmus-1.1.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading tonic-1.5.0-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.6/116.6 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading importRosbag-1.0.4-py3-none-any.whl (28 kB)
Downloading expelliarmus-1.1.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pbr-6.1.0-py2.py3-none-any.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import time
import snntorch as snn
from snntorch import surrogate


from typing import List


class ADMM_SNN:
    """ Class for ADMM Neural Network. """

    def __init__(self, n_samples: int, n_timesteps: int, input_dim: int, hidden_dims: List[int], n_outputs: int, rho: float, delta: float, theta: float):
        """
        Initialize the ADMM-SNN model.

        Parameters:
        n_samples (int): Number of samples in the batch.
        n_timesteps (int): Number of timesteps.
        input_dim (int): Dimension of the input features.
        hidden_dims (List[int]): List of hidden layer dimensions.
        n_outputs (int): Number of output classes.
        rho (float): ADMM penalty parameter.
        delta (float): Decay factor for the intermediate variables (z).
        theta (float): Activation threshold for the spiking neurons.
        """
        # Define the device to run the model on (GPU if available)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Store model parameters
        self.n_samples = n_samples
        self.n_timesteps = n_timesteps
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.n_outputs = n_outputs


        # Loss function for training
        self.loss_fn = nn.CrossEntropyLoss()

        # Hyperparameters for ADMM-SNN
        self.rho = rho             # ADMM penalty parameter
        self.delta = delta         # Decay factor for z updates
        self.theta = theta         # Activation threshold for spiking neurons

        # Model architecture specifications
        self.L = len(hidden_dims)+1  # Number of hidden layers
        self.T = n_timesteps       # Number of timesteps in each forward pass

        # Initialize initial activations (a0) as a tensor with zero values
        self.a_minus_one = torch.rand((n_timesteps, n_samples, input_dim)).to(self.device)

        # Initialize weights (W) for each layer in hidden_dims
        prev_dim = input_dim
        self.W = nn.ParameterList()
        for hidden_dim in hidden_dims:
            self.W.append(nn.Parameter(
                torch.randn(hidden_dim, prev_dim).to(self.device)   # Changed from zeros to small random
            ))
            prev_dim = hidden_dim

        # Output layer weights
        self.W.append(nn.Parameter(
            torch.randn(n_outputs, hidden_dims[-1]).to(self.device)  # Changed from zeros to small random
        ))

        # Initialize intermediate variables (z) for each layer and timestep
        self.z = []
        for hidden_dim in hidden_dims:
            z_layer = torch.rand((n_timesteps, n_samples, hidden_dim)).to(self.device)  # small random values
            self.z.append(z_layer)
        self.z.append(torch.rand((n_timesteps, n_samples, n_outputs)).to(self.device))



        # Initialize activations (a) for each layer and timestep
        self.a = []
        for hidden_dim in hidden_dims:
            a_layer = torch.rand((n_timesteps, n_samples, hidden_dim)).to(self.device)   # small random values
            self.a.append(a_layer)
        self.a.append(torch.rand((n_timesteps, n_samples, n_outputs)).to(self.device))

        # Initialize Lagrange multipliers (lambda) for the output layer
        self.lambda_lagrange = torch.ones((n_samples, n_outputs)).to(self.device)


    def __str__(self):
        model_str = "ADMM SNN Model Structure:\n"
        model_str += f" - rho: {self.rho}, delta: {self.delta}, theta: {self.theta}\n"
        model_str += f" - Number of timesteps: {self.T}\n"
        model_str += f" - Input dimension: {self.a_minus_one.size()}\n"
        model_str += f" - Hidden layers: {[w.shape for w in self.W]}\n"
        model_str += f" - Output dimension (Lagrange Multiplier): {self.lambda_lagrange.size()}\n"
        model_str += f" - self.L : {self.L }\n"
        model_str += f" - self.T : {self.T }\n"

        """Helper method to print shapes of initialized tensors"""
        print(f"\nModel Initialization Details:")
        print(f"Device: {self.device}")
        print(f"Previous activation shape: {self.a_minus_one.shape}")

        for i, w in enumerate(self.W):
            print(f"Weight layer {i} shape: {w.shape}")

        for i, z_layer in enumerate(self.z):
            print(f"z[{i}] shape: {z_layer.shape}")

        for i, a_layer in enumerate(self.a):
            print(f"a[{i}] shape: {a_layer.shape}")

        for i, lambda_layer in enumerate(self.lambda_lagrange):
            print(f"lambda[{i}] shape: {lambda_layer.shape}")
            return model_str

    def _heaviside(self, z):
        """
        Implement the Heaviside function to calculate activations based on a threshold.
        This function returns 1 where z meets or exceeds the threshold and 0 otherwise.

        Parameters:
        z (torch.Tensor): The input tensor representing the intermediate variable z at layer l and time t.

        Returns:
        torch.Tensor: A tensor of the same shape as z with values 0 or 1, based on whether each element
                      in z is above or below the threshold self.theta.
        """
        # Apply the Heaviside step function:
        # Convert continuous values in z to binary values (0 or 1) based on self.theta
        # - Returns 1 where z >= theta (neuron spikes).
        # - Returns 0 where z < theta (neuron does not spike).
        return (z >= self.theta).float()


    # ============ W_{l} update functions ============
    def _weight_update(self, l):
        """
        Update weights based on Equation (4) for layers 1 to L-1.
        Parameters:
        - l: layer index
        Returns: W_l with shape [features_out, features_in]
        """
        alpha = self.rho / 2
        z_l = self.z[l]
        a_l = self.a[l]
        a_l_minus_1 = self.a[l-1] if l > 0 else self.a_minus_one

        n_timesteps, batch_size, n_features = z_l.shape
        _, _, n_prev_features = a_l_minus_1.shape

        numerator = torch.zeros((n_features, n_prev_features), device=self.device)
        denominator = torch.zeros((n_prev_features, n_prev_features), device=self.device)

        for t in range(n_timesteps):
            if t == 0:
                x_l_t = z_l[0]
            else:
                x_l_t = z_l[t] - self.delta * z_l[t-1] + self.theta * a_l[t-1]

            numerator += alpha * torch.einsum('bf,bp->fp', x_l_t, a_l_minus_1[t])
            denominator += alpha * torch.einsum('bp,bq->pq', a_l_minus_1[t], a_l_minus_1[t])

        denominator += torch.eye(denominator.shape[0], device=self.device) * 1e-6
        return numerator @ torch.inverse(denominator)

    def _weight_update_L(self, y):
        """
        Update weights for the final layer L based on Equation (6).
        Parameters:
        - y: target values [batch, n_outputs]
        Returns: W_L with shape [features_out, features_in]
        """
        alpha = self.rho / 2
        z_L = self.z[self.L-1]
        a_L_minus_1 = self.a[self.L-2]

        n_timesteps, batch_size, n_features = z_L.shape
        _, _, n_prev_features = a_L_minus_1.shape

        # Calculate lambda term
        lambda_term = -0.5 * torch.einsum('bf,bp->fp', self.lambda_lagrange, a_L_minus_1[-1])

        # Calculate sum term for numerator
        sum_term = torch.zeros((n_features, n_prev_features), device=self.device)
        for t in range(n_timesteps):
            if t == 0:
                x_L_t = z_L[0]
            else:
                x_L_t = z_L[t] - self.delta * z_L[t-1]

            sum_term += alpha * torch.einsum('bf,bp->fp', x_L_t, a_L_minus_1[t])

        # Calculate denominator
        denominator = torch.zeros((n_prev_features, n_prev_features), device=self.device)
        for t in range(n_timesteps):
            denominator += alpha * torch.einsum('bp,bq->pq', a_L_minus_1[t], a_L_minus_1[t])

        denominator += torch.eye(denominator.shape[0], device=self.device) * 1e-6
        return (lambda_term + sum_term) @ torch.inverse(denominator)


    # ============ a_{l,t} update functions ============
    # Activation update for l=1,...,L-2, t=1,...,T-1 (Equation 8)
    def _calculate_u_w_v(self, l, t):
        """
        Calculate helper vectors u_l, v_l, w_l
        """

        z_l = self.z[l]
        a_l = self.a[l]
        a_l_minus_1 = self.a[l-1] if l > 0 else self.a_minus_one

        # Calculate u_l[t]
        if t == 0:
            u_l_t = z_l[0]
        else:
            u_l_t = z_l[t] - self.delta * z_l[t-1]

        # Calculate v_l[t]
        if t == 0:
            v_l_t = u_l_t
        else:
            v_l_t = u_l_t + self.theta * a_l[t-1]

        # Calculate w_l[t]
        projected_a = torch.matmul(a_l_minus_1[t], self.W[l].t())
        w_l_t = u_l_t - projected_a

        return u_l_t, w_l_t, v_l_t

    def _activation_update(self, l, t):
        """
        Update activations based on equation (8)
        """

        alpha = self.rho / 2
        beta = alpha

        # Get dimensions
        batch_size = self.z[l][t].shape[0]  # 256
        n_features = self.z[l][t].shape[1]  # 128 for layer 0
        W_next = self.W[l+1]  # [64, 128] for layer 0->1

        # Matrix terms
        I = torch.eye(n_features, device=self.device)
        term1 = (self.theta ** 2) * alpha * I
        term2 = alpha * W_next.t() @ W_next  # [128, 64] @ [64, 128] -> [128, 128]
        term3 = beta * I

        matrix_to_invert = term1 + term2 + term3

        # Calculate RHS terms
        _, w_next, _ = self._calculate_u_w_v(l, t+1)
        rhs_term1 = -self.theta * alpha * w_next  # [256, 128]

        _, _, v_next = self._calculate_u_w_v(l+1, t)  # v_next shape: [256, 64]
        # Fix: multiply in correct order
        rhs_term2 = alpha * torch.matmul(v_next, W_next)  # [256, 64] @ [64, 128] -> [256, 128]

        heaviside_term = self._heaviside(self.z[l][t])  # [256, 128]
        rhs_term3 = beta * heaviside_term

        # Combine RHS terms
        rhs = rhs_term1 + rhs_term2 + rhs_term3

        # Solve the system
        updated_a = torch.linalg.solve(matrix_to_invert, rhs.t()).t()

        return updated_a

    # Activation update for l=L-1, t=1,...,T-1 (Equation 10)
    def _activation_update_T(self, l):
        """
        Update activations for l=1,...,L-2 at t=T based on equation (9):
        a_l,T = (α_l+1,t*W_{l+1}^T*W_{l+1} + β_l,t*I)^{-1} *
                (α_l+1,t*W_{l+1}^T*v_{l+1,t} + β_l,t*h_l(z_l,t))
        """

        alpha = self.rho / 2
        beta = alpha

        # Get dimensions
        batch_size = self.z[l][self.T-1].shape[0]  # 256
        n_features = self.z[l][self.T-1].shape[1]  # current layer features
        W_next = self.W[l+1]

        # Calculate matrix to invert: α_l+1,t*W_{l+1}^T*W_{l+1} + β_l,t*I
        term1 = alpha * W_next.t() @ W_next
        term2 = beta * torch.eye(n_features, device=self.device)
        matrix_to_invert = term1 + term2

        # Calculate v_{l+1,t} for the first RHS term
        _, _, v_next = self._calculate_u_w_v(l+1, self.T-1)
        # Fix: multiply in correct order
        rhs_term1 = alpha * torch.matmul(v_next, W_next)  # [256, 64] @ [64, 128] -> [256, 128]

        # Calculate h_l(z_l,t) for second RHS term
        heaviside_term = self._heaviside(self.z[l][self.T-1])  # [256, 128]
        rhs_term2 = beta * heaviside_term

        # Combine RHS terms
        rhs = rhs_term1 + rhs_term2

        # Solve the system
        updated_a = torch.linalg.solve(matrix_to_invert, rhs.t()).t()

        return updated_a

    def _activation_update_Lminus1(self, t):
        """
        Update activations for l=L-1 based on equation (11):
        a_L-1,t = (α_L,t*W_L^T*W_L + β_L-1,t*I)^{-1} *
                  (W_L^T(α_L,t*u_L,t - λ/2*1(t=T)) + β_L-1,t*h_L-1(z_L-1,t - θ))
        """

        alpha = self.rho / 2
        beta = alpha
        l = self.L - 2  # L-1 in zero-based indexing

        # Get dimensions
        batch_size = self.z[l][t].shape[0]
        n_features = self.z[l][t].shape[1]
        W_L = self.W[self.L-1]

        # Calculate matrix to invert: α_L,t*W_L^T*W_L + β_L-1,t*I
        term1 = alpha * W_L.t() @ W_L
        term2 = beta * torch.eye(n_features, device=self.device)
        matrix_to_invert = term1 + term2

        # Calculate u_L,t
        u_L, _, _ = self._calculate_u_w_v(self.L-1, t)

        # Calculate first part of RHS: W_L^T(α_L,t*u_L,t)
        # Transpose u_L to match W_L dimensions
        rhs_term1 = alpha * (W_L.t() @ u_L.t()).t()  # Resulting in shape [batch_size, n_features]

        # Add lambda term if t=T-1
        if t == self.T-1:
            rhs_term1 -= (W_L.t() @ (self.lambda_lagrange / 2)).t()

        # Calculate h_L-1(z_L-1,t - θ)
        heaviside_term = self._heaviside(self.z[l][t] - self.theta)
        rhs_term2 = beta * heaviside_term

        # Combine RHS terms
        rhs = rhs_term1 + rhs_term2

        # Solve the system
        updated_a = torch.linalg.solve(matrix_to_invert, rhs.t()).t()

        return updated_a

    def _activation_update_Lminus1_T(self):
        """
        Update activations for l=L-1 at t=T based on equation (11) with t=T:
        a_L-1,T = (α_L,T*W_L^T*W_L + β_L-1,T*I)^{-1} *
                  (W_L^T(α_L,T*u_L,T - λ/2) + β_L-1,T*h_L-1(z_L-1,T - θ))
        """

        alpha = self.rho / 2
        beta = alpha
        l = self.L - 2  # L-1 in zero-based indexing
        t = self.T - 1  # T in zero-based indexing

        # Get dimensions
        batch_size = self.z[l][t].shape[0]
        n_features = self.z[l][t].shape[1]
        W_L = self.W[self.L-1]

        # Calculate matrix to invert: α_L,T*W_L^T*W_L + β_L-1,T*I
        term1 = alpha * W_L.t() @ W_L
        term2 = beta * torch.eye(n_features, device=self.device)
        matrix_to_invert = term1 + term2

        # Calculate u_L,T for final timestep
        u_L, _, _ = self._calculate_u_w_v(self.L-1, t)

        # Calculate RHS: W_L^T(α_L,T*u_L,T - λ/2)
        # Ensure correct multiplication order and transpose where necessary
        rhs_term1 = (W_L.t() @ (alpha * u_L.t() - (self.lambda_lagrange / 2).t())).t()  # Adjusted for shape compatibility


        # Calculate h_L-1(z_L-1,T - θ)
        heaviside_term = self._heaviside(self.z[l][t] - self.theta)
        rhs_term2 = beta * heaviside_term

        # Combine RHS terms
        rhs = rhs_term1 + rhs_term2

        # Solve the system
        updated_a = torch.linalg.solve(matrix_to_invert, rhs.t()).t()

        return updated_a

     # ============ z_{l,t} update functions ============
    def _calculate_p_s_q_r(self, l, t):
        """
        Calculate helper vectors defined in paper:
        p_l = W_l[a_l-1,1, ..., a_l-1,T]
        s_l = p_l + δ[0, z_l,1, ..., z_l,T-1]
        q_l = s_l - θ[0, a_l,1, ..., a_l,T-1]
        r_l = -p_l + z_l
        """

        # Get previous layer activations and current layer data
        a_l_minus_1 = self.a[l-1] if l > 0 else self.a_minus_one

        # Calculate p_l: project previous layer activations through weights
        p_l = torch.matmul(a_l_minus_1[t], self.W[l].t())  # [batch, features]

        # Calculate s_l: add decayed previous z values
        if t == 0:
            s_l = p_l
        else:
            s_l = p_l + self.delta * self.z[l][t-1]

        # Calculate q_l: subtract theta-weighted activations
        if t == 0:
            q_l = s_l
        else:
            q_l = s_l - self.theta * self.a[l][t-1]

        # Calculate r_l: negative p_l plus current z
        r_l = -p_l + self.z[l][t]

        return p_l, s_l, q_l, r_l

    def _z_update(self, l, t):
        """
        Update z_{l,t} for l=1,...,L-1 and t=1,...,T-1
        z_{l,t} = (α_l,t*q_l,t + α_l,t+1*δ(r_l,t+1 + θa_l,t)*1(t<T)) / (α_l,t + δ²α_l,t+1*1(t<T))
        """
        alpha = self.rho / 2

        # Calculate helper vectors
        _, _, q_l, r_l = self._calculate_p_s_q_r(l, t)

        # Calculate initial z update (before check_entries)
        numerator = alpha * q_l

        if t < self.T - 1:
            _, _, _, r_next = self._calculate_p_s_q_r(l, t+1)
            second_term = alpha * self.delta * (r_next + self.theta * self.a[l][t])
            numerator = numerator + second_term

        denominator = alpha
        if t < self.T - 1:
            denominator = denominator + (self.delta ** 2) * alpha

        # Initial z update
        z_update = numerator / denominator

        # Apply check_entries
        return self.check_entries(z_update, lambda z: alpha * torch.norm(z - q_l)**2)

    def _z_update_T(self, l):
        """
        Update z_{l,T} for l=1,...,L-1
        z_{l,T} = q_{l,T}
        """
        # Calculate helper vectors for final timestep
        _, _, q_l, _ = self._calculate_p_s_q_r(l, self.T-1)

        # For final timestep, z_{l,T} = q_{l,T}
        z_update = q_l

        # Apply check_entries
        return self.check_entries(z_update, lambda z: self.rho/2 * torch.norm(z - q_l)**2)

    def _z_update_L(self, t, y):
        """
        Update z_{L,t} for t=1,...,T-1 based on equation (16):
        z_{L,t} = (α_L,t*s_L,t + α_L,t+1*δr_L,t+1*1(t<T) + (y-λ/2)*1(t=T)) /
                  (α_L,t + δ²α_L,t+1*1(t<T) + 1(t=T))
        """
        alpha = self.rho / 2
        l = self.L - 1  # Last layer index

        # Calculate helper vectors
        _, s_l, _, r_l = self._calculate_p_s_q_r(l, t)

        # First term: α_L,t*s_L,t
        numerator = alpha * s_l

        # Handle t < T case
        if t < self.T - 1:
            # Calculate r_{L,t+1} for next timestep
            _, _, _, r_next = self._calculate_p_s_q_r(l, t+1)
            # Add second term: α_L,t+1*δr_L,t+1
            numerator = numerator + alpha * self.delta * r_next
            denominator = alpha + (self.delta**2 * alpha)
        else:
            # Final timestep: add (y-λ/2) term
            numerator = numerator + (y - self.lambda_lagrange/2)
            denominator = alpha + 1.0

        # Calculate update
        z_update = numerator / denominator

        # Apply check_entries with appropriate cost function
        def cost_fn(z):
            cost = alpha * torch.norm(z - s_l)**2
            if t < self.T - 1:
                # Add cost term for t < T case
                _, _, _, r_next = self._calculate_p_s_q_r(l, t+1)
                cost += alpha * torch.norm(r_next - self.delta * z)**2
            else:
                # Add cost term for t = T case
                cost += torch.norm(z - y)**2 + torch.sum(self.lambda_lagrange * z)
            return cost

        return self.check_entries(z_update, cost_fn)

    def _z_update_L_T(self, y):
        """
        Update z_{L,T} based on equation (16) with t=T:
        z_{L,T} = (α_L,T*s_L,T + (y-λ/2)*1(t=T)) / (α_L,T + 1)

        Parameters:
        - y: target values [batch_size, n_outputs]
        """
        alpha = self.rho / 2
        l = self.L - 1  # Last layer index
        t = self.T - 1  # Last timestep

        # Calculate helper vectors for final timestep
        _, s_l, _, _ = self._calculate_p_s_q_r(l, t)
        # s_l is already [batch_size, n_outputs] for last layer

        # Reshape y if needed
        batch_size = s_l.shape[0]
        n_outputs = s_l.shape[1]  # Use s_l shape since it's already correct
        if y.shape != (batch_size, n_outputs):
            y = y.view(batch_size, n_outputs)

        # Calculate update with matched dimensions
        numerator = (alpha * s_l) + (y - self.lambda_lagrange/2)  # All [batch_size, n_outputs]
        denominator = alpha + 1.0

        # Calculate update
        z_update = numerator / denominator

        # Define cost function according to equation (15)
        def cost_fn(z):
            return (torch.norm(z - y)**2 +
                    alpha * torch.norm(z - s_l)**2 +
                    torch.sum(self.lambda_lagrange * z))

        return self.check_entries(z_update, cost_fn)

    def check_entries(self, z, cost_function, params=None):
        """
        Vectorized implementation of Algorithm 1
        """
        # Work with a copy of z
        z_adjusted = z.clone()

        # Create mask for active neurons (z > θ)
        active_mask = z > self.theta

        # Calculate costs for current values
        current_costs = cost_function(z_adjusted)

        # Calculate costs at theta (for turning off)
        theta_costs = cost_function(torch.full_like(z_adjusted, self.theta))

        # Where z > θ and cost at θ is lower or equal, set to θ (turn off)
        turn_off_mask = active_mask & (theta_costs <= current_costs)
        z_adjusted[turn_off_mask] = self.theta

        # Create mask for inactive neurons (z ≤ θ)
        inactive_mask = ~active_mask

        # Calculate costs at theta + ε (for turning on)
        epsilon = 1e-6
        theta_plus_eps_costs = cost_function(torch.full_like(z_adjusted, self.theta + epsilon))

        # Where z ≤ θ and cost at θ+ε is lower, set to θ+ε (turn on)
        turn_on_mask = inactive_mask & (theta_plus_eps_costs < current_costs)
        z_adjusted[turn_on_mask] = self.theta + epsilon

        return z_adjusted
    # ============ lagrange multiplier update ============

    def _lambda_update(self):
        """
        Update Lagrange multipliers (lambda) based on equation:
        λ^(+1) ← λ + ρ(z_{L,T} - δz_{L,T-1} - W_L a_{L-1,T})
        """
        print("\nDebugging _lambda_update:")

        # Parameters
        rho = self.rho
        delta = self.delta
        l = self.L - 1  # Output layer index
        t = self.T - 1  # Final timestep

        # Get tensors with proper shapes
        z_L_T = self.z[l][t]                    # [batch_size, n_outputs]
        z_L_T_minus_1 = self.z[l][t - 1]        # [batch_size, n_outputs]
        a_L_minus_1_T = self.a[l-1][t]          # [batch_size, hidden_dim]
        W_L = self.W[l]                         # [n_outputs, hidden_dim]

        # Calculate projection term: W_L @ a_{L-1,T}
        projection = torch.matmul(a_L_minus_1_T, W_L.t())  # [batch_size, n_outputs]

        # Calculate the full update term
        update_term = z_L_T - delta * z_L_T_minus_1 - projection

        # Update lambda
        lambda_update = self.lambda_lagrange + rho * update_term

        return lambda_update

    def feed_forward(self, inputs):
        """
        Implement forward pass using SNNTorch LIF neurons.
        Returns membrane potentials of final layer.

        Parameters:
        - inputs: [timesteps, batch_size, input_dim]
        Returns:
        - mem[-1]: Final layer membrane potentials [batch_size, n_outputs]
        """
        # Initialize membrane potentials for each layer
        mem = []
        spikes = []  # Track spikes for debugging
        for l in range(self.L):
            if l == 0:
                mem.append(torch.zeros(inputs.shape[1], self.hidden_dims[0], device=self.device))
            elif l == self.L - 1:
                mem.append(torch.zeros(inputs.shape[1], self.n_outputs, device=self.device))
            else:
                mem.append(torch.zeros(inputs.shape[1], self.hidden_dims[l], device=self.device))

        # Create LIF neurons for each layer
        neurons = []
        for l in range(self.L):
            if l < self.L - 1:
                # Hidden layers: reset by subtraction
                lif = snn.Leaky(
                    beta=self.delta,
                    threshold=self.theta,
                    reset_mechanism="subtract",
                    learn_beta=False,
                    learn_threshold=False
                )
            else:
                # Output layer: no reset
                lif = snn.Leaky(
                    beta=self.delta,
                    threshold=self.theta,
                    reset_mechanism="none",
                    learn_beta=False,
                    learn_threshold=False
                )
            neurons.append(lif)

        # Process each timestep and track activations
        for t in range(inputs.shape[0]):
            x = inputs[t]  # Current input slice

            for l in range(self.L):

                # Apply weights and LIF neuron dynamics
                x = x @ self.W[l].T  # Apply weights

                spike, mem[l] = neurons[l](x, mem[l])  # Apply LIF neuron
                spikes.append(spike)  # Track spikes


                # Spikes become input to the next layer
                x = spike

        # Check if final layer membrane potentials seem reasonable
        final_output = mem[-1]
        print(f"\nFinal Layer Membrane Potentials - Mean: {final_output.mean():.4f}, Std: {final_output.std():.4f}")
        print(f"Final Layer Potential Range - Min: {final_output.min().item():.4f}, Max: {final_output.max().item():.4f}")

        return final_output  # [batch_size, n_outputs]
    def evaluate(self, inputs, targets):
        """
        Evaluate model performance on N-MNIST dataset with detailed prediction information.

        Parameters:
        - inputs: [timesteps, batch_size, input_dim]
        - targets: [batch_size, 10] (one-hot encoded)

        Returns:
        - loss: Cross-entropy loss
        - predictions: Final layer output
        """
        print("\n=== N-MNIST Evaluation ===")

        # Get final layer activity
        raw_predictions = self.feed_forward(inputs)  # [batch_size, n_outputs]
        print("\nFinal Layer Activity:")
        print(f"Shape: {raw_predictions.shape}")
        print(f"Activity Stats - Mean: {raw_predictions.mean():.4f}, Std: {raw_predictions.std():.4f}")
        print(f"Activity Range - Min: {raw_predictions.min():.4f}, Max: {raw_predictions.max():.4f}")

        # Calculate predictions and true classes
        pred_classes = raw_predictions.argmax(dim=1)  # Predicted classes
        true_classes = targets.argmax(dim=1)          # True classes (from one-hot labels)

        # Initialize confusion matrix and metrics storage
        confusion_matrix = torch.zeros(10, 10, dtype=torch.int32)
        metrics_per_class = []

        # Display predictions count for each real class with highlight format
        print("\nDetailed Predictions for Each Real Class:")
        for class_idx in range(10):
            # Mask for current class
            class_mask = (true_classes == class_idx)

            # Predictions for current class
            predicted_for_class = pred_classes[class_mask]

            # Count predictions for each possible class (0 to 9), with brackets for the current class index
            predictions_count = []
            for i in range(10):
                count = (predicted_for_class == i).sum().item()
                if i == class_idx:
                    predictions_count.append(f"[{count}]")  # Highlight correct predictions with brackets
                else:
                    predictions_count.append(str(count))

            print(f"Real Class {class_idx} -> Prediction Counts: ({', '.join(predictions_count)})")

            # Populate confusion matrix row for real class
            confusion_matrix[class_idx] = torch.tensor([int(predictions_count[i].strip("[]")) for i in range(10)])

            # Calculate metrics for current class
            TP = confusion_matrix[class_idx, class_idx].float()
            FP = confusion_matrix[:, class_idx].sum().float() - TP
            FN = confusion_matrix[class_idx, :].sum().float() - TP

            precision = TP / (TP + FP) if TP + FP > 0 else torch.tensor(0.0)
            recall = TP / (TP + FN) if TP + FN > 0 else torch.tensor(0.0)
            f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else torch.tensor(0.0)

            metrics_per_class.append({'precision': precision, 'recall': recall, 'f1': f1})

        # Calculate overall metrics
        macro_precision = torch.stack([m['precision'] for m in metrics_per_class]).mean()
        macro_recall = torch.stack([m['recall'] for m in metrics_per_class]).mean()
        macro_f1 = torch.stack([m['f1'] for m in metrics_per_class]).mean()
        accuracy = (pred_classes == true_classes).float().mean()

        print("\nOverall Metrics:")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Macro Precision: {macro_precision:.4f}")
        print(f"Macro Recall: {macro_recall:.4f}")
        print(f"Macro F1-Score: {macro_f1:.4f}")

        # Calculate cross-entropy loss
        loss = self.loss_fn(raw_predictions, true_classes)
        print(f"\nLoss: {loss.item():.6f}")

        return loss, raw_predictions
    def fit(self, inputs, targets,warm = True):
      """
      Update the optimization variables following Algorithm 2 from the paper.
      Architecture:
      - L = 4 (3 hidden + 1 output)
      - Layer 0: 1156 -> 128
      - Layer 1: 128 -> 64
      - Layer 2: 64 -> 32
      - Layer 3: 32 -> 10
      """
      print("\nStarting optimization updates...")
      # First update hidden layers (0 to 2)
      for l in range(self.L - 1):  # 0,1,2

          # Update weights using _weight_update
          self.W[l] = self._weight_update(l)

          # Update timesteps 0 to T-2
          for t in range(self.T-1):

              if l < self.L - 2:  # Layers 0,1
                  self.a[l][t] = self._activation_update(l, t)
              else:  # Layer 2
                  self.a[l][t] = self._activation_update_Lminus1(t)


              # Update z for all hidden layers
              self.z[l][t] = self._z_update(l, t)


          # Handle final timestep (T-1)

          if l < self.L - 2:
              self.a[l][self.T-1] = self._activation_update_T(l)

          else:  # Layer 2

              self.a[l][self.T-1] = self._activation_update_Lminus1_T()



          self.z[l][self.T-1] = self._z_update_T(l)



      # Update output layer weights

      self.W[self.L-1] = self._weight_update_L(y=targets)


      # Update output layer z values for timesteps 0 to T-2
      for t in range(self.T-1):
          self.z[self.L-1][t] = self._z_update_L(t, y=targets)

      # Update final timestep of output layer
      self.z[self.L-1][self.T-1] = self._z_update_L_T(y=targets)

      # Update Lagrange multiplier
      if not warm:
          self.lambda_lagrange = self._lambda_update()

      # Evaluate current performance
      loss, predictions = self.evaluate(inputs, targets)

      return loss, predictions

In [3]:
import torch
import torch.nn as nn
import torchvision
import tonic
import tonic.transforms as transforms
from torch.utils.data import DataLoader
from tonic import DiskCachedDataset
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils


# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations
sensor_size = tonic.datasets.NMNIST.sensor_size
frame_transform = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.ToFrame(sensor_size=sensor_size, time_window=1000)
])

# Load datasets
trainset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=True)
testset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=False)





Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/1afc103f-8799-464a-a214-81bb9b1f9337 to ./data/NMNIST/train.zip


  0%|          | 0/1011893601 [00:00<?, ?it/s]

Extracting ./data/NMNIST/train.zip to ./data/NMNIST
Downloading https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/a99d0fee-a95b-4231-ad22-988fdb0a2411 to ./data/NMNIST/test.zip


  0%|          | 0/169674850 [00:00<?, ?it/s]

Extracting ./data/NMNIST/test.zip to ./data/NMNIST


In [4]:
def prepare_nmnist_data(inputs, labels, device, n_timesteps=300):
    batch_size = inputs.shape[0]
    inputs = inputs.to(device)
    labels = labels.to(device)
    inputs = inputs[:, :n_timesteps]
    # Reshape inputs to [timesteps, batch_size, input_dim]
    inputs = inputs.reshape(batch_size, n_timesteps, -1).permute(1, 0, 2).float()
    # One-hot encode labels
    labels_onehot = torch.zeros(batch_size, 10, device=device)
    labels_onehot.scatter_(1, labels.unsqueeze(1), 1)
    print(f"\nData preparation stats:")
    print(f"Input shape: {inputs.shape}")
    print(f"Labels shape: {labels_onehot.shape}")

    return inputs, labels_onehot

def train(model, trainloader, num_epochs):
    """
    Train ADMM-SNN model
    """
    # Warming phase
    print("\nWarming Phase:")
    warming_losses = []
    for epoch in range(5):
        print(f'Warming Epoch [{epoch+1}/5]')
        epoch_losses = []
        epoch_accuracies = []

        for batch_idx, (inputs, labels) in enumerate(trainloader):
            print(f'  Batch [{batch_idx+1}/{len(trainloader)}]')
            # Prepare data
            inputs, labels = prepare_nmnist_data(inputs, labels, device)

            # Warming step
            loss, predictions = model.fit(inputs, labels, True)

            # Calculate accuracy
            pred_classes = torch.argmax(predictions, dim=1)
            true_classes = torch.argmax(labels, dim=1)
            accuracy = (pred_classes == true_classes).float().mean().item()

            epoch_losses.append(loss)
            epoch_accuracies.append(accuracy)

        # Epoch summary
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        avg_acc = sum(epoch_accuracies) / len(epoch_accuracies)
        warming_losses.append(avg_loss)
        print(f'  Epoch Summary - Loss: {avg_loss:.6f}, Accuracy: {avg_acc:.4f}')


    # Main training
    print("\nMain Training Phase:")
    training_metrics = []
    for epoch in range(num_epochs):
        print(f'Epoch [{epoch+1}/{num_epochs}]')
        epoch_losses = []
        epoch_accuracies = []

        for batch_idx, (inputs, labels) in enumerate(trainloader):
            print(f'  Batch [{batch_idx+1}/{len(trainloader)}]')

            # Prepare data
            inputs, labels = prepare_nmnist_data(inputs, labels, device)

            # Training step
            loss, predictions = model.fit(inputs, labels,False)

            # Ensure predictions and labels have correct shape
            labels = labels.t()

            # Calculate accuracy
            pred_classes = torch.argmax(predictions, dim=1)
            true_classes = torch.argmax(labels, dim=1)
            accuracy = (pred_classes == true_classes).float().mean().item()

            epoch_losses.append(loss)
            epoch_accuracies.append(accuracy)



        # Epoch summary
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        avg_acc = sum(epoch_accuracies) / len(epoch_accuracies)
        training_metrics.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'accuracy': avg_acc
        })
        print(f'  Epoch Summary - Loss: {avg_loss:.6f}, Accuracy: {avg_acc:.4f}')

    return warming_losses, training_metrics

def evaluate(model, testloader):
    print("\nEvaluation Phase:")
    total_loss = 0
    total_acc = 0
    num_batches = 0
    batch_metrics = []

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(testloader):
            print(f'  Batch [{batch_idx+1}/{len(testloader)}]')

            # Prepare data
            inputs, labels = prepare_nmnist_data(inputs, labels, device)

            # Forward pass
            loss, predictions = model.evaluate(inputs, labels)

            # Calculate accuracy
            pred_classes = torch.argmax(predictions, dim=0)
            true_classes = torch.argmax(labels, dim=0)
            accuracy = (pred_classes == true_classes).float().mean().item()

            # Store batch metrics
            batch_metrics.append({
                'batch': batch_idx + 1,
                'loss': loss,
                'accuracy': accuracy
            })

            total_loss += loss
            total_acc += accuracy
            num_batches += 1

            print(f'    Loss: {loss:.6f}, Accuracy: {accuracy:.4f}')

    avg_loss = total_loss / num_batches
    avg_acc = total_acc / num_batches

    print(f'\nFinal Evaluation Results:')
    print(f'  Average Loss: {avg_loss:.6f}')
    print(f'  Average Accuracy: {avg_acc:.4f}')

    return avg_loss, avg_acc, batch_metrics

# Initialize model and training
n_timesteps = 300
input_dim = 2* sensor_size[0] * sensor_size[1]
hidden_dims = [256,64]
n_outputs = 10
# DataLoaders


# Create DataLoaders with the subsets
batch_size = 64
# Calculate size of 20% of data
train_size = int(0.2 * len(trainset)) // batch_size * batch_size
test_size = int(0.2 * len(testset)) // batch_size * batch_size

from torch.utils.data import Subset
train_subset = Subset(trainset, torch.randperm(len(trainset))[:train_size])
test_subset = Subset(testset, torch.randperm(len(testset))[:test_size])



trainloader = DataLoader(train_subset,
                        batch_size=batch_size,
                        collate_fn=tonic.collation.PadTensors(),
                        shuffle=True)
testloader = DataLoader(test_subset,
                       batch_size=batch_size,
                       collate_fn=tonic.collation.PadTensors())

print(f"Original training set size: {len(trainset)}")
print(f"training set size: {len(train_subset)}")
print(f"Original test set size: {len(testset)}")
print(f"test set size: {len(test_subset)}")



Original training set size: 60000
training set size: 11968
Original test set size: 10000
test set size: 1984


In [5]:
# Hyperparameters
rho = 0.25
delta = 0.85
theta = 0.2


# Create model
print("\nInitializing model...")
model = ADMM_SNN(
    n_samples=batch_size,
    n_timesteps=n_timesteps,
    input_dim=input_dim,
    hidden_dims=hidden_dims,  # Changed final dimension to 10
    n_outputs=10,
    rho=rho,
    delta=delta,
    theta=theta
)
print(model)


# Train model
print("\nStarting training process...")
num_epochs = 20
warming_losses, training_metrics = train(model, trainloader, num_epochs)

# Evaluate model
print("\nEvaluating model...")
test_loss, test_acc, test_metrics = evaluate(model, testloader)

# Print final results
print("\nTraining Summary:")
print(f"  Warming phase final loss: {warming_losses[-1]:.6f}")
print(f"  Training final loss: {training_metrics[-1]['loss']:.6f}")
print(f"  Training final accuracy: {training_metrics[-1]['accuracy']:.4f}")
print("\nTest Results:")
print(f"  Test Loss: {test_loss:.6f}")
print(f"  Test Accuracy: {test_acc:.4f}")


Initializing model...

Model Initialization Details:
Device: cpu
Previous activation shape: torch.Size([300, 64, 2312])
Weight layer 0 shape: torch.Size([256, 2312])
Weight layer 1 shape: torch.Size([64, 256])
Weight layer 2 shape: torch.Size([10, 64])
z[0] shape: torch.Size([300, 64, 256])
z[1] shape: torch.Size([300, 64, 64])
z[2] shape: torch.Size([300, 64, 10])
a[0] shape: torch.Size([300, 64, 256])
a[1] shape: torch.Size([300, 64, 64])
a[2] shape: torch.Size([300, 64, 10])
lambda[0] shape: torch.Size([10])
ADMM SNN Model Structure:
 - rho: 0.25, delta: 0.85, theta: 0.2
 - Number of timesteps: 300
 - Input dimension: torch.Size([300, 64, 2312])
 - Hidden layers: [torch.Size([256, 2312]), torch.Size([64, 256]), torch.Size([10, 64])]
 - Output dimension (Lagrange Multiplier): torch.Size([64, 10])
 - self.L : 3
 - self.T : 300


Starting training process...

Warming Phase:
Warming Epoch [1/5]
  Batch [1/187]

Data preparation stats:
Input shape: torch.Size([300, 64, 2312])
Labels sha

KeyboardInterrupt: 