<a href="https://colab.research.google.com/github/ergysmedaunipd/thesis/blob/main/THESIS_SNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import time

from typing import List


class ADMM_SNN:
    """ Class for ADMM Neural Network. """

    def __init__(self, n_samples: int, n_timesteps: int, input_dim: int, hidden_dims: List[int], n_outputs: int, rho: float, deltas: torch.Tensor, thetas: torch.Tensor):
        self.device = "cpu"

        # Define hyperparameters:
        # - thetas = Thresholds (can be all the same or different for each neuron)
        # - deltas = Decay factors (can be all the same or different for each neuron)
        # - roh = Penalty parameter. All the \alpha_{l,t} = \beta_{l,t} = \rho/2

        self.rho = rho
        self.deltas = deltas
        self.thetas = thetas

        self.L = len(hidden_dims)
        self.T = n_timesteps

        # Define a_0_t, which will be the input to the first layer
        self.a0 = torch.zeros(
            (n_timesteps, n_samples, input_dim)).to(self.device)

        # === Initialize W_l ===
        self.W = []

        # Now define the weights for each layer
        for i, hidden_dim in enumerate(hidden_dims):
            if i == 0:
                self.W.append(torch.zeros(
                    (hidden_dim, input_dim)).to(self.device))
            else:
                self.W.append(torch.zeros(
                    (hidden_dim, hidden_dims[i-1])).to(self.device))

        # === Initialize z_l ===
        self.z = []

        for i, hidden_dim in enumerate(hidden_dims):
            self.z.append(torch.zeros(
                (n_timesteps, n_samples, hidden_dim)).to(self.device))

        # === Initialize a_l ===
        self.a = []

        for i, hidden_dim in enumerate(hidden_dims):
            self.a.append(torch.zeros(
                (n_timesteps, n_samples, hidden_dim)).to(self.device))

        # Check how to make initialization?
        self.lambda_lagrange = torch.zeros(
            (n_samples, n_outputs)).to(self.device)

    def _heaviside(self, x):
        # Implement the heaviside function that takes in input some vector of z and returns 0 or 1, based on the thresholds self.thetas
        return

    # ============ W_{l} update functions ============
    def _weight_update(self, layer_output, activation_input):
        # Implement the Weight update of layers 1, ..., L-1, i.e. line 2 of Algorithm 2 (Equation (4) where \alpha_{l,t} = \rho/2 [same consideration applies everywhere below])
        return

    def _weight_update_L(self, layer_output, activation_input):
        # Define auxiliary variable x_L ... i.e. line 10 of Algorithm 2 (Equation (6))
        return

    # ============ z_{l,t} update functions ============
    def _z_update(self, _arguments_here):
        # First argument in Line 5 of Algorithm 2 (Equation (14))
        return

    def _z_update_T(self, _arguments_here):
        # First argument in Line 8 of Algorithm 2 (Equation (14)*)
        return

    def _z_update_L(self, _arguments_here):
        # Line 12 (Equation (16))
        return

    def _z_update_L_T(self, _arguments_here):
        # Line 14 (Equation (16)*)
        return

    def check_entries(self, z, cost_function):
        # Implements algorithm 1, used in lines 5 and 8 of Algorithm 2
        return

    # ============ a_{l,t} update functions ============
    def _activation_update(self, _arguments_here):
        # Implement the Activation update for l=1,...,L-2, t=1,...,T-1 (line 4)
        return

    def _activation_update_T(self, _arguments_here):
        # Implement the Activation update for l=1,...,L-2, t=T, (line 7)
        return

    def _activation_update_Lminus1(self, _arguments_here):
        # Implement the Activation update for l=L-1, t=1,...,T-1, (line 4 again, check the Indicator functions)
        return

    def _activation_update_Lminus1_T(self, _arguments_here):
        # Implement the Activation update for l=L-1, t=T, (line 7 again, check the Indicator functions)
        return

    # ============ lagrange multiplier update ============

    def _lambda_update(self, arguments_here):
        # Implement the update of the lagrange multiplier lambda (Line 15 of Algorithm 2)
        return

    def feed_forward(self, inputs):
        # Implement the forward pass of the SNN.
        # It can be implemented using SNNTorch using:
        # - The snn.leaky integrate-and-fire neuron model setting the correct arguments as follows:
        #   - beta = deltas (the decay factors)
        #   - threshold = thetas (the thresholds)
        #   - reset_mechanism= 'subtract' for layers l=1,..., L-1, and 'none' for the last layer

        # Check SNNTorch tutorials before, if not sure how to implement this
        # https://snntorch.readthedocs.io/en/latest/tutorials/index.html

        # The outputs should be the membrane potentials of the last layer at time step T
        return

    def fit(self, _arguments_here):
        # This function updates the optimization variables, given an input batch of data samples.

        # Carry out the updates following algorithm (2)

        # Here is a skeleton of the implementation:
        for l in range(1, self.L):
            # Update self.W[l] using the function _weight_update
            self.W[l] = None
            pass
            for t in range(1, self.T):
                if l < self.L - 1:
                    # update self.a[l][t] using _activation_update
                    self.a[l][t] = None
                else:
                    # update self.a[l][t] using _activation_update_Lminus1
                    self.a[l][t] = None

                # update selfz[l][t] using the function _z_update and check_entries
                self.z[l][t] = None
                pass

            if l < self.L-1:
                # update self.a[l][T] using the function _activation_update_T
                self.a[l][self.T] = None
            else:
                # update self.a[l][T] using the function _activation_update_Lminus1_T
                self.a[l][self.T] = None

            # update self.z[l][T] using the function _z_update_T and check_entries
            self.z[l][self.T] = None
            pass

        # ----- Update the last layer -----
        # Update self.W[L] using the function _weight_update_L
        self.W[self.L] = None
        for t in range(1, self.T):
            # update self.z[L][t] using the function _z_update_L
            self.z[self.L][t] = None
            pass

        # ----- Update the last layer at time T -----
        # update self.z[L][T] using the function _z_update_L_T
        self.z[self.L][self.T] = None

        # Update the lagrange multiplier using the function _lambda_update
        self.lambda_lagrange = None
        return

    def evaluate(self, _arguments_here):
        # Standard evaluation phase
        return

    def warming(self, _arguments_here):
        # Just as previous implementation.
        return


if __name__ == "__main__":

    # Python module that includes the N-MNIST dataset
    # Check this tutorial for more info: https://snntorch.readthedocs.io/en/latest/tutorials/tutorial_7.html
    import tonic

    # Implement the training and evaluation of the model, using the N-MNIST dataset
    dataset = tonic.datasets.NMNIST(save_to='./data', train=True)

    # Implement dataset splitting, training and evaluation below ...
