### SNN WP

In [1]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn 
import weight_perturbation as wp

In [2]:
# define a network


class SNN_WP(nn.Module):
    # initialize layers
    def __init__(self, beta, num_inputs, num_hidden, num_outputs, loss):
        super(SNN_WP, self).__init__()
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)
        self.loss = loss

    def clean_forward(self, x):

        self.mem1 = self.lif1.init_leaky()
        self.mem2 = self.lif2.init_leaky()
        for step in range(x.shape[0]):  
            cur1 = self.fc1(x[step])  # post-synaptic current <-- spk_in x weight
            spk1, self.mem1 = self.lif1(cur1, self.mem1)  # mem[t+1] <--post-syn current + decayed membrane
            cur2 = self.fc2(spk1)
            spk2, self.mem2 = self.lif2(cur2, self.mem2)
        return spk2

    def noisy_forward(self, x, noise):
        original_state = self.state_dict()
        perturbed_params = wp.dictionary_add(original_state, noise)
        self.load_state_dict(perturbed_params) 
        spk = self.forward(x)

        self.load_state_dict(original_state)

        return spk

    def forward(self, x):
        return self.clean_forward(x)
    
    def forward_pass(self, x, y, noise=None):
        if noise is None:
            y_pred = self.clean_forward(x)
            return self.loss(y_pred, y)
        else:
            y_pred = self.noisy_forward(x, noise)
            return self.loss(y_pred, y)

In [3]:
# layer parameters
num_steps = 20
num_inputs = 5
num_hidden = 20
num_outputs = 2
beta = 0.99
loss = nn.CrossEntropyLoss()
snn = SNN_WP(beta, num_inputs, num_hidden, num_outputs, loss)



In [4]:
# Training loop

epochs = 5
method = "cfd"
sigma = 1e-6
lr = 1e-8
sampler = torch.distributions.Normal(0, sigma)
# make a sampler

for e in range(epochs):
    # get input and targets from task
    spk_in = spikegen.rate_conv(torch.rand((num_steps, num_inputs))).unsqueeze(1)
    y = torch.ones(1).long()
    params = snn.state_dict()

    wp_grad = wp.compute_snn_gradient(
        snn.forward_pass, spk_in, y, params, sampler, method
    )  # do forward passes and compute gradient

    new_weights = wp.update_weights(wp_grad, params, sigma, lr)

    snn.load_state_dict(new_weights) #update the weights. Huzzah