<a href="https://colab.research.google.com/github/ergysmedaunipd/thesis/blob/main/ThesisUnipdSNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tonic
!pip install snntorch
!pip install psutil

Collecting tonic
  Downloading tonic-1.5.0-py3-none-any.whl.metadata (5.4 kB)
Collecting importRosbag>=1.0.4 (from tonic)
  Downloading importRosbag-1.0.4-py3-none-any.whl.metadata (4.3 kB)
Collecting pbr (from tonic)
  Downloading pbr-6.1.0-py2.py3-none-any.whl.metadata (3.4 kB)
Collecting expelliarmus (from tonic)
  Downloading expelliarmus-1.1.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading tonic-1.5.0-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.6/116.6 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading importRosbag-1.0.4-py3-none-any.whl (28 kB)
Downloading expelliarmus-1.1.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pbr-6.1.0-py2.py3-none-any.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import time
import snntorch as snn
from snntorch import surrogate


from typing import List


class ADMM_SNN:
    """ Class for ADMM Neural Network. """

    def __init__(self, n_samples: int, n_timesteps: int, input_dim: int, hidden_dims: List[int], n_outputs: int, rho: float, delta: float, theta: float):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.loss_fn = nn.CrossEntropyLoss()
        # Define hyperparameters:
        self.rho = rho
        self.delta = delta
        self.theta = theta

        self.L = len(hidden_dims)  # Number of layers
        self.T = n_timesteps       # Number of timesteps

        self.a0 = torch.zeros((n_timesteps, n_samples, input_dim)).to(self.device)
        print(f"Initialized a0 with shape: {self.a0.shape}")  # Debugging

        # Initialize W_l (weights for each layer)
        self.W = nn.ParameterList()
        for i, hidden_dim in enumerate(hidden_dims):
                if i == 0:
                    self.W.append(nn.Parameter(
                        torch.randn((hidden_dim, input_dim)).to(self.device) *
                        np.sqrt(2.0 / input_dim)
                    ))
                else:
                    self.W.append(nn.Parameter(
                        torch.randn((hidden_dim, hidden_dims[i-1])).to(self.device) *
                        np.sqrt(2.0 / hidden_dims[i-1])
                    ))

        # Initialize z_l (intermediate variables for each layer and timestep)
        self.z = [torch.rand((n_timesteps, n_samples, hidden_dim)).to(self.device) for hidden_dim in hidden_dims]
        for i, z_layer in enumerate(self.z):
            print(f"Initialized z[{i}] with shape: {z_layer.shape}")  # Debugging

        # Initialize a_l (activations for each layer and timestep)
        self.a = [torch.rand((n_timesteps, n_samples, hidden_dim)).to(self.device) for hidden_dim in hidden_dims]
        for i, a_layer in enumerate(self.a):
            print(f"Initialized a[{i}] with shape: {a_layer.shape}")  # Debugging

        # Initialize lambda (Lagrange multipliers for the output layer)
        self.lambda_lagrange = torch.ones((n_samples, n_outputs)).to(self.device)
        print(f"Initialized lambda_lagrange with shape: {self.lambda_lagrange.shape}")  # Debugging


    def __str__(self):
        model_str = "ADMM SNN Model Structure:\n"
        model_str += f" - rho: {self.rho}, delta: {self.delta}, theta: {self.theta}\n"
        model_str += f" - Number of timesteps: {self.T}\n"
        model_str += f" - Input dimension: {self.a0.size()}\n"
        model_str += f" - Hidden layers: {[w.shape for w in self.W]}\n"
        model_str += f" - Output dimension (Lagrange Multiplier): {self.lambda_lagrange.size()}\n"
        return model_str

    def _heaviside(self, z):
        """
        Implement the Heaviside function to calculate activations based on a single threshold.
        This function returns 1 where z exceeds the threshold and 0 otherwise.

        Parameters:
        z (torch.Tensor): The input tensor (intermediate variable z at layer l and time t).

        Returns:
        torch.Tensor: A tensor of the same shape as z with values 0 or 1 based on the threshold self.theta.
        """
        return (z >= self.theta).float()

    # ============ W_{l} update functions ============
    def _weight_update(self, z_l, a_l_minus_1):
        """
        Update the weights for intermediate layers (1 to L-1) based on Equation (4).
        Parameters:
        - z_l (torch.Tensor): Shape [T, batch, n_features]
        - a_l_minus_1 (torch.Tensor): Shape [T, batch, n_prev_features]

        Returns:
        - Updated weights for the layer.
        """
        # Define alpha as rho / 2

        alpha = self.rho / 2

        # Get dimensions
        n_timesteps, batch_size, n_features = z_l.shape
        _, _, n_prev_features = a_l_minus_1.shape

        # Initialize accumulator tensors
        numerator = torch.zeros((n_features, n_prev_features), device=self.device)
        denominator = torch.zeros((n_prev_features, n_prev_features), device=self.device)

        # Iterate over timesteps
        for t in range(n_timesteps):
            if t > 0:
                # For each timestep after 0, compute the adjusted z term
                # All terms should maintain their original dimensions
                current_z = z_l[t]  # [batch, n_features]
                prev_z = z_l[t - 1]  # [batch, n_features]
                current_a = a_l_minus_1[t]  # [batch, n_prev_features]
                prev_a = a_l_minus_1[t - 1]  # [batch, n_prev_features]

                # Compute z term adjustments within the same feature space
                z_diff = current_z - self.delta * prev_z  # [batch, n_features]

                # Add contribution to numerator
                # Use einsum to handle the batch dimension properly
                term = torch.einsum('bf,bp->fp', z_diff, current_a)
                numerator += alpha * term

                # Add theta * prev_a term separately to avoid dimension mismatch
                theta_term = torch.einsum('bp,bf->fp', prev_a, z_diff)
                numerator += alpha * self.theta * theta_term
            else:
                # For t=0, use simple matrix multiplication
                term = torch.einsum('bf,bp->fp', z_l[0], a_l_minus_1[0])
                numerator += alpha * term

            # Update denominator using current timestep activations
            denom_term = torch.einsum('bp,bq->pq', a_l_minus_1[t], a_l_minus_1[t])
            denominator += alpha * denom_term

        # Add small value to diagonal of denominator for numerical stability
        denominator += torch.eye(denominator.shape[0], device=self.device) * 1e-6

        # Compute final weight update
        updated_weights = numerator @ torch.inverse(denominator)

        return updated_weights

    def _weight_update_L(self, z_L, a_L_minus_1, y):
        """
        Update the weights for the final layer (L) based on Equation (6).
        """
        # Define alpha as rho / 2
        alpha = self.rho / 2

        # Get dimensions
        n_timesteps, batch_size, n_features = z_L.shape
        _, _, n_prev_features = a_L_minus_1.shape

        # Initialize accumulator tensors
        numerator = torch.zeros((n_features, n_prev_features), device=self.device)
        denominator = torch.zeros((n_prev_features, n_prev_features), device=self.device)

        # Handle the last timestep separately for Lagrange multiplier term
        a_L_minus_1_T = a_L_minus_1[-1]  # [batch, n_prev_features]

        # Project lambda and y to match feature dimensions if needed
        if self.lambda_lagrange.shape[1] != n_features:
            # Project lambda from [batch, 10] to [batch, n_features]
            lambda_proj = torch.zeros((batch_size, n_features), device=self.device)
            lambda_proj[:, :self.lambda_lagrange.shape[1]] = self.lambda_lagrange

            # Project y from [10, batch] to [batch, n_features]
            y = y.t()  # [batch, 10]
            y_proj = torch.zeros((batch_size, n_features), device=self.device)
            y_proj[:, :y.shape[1]] = y
        else:
            lambda_proj = self.lambda_lagrange
            y_proj = y.t()

        # Compute Lagrange multiplier contribution with correct dimensions
        lambda_term = torch.einsum('bf,bp->fp', lambda_proj, a_L_minus_1_T) / 2
        numerator += lambda_term

        # Process all timesteps
        for t in range(n_timesteps):
            if t > 0:
                # Compute z term adjustments
                z_diff = z_L[t] - self.delta * z_L[t - 1]  # [batch, n_features]

                # Update numerator
                term = torch.einsum('bf,bp->fp', z_diff, a_L_minus_1[t])
                numerator += alpha * term
            else:
                # First timestep
                term = torch.einsum('bf,bp->fp', z_L[0], a_L_minus_1[0])
                numerator += alpha * term

            # Update denominator
            denom_term = torch.einsum('bp,bq->pq', a_L_minus_1[t], a_L_minus_1[t])
            denominator += alpha * denom_term

        # Add small value to diagonal of denominator for numerical stability
        denominator += torch.eye(denominator.shape[0], device=self.device) * 1e-6

        # Compute final weight update
        updated_weights = numerator @ torch.inverse(denominator)

        return updated_weights
    # ============ z_{l,t} update functions ============
    def _z_update(self, l, t):
        """
        Update z_{l,t} for l = 1, ..., L-1 and t = 1, ..., T-1 based on Equation (14).
        Now handles batched operations properly.
        """
        alpha = self.rho / 2
        q_l_t = self._calculate_q_l(l, t)  # Shape: [batch_size, n_features]
        r_l_t_plus1 = self._calculate_r_l(l, t + 1)  # Shape: [batch_size, n_features]

        # Compute components for update
        if t < self.T:
            a_term = self.theta * self.a[l][t]  # Shape: [batch_size, n_features]
            r_term = r_l_t_plus1 + a_term
            numerator = alpha * q_l_t + alpha * self.delta * r_term
            denominator = alpha + self.delta**2 * alpha
        else:
            numerator = alpha * q_l_t
            denominator = alpha

        # Perform division (broadcasting handles batch dimension)
        z_update = numerator / denominator

        return z_update

    def _z_update_T(self, l):
        """
        Update z_{l,T} for l = 1, ..., L-1 based on Equation (14)*.
        Applies Heaviside function for thresholding.
        """
        alpha = self.rho / 2
        q_l_T = self._calculate_q_l(l, self.T)  # Get q_{l,T}
        r_l_T_minus1 = self._calculate_r_l(l, self.T - 1)  # Get r_{l,T-1}

        # Compute the update for z_{l,T} based on Equation (14)* with the Heaviside function
        numerator = alpha * q_l_T + alpha * self.delta * (r_l_T_minus1 + self.theta * self.a[l][self.T - 1])
        denominator = alpha + self.delta**2 * alpha
        z_update_T = numerator / denominator

        # Apply Heaviside function to enforce thresholding
        return self._heaviside(z_update_T)

    def _z_update_L(self, y):
        """
        Update z_{L,T} for the last layer L based on Equation (16).
        Following similar structure to _z_update with proper batch handling.
        """
        alpha = self.rho / 2
        s_L_T = self._calculate_s_L()  # Shape: [batch_size, n_features]
        r_L_T_minus1 = self._calculate_r_l(self.L - 1, self.T - 1)  # Shape: [batch_size, n_features]

        # Transpose y to match batch dimension
        y = y.t()  # [128, 10]

        # Compute terms with proper broadcasting
        numerator = alpha * s_L_T  # [128, 64]
        r_term = r_L_T_minus1  # [128, 64]

        # Project y and lambda to feature space if needed
        if y.shape[1] != r_term.shape[1]:
            W_L = self.W[self.L-1]  # [64, 128]
            y = torch.matmul(y, W_L[:, :y.shape[1]].t())  # [128, 64]
            lambda_term = torch.matmul(self.lambda_lagrange, W_L[:, :self.lambda_lagrange.shape[1]].t())  # [128, 64]
        else:
            lambda_term = self.lambda_lagrange

        # Add terms with proper broadcasting
        r_term = r_term + y - lambda_term / 2  # [128, 64]
        numerator = numerator + alpha * self.delta * r_term  # [128, 64]
        denominator = alpha + self.delta**2 * alpha  # scalar

        # Compute final update
        z_update = numerator / denominator  # [128, 64]

        # Apply Heaviside function to enforce thresholding
        return self._heaviside(z_update)

    def _z_update_L_T(self, y):
        """
        Update z_{L,T} for the last layer L at time T.
        Following same structure as _z_update_L.
        """
        alpha = self.rho / 2
        s_L_T = self._calculate_s_L()  # [batch_size, n_features]
        r_L_T = self._calculate_r_l(self.L - 1, self.T - 1)  # [batch_size, n_features]

        # Transpose y to match batch dimension
        y = y.t()  # [128, 10]

        # Compute terms with proper broadcasting
        numerator = alpha * s_L_T  # [128, 64]
        r_term = r_L_T  # [128, 64]

        # Project y and lambda to feature space if needed
        if y.shape[1] != r_term.shape[1]:
            W_L = self.W[self.L-1]  # [64, 128]
            y = torch.matmul(y, W_L[:, :y.shape[1]].t())  # [128, 64]
            lambda_term = torch.matmul(self.lambda_lagrange, W_L[:, :self.lambda_lagrange.shape[1]].t())  # [128, 64]
        else:
            lambda_term = self.lambda_lagrange

        # Add terms with proper broadcasting
        r_term = r_term + y - lambda_term / 2  # [128, 64]
        numerator = numerator + alpha * self.delta * r_term  # [128, 64]
        denominator = alpha + self.delta**2 * alpha  # scalar

        # Compute final update
        z_update = numerator / denominator  # [128, 64]


        # Apply Heaviside function to enforce thresholding
        return self._heaviside(z_update)

    def check_entries(self, z, cost_function):
        """
        Implements Algorithm 1 to check and adjust entries in z_{l,t}.
        Handles batched operations properly.

        Parameters:
        - z (torch.Tensor): Input tensor with shape [batch_size, n_features]
        - cost_function: Function to compute cost (e.g., torch.norm)

        Returns:
        - torch.Tensor: Adjusted tensor with same shape as input
        """
        # Get original shape and reshape if needed
        original_shape = z.shape
        if len(original_shape) > 2:
            z = z.view(-1, original_shape[-1])

        # Create a copy to modify
        z_adjusted = z.clone()

        # Get mask for values above threshold
        above_threshold = z > self.theta

        # Compute costs for current values and threshold
        current_costs = cost_function(z_adjusted)
        threshold_costs = cost_function(torch.full_like(z_adjusted, self.theta))

        # Where current cost is higher than threshold cost, switch to threshold
        should_switch_off = above_threshold & (current_costs > threshold_costs)
        z_adjusted[should_switch_off] = self.theta

        # Check for values below or at threshold that should be switched on
        below_or_at_threshold = ~above_threshold
        switched_on_values = z_adjusted + self.theta
        switched_on_costs = cost_function(switched_on_values)
        should_switch_on = below_or_at_threshold & (switched_on_costs < current_costs)
        z_adjusted[should_switch_on] = switched_on_values[should_switch_on]

        # Apply Heaviside function
        z_final = self._heaviside(z_adjusted)

        # Reshape back to original shape if needed
        if len(original_shape) > 2:
            z_final = z_final.view(original_shape)

        return z_final

    def _calculate_q_l(self, l, t):
        """
        Calculate q_{l,t} based on s_{l,t} and activation threshold.
        Following the paper's equation definitions.
        """
        s_l_t = self._calculate_s_l(l, t)

        # Make sure the theta multiplication maintains proper dimensions
        theta_term = self.theta * self.a[l][t - 1]  # This should broadcast correctly

        # Dimensions should match for subtraction
        result = s_l_t - theta_term
        return result

    def _calculate_r_l(self, l, t):
        """
        Calculate r_{l,t} based on z_{l,t} and weight projection.
        Following the paper's equation definitions.
        """

        # Handle the case when t is out of bounds
        if t >= len(self.a[l-1]):
            t = len(self.a[l-1]) - 1

        # Get activation from previous layer
        a_prev = self.a[l-1][t]  # Shape: [batch_size, n_prev_features]

        # Compute W_l @ a_{l-1,t} with proper dimensions
        p_l = torch.matmul(a_prev, self.W[l].t())  # Shape: [batch_size, n_features]

        # Make sure z has same shape for subtraction
        if t < len(self.z[l]):
            z_term = self.z[l][t]
        else:
            # If t is beyond z's length, use the last timestep
            z_term = self.z[l][-1]

        # Compute result with proper dimensions
        result = -p_l + z_term
        return result

    def _calculate_s_l(self, l, t):
        """
        Calculate s_{l,t} for intermediate updates.
        Following the paper's equation definitions.
        """

        # Handle the case when t is out of bounds
        if t >= len(self.a[l-1]):
            t = len(self.a[l-1]) - 1

        # Get activation from previous layer
        a_prev = self.a[l-1][t]  # Shape: [batch_size, n_prev_features]

        # Compute W_l @ a_{l-1,t} with proper dimensions
        p_l = torch.matmul(a_prev, self.W[l].t())  # Shape: [batch_size, n_features]

        if t > 1 and t-1 < len(self.z[l]):
            # Add delta * z_{l,t-1} term if t > 1
            z_prev = self.z[l][t-1]
            result = p_l + self.delta * z_prev
        else:
            result = p_l

        return result

    def _calculate_s_L(self):
        """
        Calculate s_L for the last layer update.
        Following the paper's equation definitions.
        """

        # Get activation from second-to-last layer at time T
        a_Lminus1_T = self.a[self.L-2][self.T-1]  # Shape: [batch_size, n_prev_features]

        # Compute W_L @ a_{L-1,T} with proper dimensions
        result = torch.matmul(a_Lminus1_T, self.W[self.L-1].t())  # Shape: [batch_size, n_features]
        return result

    # ============ a_{l,t} update functions ============
    # Activation update for l=1,...,L-2, t=1,...,T-1 (Equation 8)
    def _activation_update(self, u_l, w_l, v_l, t):
        """
        Update activations based on equation (8) in the paper:
        For l=1,...,L-2, t=1,...,T-1
        """
        alpha = self.rho / 2
        beta = alpha
        batch_size = u_l.shape[0]
        n_features = u_l.shape[1]
        next_layer_idx = min(t + 1, len(self.W)-1)
        W_next = self.W[next_layer_idx]

        # Calculate θ²I term
        theta_squared_term = (self.delta**2 * alpha * torch.eye(n_features, device=self.device))

        # Calculate W_{l+1}^T W_{l+1} term
        w_term = alpha * (W_next.t() @ W_next)

        # Combine terms
        term1 = theta_squared_term + w_term + beta * torch.eye(n_features, device=self.device)
        # Initialize right-hand side term
        term2 = torch.zeros((n_features, batch_size), device=self.device)

        # Add -θw_{l,t+1} term
        term2 = term2 - self.delta * alpha * w_l.t()

        # Add W_{l+1}^T v_{l+1} term
        term2 = term2 + alpha * W_next.t() @ v_l.t()

        # Add Heaviside term
        if t < len(self.z):
            z_term = self.z[t] - self.theta
            h_term = self._heaviside(z_term)
            term2 = term2 + beta * h_term.t()


        # Solve the system
        update = torch.linalg.solve(term1, term2)

        # Transpose back to match expected shape
        update = update.t()

        return update

    # Activation update for l=1,...,L-2 at t=T (Equation 9)
    def _activation_update_T(self, u_l, w_l, v_l):
        """
        Update activations for l=1,...,L-2 at t=T based on equation (9)
        """
        alpha = self.rho / 2
        beta = alpha
        batch_size = u_l.shape[0]
        n_features = u_l.shape[1]

        # Calculate W_{l+1}^T W_{l+1} term
        term1 = alpha * self.W[self.L-2].t() @ self.W[self.L-2]

        # Add identity term
        term1 = term1 + beta * torch.eye(n_features, device=self.device)

        # Initialize right-hand side term
        term2 = torch.zeros((batch_size, n_features), device=self.device)

        # Add W_{l+1}^T v_{l+1} term
        term2 = term2 + alpha * self.W[self.L-2].t() @ v_l

        # Add Heaviside term
        term2 = term2 + beta * self._heaviside(self.z[self.L-2][self.T] - self.theta)

        # Solve the system for each batch
        update = torch.solve(term2.t(), term1)[0].t()

        return update

    # Activation update for l=L-1, t=1,...,T-1 (Equation 10)
    def _activation_update_Lminus1(self, u_Lminus1, w_Lminus1, v_Lminus1, t):
        """
        Update activations for l=L-1, t=1,...,T-1 based on equation (10)
        """
        alpha = self.rho / 2
        beta = alpha
        batch_size = u_Lminus1.shape[0]
        n_features = u_Lminus1.shape[1]  # This is 64 (output of previous layer)
        W_L = self.W[self.L-1]  # Shape: [64, 128]


        # Calculate W_L W_L^T term with correct dimensions
        term1 = alpha * (W_L @ W_L.t())  # Shape: [64, 64]

        # Add identity term matching W_L dimensions
        term1 = term1 + beta * torch.eye(n_features, device=self.device)  # Shape: [64, 64]

        # Initialize right-hand side term
        term2 = torch.zeros((n_features, batch_size), device=self.device)  # Shape: [64, 128]

        # Add W_L^T v_L term - need to transpose v_Lminus1 for correct dimensions
        term2 = term2 + alpha * v_Lminus1.t()  # Shape: [64, 128]

        # Add Heaviside term
        if t < len(self.z[self.L-1]):
            z_term = self.z[self.L-1][t] - self.theta
            h_term = self._heaviside(z_term)
            term2 = term2 + beta * h_term.t()

        # Add Lagrange multiplier term only at final timestep
        if t == self.T:
            term2 = term2 + W_L.t() @ self.lambda_lagrange.t()


        # Solve the system
        # term1: [64, 64], term2: [64, 128]
        update = torch.linalg.solve(term1, term2)  # Shape: [64, 128]

        # Transpose back to match expected shape
        update = update.t()  # Shape: [128, 64]

        return update

    # Activation update for l=L-1, t=T (Equation 11)
    def _activation_update_Lminus1_T(self, u_Lminus1, w_Lminus1, v_Lminus1):
        """
        Update activations for l=L-1, t=T based on equation (11)
        Following equation: a_{L-1,T} = (W_L^T W_L + I)^{-1} [W_L^T (αv_L - λ/2) + β h(z_{L-1,T} - θ)]
        """
        alpha = self.rho / 2
        beta = alpha
        batch_size = u_Lminus1.shape[0]  # 128
        n_features = u_Lminus1.shape[1]  # 64
        W_L = self.W[self.L-1]  # Shape: [64, 128]


        # Calculate W_L W_L^T term with correct dimensions
        term1 = alpha * (W_L @ W_L.t())  # Shape: [64, 64]

        # Add identity term matching W_L dimensions
        term1 = term1 + beta * torch.eye(n_features, device=self.device)  # Shape: [64, 64]

        # Initialize right-hand side term
        term2 = torch.zeros((n_features, batch_size), device=self.device)  # Shape: [64, 128]

        # Add W_L^T v_L term - need to transpose v_Lminus1 for correct dimensions
        term2 = term2 + alpha * v_Lminus1.t()  # Shape: [64, 128]

        # Add Heaviside term for final timestep
        z_term = self.z[self.L-1][self.T-1] - self.theta  # Use T-1 for indexing
        h_term = self._heaviside(z_term)
        term2 = term2 + beta * h_term.t()  # Shape: [64, 128]

        # Add Lagrange multiplier term for final timestep with proper dimensions
        # lambda_lagrange: [128, 10]
        # First reshape lambda to match batch dimension: [128, 64]
        lambda_reshaped = self.lambda_lagrange @ W_L[:, :10].t()  # [128, 64]
        term2 = term2 - (lambda_reshaped.t() / 2)  # [64, 128]


        # Solve the system
        # term1: [64, 64], term2: [64, 128]
        update = torch.linalg.solve(term1, term2)  # Shape: [64, 128]

        # Transpose back to match expected shape
        update = update.t()  # Shape: [128, 64]

        return update
    # ============ lagrange multiplier update ============

    def _lambda_update(self, y):
        """
        Update the Lagrange multiplier lambda.

        Parameters:
        - y (torch.Tensor): Target values [num_classes, batch_size]

        Returns:
        - Updated lambda [batch_size, num_classes]
        """
        # Get last timestep activations: [batch_size, hidden_dim]
        last_layer_activations = self.a[self.L - 2][self.T-1]

        # Get last timestep outputs: [batch_size, num_classes]
        last_layer_outputs = self.z[self.L - 1][self.T-1]

        # Compute W @ a term: [batch_size, num_classes]
        # W: [num_classes, hidden_dim], a: [batch_size, hidden_dim]
        Wa_term = last_layer_activations @ self.W[self.L - 1].t()

        # Compute rho term: [batch_size, num_classes]
        rho_term = self.rho * (last_layer_outputs - Wa_term)

        # Ensure y has shape [batch_size, num_classes]
        if y.shape[0] != last_layer_outputs.shape[0]:
            y = y.t()

        # Compute y term: [batch_size, num_classes]
        y_term = y - last_layer_outputs

        # Update lambda
        self.lambda_lagrange = self.lambda_lagrange + rho_term + y_term

        return self.lambda_lagrange

    def feed_forward(self, inputs):
        """
        Process N-MNIST inputs through LIF layers and get final predictions.

        Parameters:
        - inputs (torch.Tensor): Shape [timesteps, in_features, batch_size]

        Returns:
        - Final layer membrane potentials for classification
        """
        timesteps, in_features, batch_size = inputs.shape

        # Initialize membrane potentials
        mem = []
        for layer_dim in [self.W[i].shape[0] for i in range(len(self.W))]:
            mem.append(torch.zeros(batch_size, layer_dim, device=self.device))

        # Create LIF neurons
        neurons = []
        for l in range(len(self.W)):
            if l < len(self.W) - 1:
                # Hidden layers with reset
                lif = snn.Leaky(
                    beta=self.delta,
                    threshold=self.theta,
                    reset_mechanism="subtract",
                    learn_beta=False,
                    learn_threshold=False
                )
            else:
                # Output layer accumulates without reset
                lif = snn.Leaky(
                    beta=self.delta,
                    threshold=self.theta,
                    reset_mechanism="none",
                    learn_beta=False,
                    learn_threshold=False
                )
            neurons.append(lif)

        # Process each timestep
        for t in range(timesteps):
            x = inputs[t].T  # [batch_size, in_features]

            # Process through each layer
            for l in range(len(self.W)):
                # Linear transformation
                x = x @ self.W[l].T

                # Apply LIF neuron dynamics
                spk, mem[l] = neurons[l](x, mem[l])

                # Output spikes become input to next layer
                x = spk

        # Return final layer membrane potentials for classification
        return mem[-1]  # Shape: [batch_size, n_outputs]
    def calculate_loss(self, raw_predictions, targets):
        """
        Calculate Cross Entropy Loss

        Parameters:
        - raw_predictions (torch.Tensor): Output activities [batch_size, num_classes]
        - targets (torch.Tensor): Target values as class indices [batch_size]
        """
        # Convert one-hot encoded targets to class indices if necessary
        if targets.shape[0] != raw_predictions.shape[0]:
            targets = targets.t()
        if targets.dim() > 1:
            targets = targets.argmax(dim=1)

        # Calculate cross-entropy loss
        loss = self.loss_fn(raw_predictions, targets)

        return loss
    def evaluate(self, inputs, targets):
        """
        Evaluate N-MNIST specific metrics based on class membership
        """
        print("\n=== N-MNIST Analysis ===")

        # Get final layer activity
        raw_predictions = self.feed_forward(inputs)  # [batch_size, 10]
        print("\nFinal Layer Activity:")
        print(f"Shape: {raw_predictions.shape}")
        print(raw_predictions)
        print(f"Activity Stats - Mean: {raw_predictions.mean():.4f}, Std: {raw_predictions.std():.4f}")
        print(f"Activity Range - Min: {raw_predictions.min():.4f}, Max: {raw_predictions.max():.4f}")

        # Threshold-based class assignments
        threshold = self.theta
        class_assignments = (raw_predictions > threshold).float()

        # Get predicted classes and actual classes
        pred_classes = raw_predictions.argmax(dim=1)
        true_classes = targets.t().argmax(dim=1)

        # Initialize metrics storage
        metrics_per_class = []

        print("\nConfusion Matrix and Metrics:")
        for class_idx in range(10):
            # Get predictions for each true class
            predicted_for_class = pred_classes[true_classes == class_idx]

            # Confusion matrix row for true class `class_idx`
            predictions_count = [(predicted_for_class == pred_class).sum().item() for pred_class in range(10)]
            print(f"Real Class {class_idx}: [{', '.join(map(str, predictions_count))}]")

            # Calculate True Positives, False Positives, False Negatives
            TP = (pred_classes == true_classes) & (true_classes == class_idx)
            FP = (pred_classes == class_idx) & (true_classes != class_idx)
            FN = (pred_classes != class_idx) & (true_classes == class_idx)

            TP_count = TP.sum().float()
            FP_count = FP.sum().float()
            FN_count = FN.sum().float()

            # Calculate precision, recall, F1 for each class
            precision = TP_count / (TP_count + FP_count) if TP_count + FP_count > 0 else torch.tensor(0.0)
            recall = TP_count / (TP_count + FN_count) if TP_count + FN_count > 0 else torch.tensor(0.0)
            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0)

            # Print metrics in one line for each class
            print(f"Class {class_idx} - Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

            # Store metrics for macro-averaging
            metrics_per_class.append({'precision': precision, 'recall': recall, 'f1': f1})

        # Calculate macro-averaged metrics
        macro_precision = torch.stack([m['precision'] for m in metrics_per_class]).mean()
        macro_recall = torch.stack([m['recall'] for m in metrics_per_class]).mean()
        macro_f1 = torch.stack([m['f1'] for m in metrics_per_class]).mean()
        accuracy = (pred_classes == true_classes).float().mean()

        print("\nOverall Metrics:")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Macro Precision: {macro_precision:.4f}")
        print(f"Macro Recall: {macro_recall:.4f}")
        print(f"Macro F1-Score: {macro_f1:.4f}")

        # Calculate firing accuracy
        firing_correct = class_assignments[range(len(true_classes)), true_classes].sum()
        firing_accuracy = firing_correct / len(true_classes)
        print(f"\nFiring Accuracy: {firing_accuracy:.4f}")

        # Calculate and print loss
        loss = self.calculate_loss(raw_predictions, targets)
        print(f"\nLoss: {loss.item():.6f}")

        return loss, raw_predictions

    def fit(self, inputs, targets):
        #print("\n=== Starting Warming Phase ===")
        #print("\nInitial state:")
        #print(f"Weight samples (first layer): {self.W[0][0, :5]}")
        #print(f"Lambda sample: {self.lambda_lagrange[0, :5]}")
        #print(f"z values sample (first layer): {self.z[0][0, 0, :5]}")
        #print(f"a values sample (first layer): {self.a[0][0, 0, :5]}")
        for l in range(1, self.L):
            #print(f"\n--- Layer {l} Updates ---")
            # Update self.W[l] using the function _weight_update
            self.W[l] = self._weight_update(self.z[l], self.a[l - 1])

            for t in range(1, self.T-1):  # Changed from self.T to self.T-1
                if l < self.L - 1:
                    u_l, w_l, v_l = self._calculate_u_w_v(l, t)
                    self.a[l][t] = self._activation_update(u_l, w_l, v_l, t)
                else:
                    u_Lminus1, w_Lminus1, v_Lminus1 = self._calculate_u_w_v(l, t)
                    self.a[l][t] = self._activation_update_Lminus1(u_Lminus1, w_Lminus1, v_Lminus1, t)

                # update self.z[l][t] using the function _z_update and check_entries
                self.z[l][t] = self._z_update(l, t)
                self.z[l][t] = self.check_entries(self.z[l][t], cost_function=lambda z: torch.norm(z))
            #print(f"\nLayer {l} final stats:")
            #print(f"z mean: {self.z[l].mean():.4f}, max: {self.z[l].max():.4f}, min: {self.z[l].min():.4f}")
            #print(f"a mean: {self.a[l].mean():.4f}, max: {self.a[l].max():.4f}, min: {self.a[l].min():.4f}")

            # Handle the final timestep separately
            t = self.T - 1  # Use last valid index
            if l < self.L - 1:
                u_l, w_l, v_l = self._calculate_u_w_v(l, t)
                self.a[l][t] = self._activation_update_T(u_l, w_l, v_l)
            else:
                u_Lminus1, w_Lminus1, v_Lminus1 = self._calculate_u_w_v(l, t)
                self.a[l][t] = self._activation_update_Lminus1_T(u_Lminus1, w_Lminus1, v_Lminus1)

            if l < self.L - 1:
                self.z[l][t] = self._z_update_T(l)
                self.z[l][t] = self.check_entries(self.z[l][t], cost_function=lambda z: torch.norm(z))
            else:
                self.z[l][t] = self._z_update_L(targets)
                self.z[l][t] = self.check_entries(self.z[l][t], cost_function=lambda z: torch.norm(z))

        # ----- Update the last layer -----
        self.W[self.L - 1] = self._weight_update_L(self.z[self.L - 1], self.a[self.L - 2], targets)

        for t in range(1, self.T-1):  # Changed from self.T to self.T-1
            self.z[self.L - 1][t] = self._z_update_L(targets)
            self.z[self.L - 1][t] = self.check_entries(self.z[self.L - 1][t], cost_function=lambda z: torch.norm(z))

        # Handle final timestep of last layer
        t = self.T - 1  # Use last valid index
        self.z[self.L - 1][t] = self._z_update_L_T(targets)
        self.z[self.L - 1][t] = self.check_entries(self.z[self.L - 1][t], cost_function=lambda z: torch.norm(z))

        # Update the lagrange multiplier using the function _lambda_update
        self.lambda_lagrange = self._lambda_update(targets)
        loss, predictions = self.evaluate(inputs, targets)

        print(f"\Training phase interim metrics:")
        print(f"  Loss: {loss:.6f}")


        return loss, predictions

    def warming(self, inputs, targets):
        #print("\n=== Starting Warming Phase ===")
        #print("\nInitial state:")
        #print(f"Weight samples (first layer): {self.W[0][0, :5]}")
        #print(f"Lambda sample: {self.lambda_lagrange[0, :5]}")
        #print(f"z values sample (first layer): {self.z[0][0, 0, :5]}")
        #print(f"a values sample (first layer): {self.a[0][0, 0, :5]}")
        for l in range(1, self.L):
            #print(f"\n--- Layer {l} Updates ---")
            # Update self.W[l] using the function _weight_update
            self.W[l] = self._weight_update(self.z[l], self.a[l - 1])

            for t in range(1, self.T-1):  # Changed from self.T to self.T-1
                if l < self.L - 1:
                    u_l, w_l, v_l = self._calculate_u_w_v(l, t)
                    self.a[l][t] = self._activation_update(u_l, w_l, v_l, t)
                else:
                    u_Lminus1, w_Lminus1, v_Lminus1 = self._calculate_u_w_v(l, t)
                    self.a[l][t] = self._activation_update_Lminus1(u_Lminus1, w_Lminus1, v_Lminus1, t)

                # update self.z[l][t] using the function _z_update and check_entries
                self.z[l][t] = self._z_update(l, t)
                self.z[l][t] = self.check_entries(self.z[l][t], cost_function=lambda z: torch.norm(z))
            #print(f"\nLayer {l} final stats:")
            #print(f"z mean: {self.z[l].mean():.4f}, max: {self.z[l].max():.4f}, min: {self.z[l].min():.4f}")
            #print(f"a mean: {self.a[l].mean():.4f}, max: {self.a[l].max():.4f}, min: {self.a[l].min():.4f}")

            # Handle the final timestep separately
            t = self.T - 1  # Use last valid index
            if l < self.L - 1:
                u_l, w_l, v_l = self._calculate_u_w_v(l, t)
                self.a[l][t] = self._activation_update_T(u_l, w_l, v_l)
            else:
                u_Lminus1, w_Lminus1, v_Lminus1 = self._calculate_u_w_v(l, t)
                self.a[l][t] = self._activation_update_Lminus1_T(u_Lminus1, w_Lminus1, v_Lminus1)

            if l < self.L - 1:
                self.z[l][t] = self._z_update_T(l)
                self.z[l][t] = self.check_entries(self.z[l][t], cost_function=lambda z: torch.norm(z))
            else:
                self.z[l][t] = self._z_update_L(targets)
                self.z[l][t] = self.check_entries(self.z[l][t], cost_function=lambda z: torch.norm(z))

        # ----- Update the last layer -----
        self.W[self.L - 1] = self._weight_update_L(self.z[self.L - 1], self.a[self.L - 2], targets)

        for t in range(1, self.T-1):  # Changed from self.T to self.T-1
            self.z[self.L - 1][t] = self._z_update_L(targets)
            self.z[self.L - 1][t] = self.check_entries(self.z[self.L - 1][t], cost_function=lambda z: torch.norm(z))

        # Handle final timestep of last layer
        t = self.T - 1  # Use last valid index
        self.z[self.L - 1][t] = self._z_update_L_T(targets)
        self.z[self.L - 1][t] = self.check_entries(self.z[self.L - 1][t], cost_function=lambda z: torch.norm(z))

        loss, predictions = self.evaluate(inputs, targets)

        print("\n=== Final State ===")
        print(f"Final weight samples: {self.W[-1][0, :5]}")
        print(f"Final lambda sample: {self.lambda_lagrange[0, :5]}")
        print(f"\nWarming phase interim metrics:")
        print(f"  Loss: {loss:.6f}")

        return loss, predictions


    def _calculate_u_w_v(self, l, t):
        """
        Calculate vectors based on equations in Section 2.2 of the paper:
        u_l = [z_{l,1}, z_{l,2} - δz_{l,1}, ..., z_{l,T} - δz_{l,T-1}]
        v_l = u_l + ϑ[0, a_{l,1}, ..., a_{l,T-1}]
        w_l = u_l - W_l[a_{l-1,1}, ..., a_{l-1,T}]

        Parameters:
        - l (int): Layer index
        - t (int): Time step

        Returns:
        - u_l, w_l, v_l (torch.Tensor): Calculated vectors with shape [batch_size, n_features]
        """
        batch_size = self.z[l].shape[1]
        n_features = self.z[l].shape[2]
        device = self.device

        # Calculate u_l
        if t == 0:  # First timestep
            u_l = self.z[l][0]  # Shape: [batch_size, n_features]
        else:
            # z_{l,t} - δz_{l,t-1}
            u_l = self.z[l][t] - self.delta * self.z[l][t-1]  # Shape: [batch_size, n_features]

        # Calculate v_l with proper broadcasting
        if t == 0:
            v_l = u_l  # No previous activations for t=0
        else:
            # Add ϑ * previous activation
            v_l = u_l.clone()
            if t > 1 and t-1 < len(self.a[l]):
                v_l = v_l + self.theta * self.a[l][t-1]  # Shape: [batch_size, n_features]

        # Calculate w_l
        if l > 0:  # Not input layer
            # Get previous layer activation
            a_prev = self.a[l-1][t]  # Shape: [batch_size, n_prev_features]

            # Compute W_l @ a_{l-1,t} with proper dimensions
            w_l = u_l - torch.matmul(a_prev, self.W[l].t())  # Shape: [batch_size, n_features]
        else:
            w_l = u_l


        return u_l, v_l, w_l


In [None]:
import torch
import torch.nn as nn
import torchvision
import tonic
import tonic.transforms as transforms
from torch.utils.data import DataLoader
from tonic import DiskCachedDataset
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils


# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations
sensor_size = tonic.datasets.NMNIST.sensor_size
frame_transform = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.ToFrame(sensor_size=sensor_size, time_window=1000)
])

# Load datasets
trainset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=True)
testset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=False)





Shapes in __init__:
W[0] shape: torch.Size([100, 1156])
W[1] shape: torch.Size([50, 100])
z[0] shape: torch.Size([300, 128, 100])
z[1] shape: torch.Size([300, 128, 50])
a[0] shape: torch.Size([300, 128, 100])
a[1] shape: torch.Size([300, 128, 50])
Sample data shape: torch.Size([128, 309, 2, 34, 34])
Sample target shape: torch.Size([128])


In [None]:
def prepare_nmnist_data(inputs, labels, device, n_timesteps=100):
    """
    Modified data preparation with consistent device placement.
    """
    batch_size = inputs.shape[0]

    # Move inputs to device first
    inputs = inputs.to(device)
    labels = labels.to(device)  # Move labels to device

    # Reshape inputs to [time, features, batch]
    inputs = inputs.float().reshape(batch_size, -1, sensor_size[0] * sensor_size[1])
    inputs = inputs.permute(1, 2, 0).contiguous()

    # Handle timesteps
    if inputs.shape[0] > n_timesteps:
        inputs = inputs[:n_timesteps]
    elif inputs.shape[0] < n_timesteps:
        padding = torch.zeros((n_timesteps - inputs.shape[0], inputs.shape[1], batch_size),
                            device=device)
        inputs = torch.cat((inputs, padding), dim=0)

    # One-hot encode labels (now labels are already on correct device)
    labels_onehot = torch.zeros((10, batch_size), device=device)
    labels_onehot.scatter_(0, labels.unsqueeze(0), 1)

    return inputs, labels_onehot

def train(model, trainloader, num_epochs):

    # Warming phase
    print("\nWarming Phase:")
    warming_losses = []
    for epoch in range(5):
        print(f'Warming Epoch [{epoch+1}/2]')
        epoch_losses = []
        epoch_accuracies = []

        for batch_idx, (inputs, labels) in enumerate(trainloader):

            # Prepare data
            inputs, labels = prepare_nmnist_data(inputs, labels, device)

            # Warming step
            loss, predictions = model.warming(inputs, labels)

            # Ensure predictions and labels have correct shape
            # predictions: [batch_size, num_classes]
            # labels: [num_classes, batch_size] -> [batch_size, num_classes]
            labels = labels.t()

            # Calculate accuracy
            pred_classes = torch.argmax(predictions, dim=1)  # [batch_size]
            true_classes = torch.argmax(labels, dim=1)      # [batch_size]
            accuracy = (pred_classes == true_classes).float().mean().item()

            epoch_losses.append(loss)
            epoch_accuracies.append(accuracy)



        # Epoch summary
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        avg_acc = sum(epoch_accuracies) / len(epoch_accuracies)
        warming_losses.append(avg_loss)
        print(f'  Epoch Summary - Loss: {avg_loss:.6f}, Accuracy: {avg_acc:.4f}')

    # Main training
    print("\nMain Training Phase:")
    training_metrics = []
    for epoch in range(num_epochs):
        print(f'Epoch [{epoch+1}/{num_epochs}]')
        epoch_losses = []
        epoch_accuracies = []

        for batch_idx, (inputs, labels) in enumerate(trainloader):
            print(f'  Batch [{batch_idx+1}/{len(trainloader)}]')

            # Prepare data
            inputs, labels = prepare_nmnist_data(inputs, labels, device)

            # Training step
            loss, predictions = model.fit(inputs, labels)

            # Ensure predictions and labels have correct shape
            labels = labels.t()

            # Calculate accuracy
            pred_classes = torch.argmax(predictions, dim=1)
            true_classes = torch.argmax(labels, dim=1)
            accuracy = (pred_classes == true_classes).float().mean().item()

            epoch_losses.append(loss)
            epoch_accuracies.append(accuracy)



        # Epoch summary
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        avg_acc = sum(epoch_accuracies) / len(epoch_accuracies)
        training_metrics.append({
            'epoch': epoch + 1,
            'loss': avg_loss,
            'accuracy': avg_acc
        })
        print(f'  Epoch Summary - Loss: {avg_loss:.6f}, Accuracy: {avg_acc:.4f}')

    return warming_losses, training_metrics

def evaluate(model, testloader):
    print("\nEvaluation Phase:")
    total_loss = 0
    total_acc = 0
    num_batches = 0
    batch_metrics = []

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(testloader):
            print(f'  Batch [{batch_idx+1}/{len(testloader)}]')

            # Prepare data
            inputs, labels = prepare_nmnist_data(inputs, labels, device)

            # Forward pass
            loss, predictions = model.evaluate(inputs, labels)

            # Calculate accuracy
            pred_classes = torch.argmax(predictions, dim=0)
            true_classes = torch.argmax(labels, dim=0)
            accuracy = (pred_classes == true_classes).float().mean().item()

            # Store batch metrics
            batch_metrics.append({
                'batch': batch_idx + 1,
                'loss': loss,
                'accuracy': accuracy
            })

            total_loss += loss
            total_acc += accuracy
            num_batches += 1

            print(f'    Loss: {loss:.6f}, Accuracy: {accuracy:.4f}')

    avg_loss = total_loss / num_batches
    avg_acc = total_acc / num_batches

    print(f'\nFinal Evaluation Results:')
    print(f'  Average Loss: {avg_loss:.6f}')
    print(f'  Average Accuracy: {avg_acc:.4f}')

    return avg_loss, avg_acc, batch_metrics

# Initialize model and training
n_timesteps = 100
input_dim = sensor_size[0] * sensor_size[1]
hidden_dims = [128, 10]
n_outputs = 10
# DataLoaders


# Calculate size of 20% of data
train_size = int(0.5 * len(trainset)) // 256 * 256
test_size = int(0.5 * len(testset)) // 256 * 256

from torch.utils.data import Subset
train_subset = Subset(trainset, torch.randperm(len(trainset))[:train_size])
test_subset = Subset(testset, torch.randperm(len(testset))[:test_size])


# Create DataLoaders with the subsets
batch_size = 256
trainloader = DataLoader(train_subset,
                        batch_size=batch_size,
                        collate_fn=tonic.collation.PadTensors(),
                        shuffle=True)
testloader = DataLoader(test_subset,
                       batch_size=batch_size,
                       collate_fn=tonic.collation.PadTensors())

print(f"Original training set size: {len(trainset)}")
print(f"training set size: {len(train_subset)}")
print(f"Original test set size: {len(testset)}")
print(f"test set size: {len(test_subset)}")

# Hyperparameters
rho = 0.05
delta = torch.tensor(0.7, device=device)
theta = torch.tensor(0.15, device=device)

# Create model
print("\nInitializing model...")
model = ADMM_SNN(
    n_samples=batch_size,
    n_timesteps=n_timesteps,
    input_dim=input_dim,
    hidden_dims=hidden_dims,  # Changed final dimension to 10
    n_outputs=10,
    rho=rho,
    delta=delta,
    theta=theta
)
print(model)

# Train model
print("\nStarting training process...")
num_epochs = 20
warming_losses, training_metrics = train(model, trainloader, num_epochs)

# Evaluate model
print("\nEvaluating model...")
test_loss, test_acc, test_metrics = evaluate(model, testloader)

# Print final results
print("\nTraining Summary:")
print(f"  Warming phase final loss: {warming_losses[-1]:.6f}")
print(f"  Training final loss: {training_metrics[-1]['loss']:.6f}")
print(f"  Training final accuracy: {training_metrics[-1]['accuracy']:.4f}")
print("\nTest Results:")
print(f"  Test Loss: {test_loss:.6f}")
print(f"  Test Accuracy: {test_acc:.4f}")

Shapes in __init__:
W[0] shape: torch.Size([128, 1156])
W[1] shape: torch.Size([64, 128])
z[0] shape: torch.Size([300, 128, 128])
z[1] shape: torch.Size([300, 128, 64])
a[0] shape: torch.Size([300, 128, 128])
a[1] shape: torch.Size([300, 128, 64])
Sample data shape: torch.Size([128, 310, 2, 34, 34])
Sample target shape: torch.Size([128])
------ Warming Epoch: 0 ------


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x64 and 128x128)