### SNN WP

In [2]:
import snntorch as snn
import snntorch.functional as SF

from snntorch import spikeplot as splt

from snntorch import spikegen

import torch

import torch.nn as nn 

import weight_perturbation as wp

In [3]:
# define a network


class SNN_WP(nn.Module):
    """
    SNN for weight perturbation. Consists of two fully connected layers of LIFs.
    """
    def __init__(self, beta, num_inputs, num_hidden, num_outputs, loss):
        """
        Initialize the network

        Parameters
        ----------
        beta : float
            The memory leakage of the LIF
        num_inputs : int
            The size of the input layer
        num_hidden : int
            The size of the output layer
        num_outputs : int
            The size of the output
        loss : Loss
            The loss used during training
        """
        super(SNN_WP, self).__init__()

        spike_grad = snn.surrogate.fast_sigmoid(slope=25) #needed for BP
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.loss = loss

    def clean_forward(self, x):
        """Wrapper function of forward for readability purposes"""
        return self.forward(x)

    def noisy_forward(self, x, noise):
        """Perturbs the weights and runs a forward pass

        Parameters
        ----------
        x : Array
            Input data as spike trains
        noise : dict
            Noise for each network parameter

        Returns
        -------
        spk : Array
            Output spike train
        """
        original_state = self.state_dict()
        perturbed_params = wp.dictionary_add(original_state, noise)
        self.load_state_dict(perturbed_params) 
        spk = self.forward(x)

        # reset the parameters back to the unperturbed parameters
        self.load_state_dict(original_state)

        return spk

    def forward(self, x):
        """Perturbs the weights and runs a forward pass

        Parameters
        ----------
        x : Array
            Input data as spike trains

        Returns
        -------
        spk : Array
            Output spike train
        """
        self.mem1 = self.lif1.init_leaky()
        self.mem2 = self.lif2.init_leaky()
        for step in range(x.shape[0]):  
            cur1 = self.fc1(x[step])  # post-synaptic current <-- spk_in x weight
            spk1, self.mem1 = self.lif1(cur1, self.mem1)  # mem[t+1] <--post-syn current + decayed membrane
            cur2 = self.fc2(spk1)
            spk2, self.mem2 = self.lif2(cur2, self.mem2)
        return spk2

    def forward_pass(self, x, y, noise=None):
        """Perturbs the weights and runs a forward pass

        Parameters
        ----------
        x : Array
            Input data as spike trains
        y : Array
            True labels
        noise : dict, optional
            Noise for each network parameter

        Returns
        -------
        loss : long
            The loss of the pass
        """
        if noise is None:
            y_pred = self.clean_forward(x)
            return self.loss(y_pred, y)
        else:
            y_pred = self.noisy_forward(x, noise)
            return self.loss(y_pred, y)
        


In [None]:
class ClassificationDataset(torch.utils.data.Dataset):
    """Classification dataset."""

    def __init__(self, num_samples, timesteps, dim_in):
        """Linear relation between input and output"""
        pass
    def __len__(self):
        """Number of samples."""
        return self.num_samples

    def __getitem__(self, idx):
        return self.features[:, idx, :], self.labels[:, idx, :]

In [None]:
dataset = ClassificationDataset(10000, 100, 8)

train_set, val_set = torch.utils.data.random_split(dataset, [9000, 1000])


train_loader = torch.utils.data.DataLoader(
    dataset=train_set, batch_size=100, drop_last=True
)
val_set = torch.utils.data.DataLoader(dataset=val_set, batch_size=100, drop_last=True)

In [11]:
# layer parameters
num_steps = 20
num_inputs = 5
num_hidden = 20
num_outputs = 2
beta = 0.99
loss = SF.ce_rate_loss()
SNN = SNN_WP(beta, num_inputs, num_hidden, num_outputs, loss)


# training parameters
loss_hist_wp = []
test_acc_hist_wp = []
loss_hist_bp = []
test_acc_hist_bp = []
epochs = 5
method = "cfd"
sigma = 1
lr = 1e-6
device = "cpu"

In [None]:
def batch_accuracy(test_loader, net):
    with torch.no_grad():
        total = 0
        acc = 0
        net.eval()

        test_loader = iter(test_loader)
        for data, targets in test_loader:
            data = data.to(device)
            targets = targets.to(device)
            spk_rec = net(data)

            acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
            total += spk_rec.size(1)

    return acc / total

In [None]:
# Training loop for WP

sampler = torch.distributions.Normal(0, sigma)
# make a sampler

with torch.no_grad():
    for e in range(epochs):
        # get input and targets from task
        loss_epoch = []
        for data, targets in iter(train_loader):

            data = data.to(device)
            targets = targets.to(device)

            params = SNN.state_dict()

            loss = SNN.forward_pass(data, targets)
            loss_hist_wp.append(loss.item())

            wp_grad = wp.compute_snn_gradient(
                SNN.forward_pass, data, targets, params, sampler, method
            )  # do forward passes and compute gradient

            new_weights = wp.update_weights(wp_grad, params, sigma, lr)
            new_weights = wp.dictionary_mult(new_weights, 2)

            SNN.load_state_dict(new_weights)  # update the weights. Huzzah

            loss_epoch.append(loss.item())

        loss_hist_wp.append(torch.mean(torch.tensor(loss_epoch)))

        with torch.no_grad():
            # Test set forward pass
            test_acc = batch_accuracy(test_loader, SNN)
            print(f"Epoch {e}, Test Acc: {test_acc * 100:.2f}%\n")
            test_acc_hist_wp.append(test_acc.item())

In [None]:
# Training loop for BP

optimizer = torch.optim.sgd(lr=lr)


for e in range(epochs):

    loss_epoch = []

    for data, targets in iter(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        SNN.train()

        y_pred = SNN(data)

        # initialize the loss & sum over time
        loss_val = SNN.loss(y_pred, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_epoch.append(loss_val.item())

    loss_hist_bp.append(torch.mean(torch.tensor(loss_epoch)))

    with torch.no_grad():
        # Test set forward pass
        test_acc = batch_accuracy(test_loader, SNN)
        print(f"Epoch {e}, Test Acc: {test_acc * 100:.2f}%\n")
        test_acc_hist_bp.append(test_acc.item())