# Spiking Neural Network (SNN) with PyTorch : towards bridging the gap between deep learning and the human brain

So I simply thought "hey, what about coding a Spiking Neural Network?" Here it is. It couldn't get that out of my head so I coded it out of curiosity. 

Spiking Neural Networks (SNNs) are neural networks that are closer to what happens in the brain than what people usually code when doing Machine Learning and Deep Learning. In the case of the SNN, the neurons accumulate the input activation until a threshold is reached, and when this threshold is reached, the neuron empties itself from it's activation and fire. Once empty, it should indeed take some [refractory period](https://en.wikipedia.org/wiki/Refractory_period_(physiology)) until it fires again, as it happen in the brain. So I roughly replicated this behavior here with PyTorch. I coded this without reading existing code for me to come up with a solution by myself.

## How does it works?

The way I define a neuron's firing method is through the following steps, where the argument `x` is an input:

- Before anything, we need an empty inner state for each neuron, and for each sample in the batch. 
```python
    self.prev_inner = torch.zeros([batch_size, self.n_hidden]).to(self.device)
    self.prev_outer = torch.zeros([batch_size, self.n_hidden]).to(self.device)
```
- Then, a weight matrix multiplies the input x, which is an input from spiking neurons from below: 
```python
    input_excitation = self.fully_connected(x)
```
- We then add the result to a decayed version of the information we already had at the previous time step / time tick. The `decay_multiplier` serves the purpose of slowly fading the inner activation such that we don't accumulate stimulis for too long to be able to have the neurons to rest. The `decay_multiplier` could have a value of 0.9 for example. Decay as such is also called exponential decay and yields an effect of [Exponential moving average](https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average) on the most recent values seen, which also affects the gradients. In this sense, it's now really true that neurons that fire together wire together: when an input is received closer to the moment of giving an output, the gradient will be strong and learning will be able to take place. In the opposite case, a stimuli that happened too long ago will suffer from vanishing gradients since it has been exponentially decayed down. 
```python
    inner_excitation = input_excitation + self.prev_inner * self.decay_multiplier
```
- Now, we compute the activation of the neurons to find their output value. We have a threshold to reach before having the neuron activating. The ReLU function may not be the most appropriate here, but let's get a prototype working fast: 
```python
    outer_excitation = F.relu(inner_excitation - self.threshold)
```
- If the neuron fires, the activation of the neuron is subtracted to its inner state to reset each neuron to a reset position not to have them jammed, firing constantly. Here, I even added an extra penalty named `penalty_threshold` instead of `threshold` as above, for increasing even more the refractory time. I wasn't sure wheter the negative period in the refractory period was on the outputs of the neurons or inside the neurons, so here I exagerated it inside by subtracting a bigger value than the initial threshold to punish the neurons, although it would probably also work to have this negative part being found as an output signal rather than as an inner signal. I went with what I judged safe without digging far in research. Let's see how I subtract this just when the neuron fires to have it to have a refractory period: 
```python
    do_penalize_gate = (outer_excitation > 0).float()
    inner_excitation = inner_excitation - (self.penalty_threshold + outer_excitation) * do_penalize_gate
```
- Finally, I return the previous output, simulating a small firing delay, which is useless for now, but which may be interesting to have if the SNN I coded was ever to have Recurrent connections which would require time offsets in the connections from top layers near the outputs back into bottom layers near the input: 
```python
    delayed_return_state = self.prev_inner
    delayed_return_output = self.prev_outer
    self.prev_inner = inner_excitation
    self.prev_outer = outer_excitation
    return delayed_return_state, delayed_return_output
```

Amazingly, it worked on the 1st try once the dimension mismatching errors were fixed. And the accuracy was about the same of the accuracy of a simple non-spiking Feedforward Neural Network with the same number of neurons. And I didn't even tuned the threshold. In the end, I realized that coding and training a Spiking Neural Network (SNN) with PyTorch was easy enough as shown above, it can be coded in an evening as such. 



Basically, the neurons' activation must decay through time and fire only when getting past a certain threshold. So I've gated the output of the 

## SNNs v.s. RNNs

The SNN is NOT an RNN, despite it evolves through time too. For this SNN to be an RNN, I believe it would require some more connections such as from the outputs back into the inputs. In fact, RNNs are defined as a function of some inputs and of many neurons at the previous time step, such as:

<a href="https://www.codecogs.com/eqnedit.php?latex=$o_t&space;=&space;f(o_{t-1},&space;x_t)$" target="_blank"><img src="https://latex.codecogs.com/gif.latex?$o_t&space;=&space;f(o_{t-1},&space;x_t)$" title="$o_t = f(o_{t-1}, x_t)$" /></a> 

for example. In our case, we keep some state, but it's nothing comparable to having a connection back to other neurons in the past. 

## Results

Cleaner-than-average code awaits below. Scroll on!

In [1]:
import os

import torchvision.datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.autograd import Variable


In [2]:

def train(model, device, train_set_loader, optimizer, epoch, logging_interval=100):
    # This method is derived from: 
    # https://github.com/pytorch/examples/blob/master/mnist/main.py
    # Was licensed BSD-3-clause
    
    model.train()
    for batch_idx, (data, target) in enumerate(train_set_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % logging_interval == 0:
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct = pred.eq(target.view_as(pred)).float().mean().item()
            print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.2f}%'.format(
                epoch, batch_idx * len(data), len(train_set_loader.dataset),
                100. * batch_idx / len(train_set_loader), loss.item(),
                100. * correct))

def test(model, device, test_set_loader):
    # This method is derived from: 
    # https://github.com/pytorch/examples/blob/master/mnist/main.py
    # Was licensed BSD-3-clause
    
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_set_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduce=False).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_set_loader.dataset)
    print("")
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
        test_loss, 
        correct, len(test_set_loader.dataset),
        100. * correct / len(test_set_loader.dataset)))
    print("")

def download_mnist(data_path):
    if not os.path.exists(data_path):
        os.mkdir(data_path)
    transformation = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    training_set = torchvision.datasets.MNIST(data_path, train=True, transform=transformation, download=True)
    testing_set = torchvision.datasets.MNIST(data_path, train=False, transform=transformation, download=True)
    return training_set, testing_set


In [3]:

batch_size = 1000
test_batch_size = 1  # TODO: fix this
DATA_PATH = './data'

training_set, testing_set = download_mnist(DATA_PATH)
train_set_loader = torch.utils.data.DataLoader(
    dataset=training_set,
    batch_size=batch_size,
    shuffle=True)
test_set_loader = torch.utils.data.DataLoader(
    dataset=testing_set,
    batch_size=test_batch_size,
    shuffle=False)


use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


In [8]:
class SpikingNeuronLayerRNN(nn.Module):
    
    def __init__(self, device, n_inputs=28*28, n_hidden=100, decay_multiplier=0.9, threshold=2.0, penalty_threshold=2.5):
        super(SpikingNeuronLayerRNN, self).__init__()
        self.device = device
        self.n_inputs = n_inputs
        self.n_hidden = n_hidden
        self.decay_multiplier = decay_multiplier
        self.threshold = threshold
        self.penalty_threshold = penalty_threshold
        
        self.fc = nn.Linear(n_inputs, n_hidden)
        
        self.init_parameters()
        self.reset_state()
        self.to(self.device)
        
    def init_parameters(self):
        for param in self.parameters():
            if param.dim() >= 2:
                nn.init.xavier_uniform_(param)
        
    def reset_state(self):
        self.prev_inner = torch.zeros([self.n_hidden]).to(self.device)
        self.prev_outer = torch.zeros([self.n_hidden]).to(self.device)

    def forward(self, x):
        """
        Call the neuron at every time step.
        
        x: activated_neurons_below
        
        return: a tuple of (state, output) for each time step. Each item in the tuple
        are then themselves of shape (batch_size, n_hidden) and are PyTorch objects, such 
        that the whole returned would be of shape (2, batch_size, n_hidden) if casted.
        """
        if self.prev_inner.dim() == 1:
            # Adding batch_size dimension directly after doing a `self.reset_state()`:
            batch_size = x.shape[0]
            self.prev_inner = torch.stack(batch_size * [self.prev_inner])
            self.prev_outer = torch.stack(batch_size * [self.prev_outer])
        
        # 1. Weight matrix multiplies the input x
        input_excitation = self.fc(x)
        
        # 2. We add the result to a decayed version of the information we already had.
        inner_excitation = input_excitation + self.prev_inner * self.decay_multiplier
        
        # 3. We compute the activation of the neuron to find its output value, 
        #    but before the activation, there is also a negative bias that refrain thing from firing too much.
        outer_excitation = F.relu(inner_excitation - self.threshold)
        
        # 4. If the neuron fires, the activation of the neuron is subtracted to its inner state 
        #    (and with an extra penalty for increase refractory time), 
        #    because it discharges naturally so it shouldn't fire twice. 
        do_penalize_gate = (outer_excitation > 0).float()
        inner_excitation = inner_excitation - (self.penalty_threshold + outer_excitation) * do_penalize_gate
        
        # 5. The returned value is normalized from a learnable scalar to help normalize the learning.
        outer_excitation = outer_excitation  # TODO: code this
        
        # 6. Setting internal values before returning. 
        #    And the returning value is the one of the previous time step to delay 
        #    activation of 1 time step of "processing" time. For logits, we don't take activation.
        delayed_return_state = self.prev_inner
        delayed_return_output = self.prev_outer
        self.prev_inner = inner_excitation
        self.prev_outer = outer_excitation
        return delayed_return_state, delayed_return_output


class InputDataToSpikingPerceptronLayer(nn.Module):
    
    def __init__(self, device):
        super(InputDataToSpikingPerceptronLayer, self).__init__()
        self.device = device
        
        self.reset_state()
        self.to(self.device)
        
    def reset_state(self):
        #     self.prev_state = torch.zeros([self.n_hidden]).to(self.device)
        pass
    
    def forward(self, x, is_2D=True):
        x = x.view(x.size(0), -1)  # Flatten 2D image to 1D for FC
        random_activation_perceptron = torch.rand(x.shape).to(self.device)
        return random_activation_perceptron * x


class OutputDataToSpikingPerceptronLayer(nn.Module):
    
    def __init__(self, average_output=True):
        """
        average_output: might be needed if this is used within a regular neural net as a layer.
        Otherwise, sum may be numerically more stable for gradients with setting average_output=False.
        """
        super(OutputDataToSpikingPerceptronLayer, self).__init__()
        if average_output:
            self.reducer = lambda x, dim: x.sum(dim=dim)
        else:
            self.reducer = lambda x, dim: x.mean(dim=dim)
    
    def forward(self, x):
        if type(x) == list:
            x = torch.stack(x)
        return self.reducer(x, 0)


class SpikingNet(nn.Module):
    
    def __init__(self, device, n_time_steps, begin_eval):
        super(SpikingNet, self).__init__()
        assert (0 <= begin_eval and begin_eval < n_time_steps)
        self.device = device
        self.n_time_steps = n_time_steps
        self.begin_eval = begin_eval
        
        self.input_conversion = InputDataToSpikingPerceptronLayer(device)
        
        self.layer1 = SpikingNeuronLayerRNN(
            device, n_inputs=28*28, n_hidden=100,
            decay_multiplier=0.9, threshold=2.0, penalty_threshold=2.5
        )
        
        self.layer2 = SpikingNeuronLayerRNN(
            device, n_inputs=100, n_hidden=10,
            decay_multiplier=0.9, threshold=2.0, penalty_threshold=2.5
        )
        
        self.output_conversion = OutputDataToSpikingPerceptronLayer(average_output=False)  # Sum on outputs.
        
        self.to(self.device)
    
    def forward_through_time(self, x):
        """
        This acts as a layer. Its input is non-time-related, and its output too.
        So the time iterations happens inside, and the returned layer is thus
        passed through global average pooling on the time axis before the return 
        such as to be able to mix this pipeline with regular backprop layers such
        as the input data and the output data.
        """
        self.input_conversion.reset_state()
        self.layer1.reset_state()
        self.layer2.reset_state()

        out = []
        
        for _ in range(self.n_time_steps):
            xi = self.input_conversion(x)
            
            # For layer 1, we take the regular output.
            _, layer1_output = self.layer1(xi)
            
            # We take inner state of layer 2 because it's pre-activation and thus acts as out logits.
            layer2_state, _ = self.layer2(layer1_output)
            
            out.append(layer2_state)
        out = self.output_conversion(out[self.begin_eval:])
        return out
    
    def forward(self, x):
        out = self.forward_through_time(x)
        return F.log_softmax(out, dim=-1)


class NonSpikingNet(nn.Module):
    
    def __init__(self):
        super(NonSpikingNet, self).__init__()
        self.layer1 = nn.Linear(28*28, 100)
        self.layer2 = nn.Linear(100, 10)

    def forward(self, x, is_2D=True):
        x = x.view(x.size(0), -1)  # Flatten 2D image to 1D for FC
        x = F.relu(self.layer1(x))
        x =        self.layer2(x)
        return F.log_softmax(x, dim=-1)

## Training a Spiking Neural Network (SNN)

Let's use our `SpikingNet`!

In [10]:

model = SpikingNet(device, n_time_steps=128, begin_eval=0)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.5)

# 1 pass on the training data:
train(model, device, train_set_loader, optimizer, epoch, logging_interval=1)
test(model, device, test_set_loader)



Test set: Average loss: 0.4547, Accuracy: 8555/10000 (85.55%)



## Training a Feedforward Neural Network (for comparison)

It has the same number of layers and neurons, and also uses ReLU activation, but it's not an SNN, this one is a regular one as defined in the code above with this other class `NonSpikingNet`.

In [12]:

model = NonSpikingNet().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.5)

# 1 pass on the training data:
train(model, device, train_set_loader, optimizer, epoch, logging_interval=1)
test(model, device, test_set_loader)



Test set: Average loss: 0.5126, Accuracy: 8690/10000 (86.90%)



## Conclusion

Well, I've trained just for 1 epoch here. But the results are about the same. Absolutely no tuning has been performed yet. I've coded this one-shot. It would be worth trying a few things to see how it goes. 

Using SNNs should act as a regularizer, just like dropout, as I wouldn't expect the neurons to fire all at the same time. Although, it's an interesting path to explore, as [Brain Rhythms](https://www.youtube.com/watch?v=OCpYdSN_kts) seems to play an important role in the brain, whereas in Deep Learning no such things happens. 
