In [None]:
#Mounting content from Google Drive
#from google.colab import drive
#import os
#drive.mount('/content/gdrive')
#!ls '/content/gdrive/My Drive/'
#path = '/content/gdrive/My Drive/Master Thesis/' #set the path as your own location
#os.chdir(path)

In [2]:
from typing import Optional, NamedTuple, Tuple, Any, Sequence
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
import matplotlib
from matplotlib import pyplot as plt
import datetime
from torchsummary import summary
import numpy as np
import heapq
%load_ext tensorboard

In [3]:
class SpikeFunction(torch.autograd.Function):
    """
    Spiking function with rectangular gradient.
    Source: https://www.frontiersin.org/articles/10.3389/fnins.2018.00331/full
    Implementation: https://github.com/combra-lab/pop-spiking-deep-rl/blob/main/popsan_drl/popsan_td3/popsan.py
    """

    @staticmethod
    def forward(ctx: Any, v: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(v)  # save voltage - thresh for backwards pass
        return v.gt(0.0).float()

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        v, = ctx.saved_tensors
        grad_input = grad_output.clone()
        spike_pseudo_grad = (v.abs() < 0.5).float()  # 0.5 is the width of the rectangle
        return grad_input * spike_pseudo_grad, None  # ensure a tuple is returned

## Define Hyperparameters and Initialize Parameters to be Evolved

In [6]:
popu = 30
net_size = [7840,256,64]
N_neuron = sum(net_size)
i_decay_pool = torch.rand(popu,N_neuron)
v_decay_pool = torch.rand(popu,N_neuron)
thresh_pool = torch.rand(popu,N_neuron)
encoder_pool = torch.empty(popu).random_(2) #0: population encoder, 1: position encoder

In [7]:
class LIFState(NamedTuple):
    z: torch.Tensor
    v: torch.Tensor
    i: torch.Tensor

class LIF(nn.Module):
    """
    Leaky-integrate-and-fire neuron with learnable parameters.
    """

    def __init__(self, size: int, n: int):
        super().__init__()
        self.size = size
        self.n = n #index of each individual
        # Put the neurons parameters to the corresponding layer
        if size == net_size[0]:
            self.i_decay = nn.Parameter(i_decay_pool[n,0:net_size[0]])
            self.v_decay = nn.Parameter(v_decay_pool[n,0:net_size[0]])
            self.thresh = nn.Parameter(thresh_pool[n,0:net_size[0]])
            self.spike = SpikeFunction.apply  # spike function
        if size == net_size[1]:
            self.i_decay = nn.Parameter(i_decay_pool[n,net_size[0]:(net_size[0]+net_size[1])])
            self.v_decay = nn.Parameter(v_decay_pool[n,net_size[0]:(net_size[0]+net_size[1])])
            self.thresh = nn.Parameter(thresh_pool[n,net_size[0]:(net_size[0]+net_size[1])])
            self.spike = SpikeFunction.apply  # spike function
        if size == net_size[2]:
            self.i_decay = nn.Parameter(i_decay_pool[n,(net_size[0]+net_size[1]):N_neuron])
            self.v_decay = nn.Parameter(v_decay_pool[n,(net_size[0]+net_size[1]):N_neuron])
            self.thresh = nn.Parameter(thresh_pool[n,(net_size[0]+net_size[1]):N_neuron])
            self.spike = SpikeFunction.apply  # spike function

    def forward(
        self,
        synapse: nn.Module,
        z: torch.Tensor,
        state: Optional[LIFState] = None,
    ) -> Tuple[torch.Tensor, LIFState]:
        # Previous state
        if state is None:
            state = LIFState(
                z=torch.zeros_like(synapse(z)),
                v=torch.zeros_like(synapse(z)),
                i=torch.zeros_like(synapse(z)),
            )
        # Update state
        i = state.i * self.i_decay + synapse(z)
        v = state.v * self.v_decay * (1.0 - state.z) + i
        z = self.spike(v - self.thresh)
        return z, LIFState(z, v, i)

In [8]:
class SpikingMLP(nn.Module):
    """
    Spiking network with LIF neuron model.
    """

    def __init__(self, sizes: Sequence[int], n: Sequence[int]):
        super().__init__()
        self.sizes = sizes
        self.n = n
        self.spike = SpikeFunction.apply

        # Define layers
        self.synapses = nn.ModuleList()
        self.neurons = nn.ModuleList()
        self.states = []
        # Loop over current (accessible with 'size') and next (accessible with 'sizes[i]') element
        for i, size in enumerate(sizes[:-1], start=1):
            # Parameters of synapses and neurons are randomly initialized
            self.synapses.append(nn.Linear(size, sizes[i], bias=False))
            self.neurons.append(LIF(sizes[i],n))
            self.states.append(None)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        for i, (neuron, synapse) in enumerate(zip(self.neurons, self.synapses)):
            z, self.states[i] = neuron(synapse, z, self.states[i])
        return z

    def reset(self):
        """
        Resetting states when you're done is very important!
        """
        for i, _ in enumerate(self.states):
            self.states[i] = None

## Encoder

In [9]:
def spike_train_encoding(x: torch.Tensor, n: int) -> torch.Tensor:
    """
    Encode a tensor of shape (batch, values) into a spike train tensor
    with shape (batch, values, n).
    
    Assumes x is in range [0, 1], denoting the spike probabilities.
    """
    # x: shape (batch, values) -> spikes: shape (batch, values, n)
    batch, values = x.shape  #[64, 784]
    x = x.unsqueeze(-1)
    spikes = torch.rand(batch, values, n, dtype=x.dtype, device=x.device)
    return (spikes < x).float() #[64, 784, 10]

def spike_population_encoding(x: torch.Tensor, pop: int) -> torch.Tensor:
    """
    Encode a tensor of shape (batch, values) into a spike population tensor
    with shape (batch, values, pop).
    
    Assumes x is in range [0, 1], denoting the spike probabilities.
    
    NOTE: this is the same as spike_train_encoding(), but dimensions
    will be treated differently later on!
    """
    batch, values = x.shape
    output = torch.reshape(spike_train_encoding(x, pop),(batch,values*pop)) #reshape from [64,784,10] to [64,7840]
  
    return output #[64, 7840]
   


def spike_position_encoding(x: torch.Tensor, bins: int) -> torch.Tensor:
    """
    A deterministic position encoding, where we have a certain number of bins
    to discretize the continuous input value.
    
    This is based on the idea of 'place cells'
    (https://en.wikipedia.org/wiki/Place_cell). A stochastic variant
    of this is also possible, where the bins have the shape of a normal
    distribution (radial basis functions) and represent spiking probabilities.
    
    Assumes x is in range [0, 1].
    """
    # Add bin dimension
    batch, values = x.shape
    x = x.unsqueeze(-1)
    # Get bins of same shape
    # Of course, it would be more efficient to create this tensor only once
    bins = bins+1 #10 neurons need 11 bins
    bins = torch.linspace(0, 1, bins, dtype=x.dtype, device=x.device).view(1, 1, -1).expand(batch, values, -1).clone()
    # Get spikes
    # See documentation for searchsorted() to see how it works
    # Only last dimension of x and bins can be different
    spikes = torch.searchsorted(bins, x)  # right-bound inclusive [64,784,1], position of the spike
    index_spikes = torch.reshape(spikes,(batch*values,1))-1 #convert position to index of the spiking neuron
    index_position = torch.arange(batch*values, device=x.device).reshape(batch*values,1)
    output = torch.zeros([batch*values,10],device=x.device) #empty neurons
    replace = torch.ones([1],device=x.device)
    output.index_put_((index_position,index_spikes),replace) #replace 0 by 1 at the spiking neurons
    return torch.reshape(output,(batch,values*10)) #[64, 7840]

encoder_list = [spike_population_encoding,spike_position_encoding]


## Decoder

In [10]:
class VoltageDecoding(nn.Module):
    """
    Voltage decoder with learnable parameters (hence a class).
    
    Just acts like a non-leaky integrate-and-fire (IF) neuron.
    """
    
    def __init__(self, in_size: int, out_size: int):
        super().__init__()
        # Synapse between network and decoder
        self.synapse = nn.Linear(in_size, out_size)
        # Learnable voltage decay
        self.v_decay = nn.Parameter(torch.rand(out_size))
        
    def forward(self, z: torch.Tensor, v: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Previous v
        if v is None:
            v = torch.zeros_like(self.synapse(z))
        # Update
        i = self.synapse(z)
        v = v * self.v_decay + i
        return v

In [11]:
 class SpikingClassifier(nn.Module):
    """
    Classifier SNN that makes use of current encoding and volt decoding
    (should have as little loss-of-signal as possible).
    """
    
    def __init__(self,n: int, net_sizes: Sequence[int] = [7840, 256, 64], out_size: int = 10):
        super().__init__()
        self.n = n
        # Encoder
        self.encoder = encoder_list[int(encoder_pool[n])] #spike_current_encoding
        # Network
        self.snn = SpikingMLP(net_sizes,n) #network for each individual
        # Decoder
        self.decoder = VoltageDecoding(net_sizes[-1], out_size)
        
    def forward(self, x):
        # Flatten image
        batch, channel, height, width = x.shape
        x = x.view(batch, -1)
        
        # Encode entire sequence
        i_in = self.encoder(x,10)
        
        # Reset network
        self.snn.reset()
        
        # Run: just one step!
        # Network
        z = self.snn(i_in)
        # Decoder
        v_out = self.decoder(z)
        
        return v_out  #v_out.shape = [64,10]

In [12]:
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms

def train(
    model: nn.Module,
    train_loader: DataLoader,
    optimizer: optim.Optimizer,
    device: torch.device,
    epoch: int,
    log_interval: int = 64,
):
    # Set to train mode
    model.train()
    
    # Do one epoch
    for batch_idx, (data, target) in enumerate(train_loader):
        # Move to GPU (if applicable)
        data, target = data.to(device), target.to(device)
        # Zero gradients
        optimizer.zero_grad()
        # Predict
        output = model(data)
        # Get loss
        loss = F.cross_entropy(output, target)
        # Do backprop
        loss.backward()
        # Learn
        optimizer.step()
        
        # Log
        #if batch_idx % log_interval == 0:
            #print(f"train epoch: {epoch} [{batch_idx}/{len(train_loader)}]\tloss: {loss.item():.6f}")
            

def test(model: nn.Module, test_loader: DataLoader, device: torch.device):
    # Set to test/eval mode
    model.eval()
    
    # Counters
    test_loss = 0
    correct = 0
    
    # Don't update graph
    with torch.no_grad():
        for data, target in test_loader:
            # Move to GPU
            data, target = data.to(device), target.to(device)
            # Predict
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction="sum").item()
            # See if correct
            pred = output.argmax(1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    accuracy = 100.0 * correct / len(test_loader.dataset)
    #print(f"\ntest: avg loss: {test_loss:.4f}, accuracy: {accuracy:.1f}%\n")
    
    return test_loss, accuracy

In [13]:
def breed(min_loss, par1, par2, par3, par4):
    '''Select individuls with smaller loss, par1, par2, par3 are neuron 
    parameter, par4 is the index of encoder'''    
    half_popu = int(popu/2)
    good_index = heapq.nsmallest(half_popu, range(len(min_loss)), min_loss.__getitem__)
    parents1 = par1[good_index]  
    parents2 = par2[good_index]
    parents3 = par3[good_index]
    parents4 = par4[good_index]
    mu_neuron = torch.rand_like(parents1)+0.5 #parameters vary in the range of [50%,150%]
    mu_encoder = torch.cat([torch.zeros(int(0.4*popu)),torch.ones(int(0.1*popu))],0) #20% possibility of mutation
    mu_encoder = mu_encoder[torch.randperm(mu_encoder.size()[0])]
    new_par1 = torch.cat((parents1,parents1*mu_neuron),0)
    new_par2 = torch.cat((parents2,parents2*mu_neuron),0)
    new_par3 = torch.cat((parents3,parents3*mu_neuron),0)
    new_par4 = torch.cat((parents4,abs(mu_encoder-parents4)),0)
    return new_par1,new_par2,new_par3,new_par4

In [14]:
# Data 
'''For Local Jupyter Notebook'''
mnist_train = MNIST(".", train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST(".", train=False, download=True, transform=transforms.ToTensor())
'''For Google Colab'''
#data_path = path + 'data/'
#save_path = path + 'model/'
#mnist_train = datasets.MNIST(data_path, train = True, download = False, transform = transforms.ToTensor())
#mnist_test = datasets.MNIST(data_path, train = False, download = False, transform = transforms.ToTensor())

# Dataloaders: 2x as many workers as cores, pin_memory to True
train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True, num_workers=32, pin_memory=True)
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False, num_workers=32, pin_memory=True)

# Check for GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

def generation():
    min_loss = []
    for i in range(popu):
        #print(i)
        # Select classifier
        classifier = SpikingClassifier(i).to(device)
        # Optimizer
        optimizer = optim.Adam(classifier.parameters())
        # Logging
        test_losses = []
        test_accs = []
        # Run epoch
        for e in range(2):
            # Train
            train(classifier, train_loader, optimizer, device, e)
            # Test/validate
            test_loss, test_acc = test(classifier, test_loader, device)
            # Log
            test_losses.append(test_loss)
            test_accs.append(test_acc)
        min_loss.append(min(test_losses))   
    return min_loss

N = 5
for a in range(N):
    print(a)
    min_loss = generation()
    new_i_decay,new_v_decay,new_thresh,new_encoder = breed(min_loss, i_decay_pool, v_decay_pool, thresh_pool, encoder_pool)
    i_decay_pool = new_i_decay
    v_decay_pool = new_v_decay
    thresh_pool = new_thresh
    encoder_pool = new_encoder
    print(heapq.nsmallest(popu,min_loss))


0
[0.20289878854751586, 0.20316954841017723, 0.20420764437317848, 0.20534762805700302, 0.20565676441788674, 0.20794454568624496, 0.20833383709788322, 0.2113780371785164, 0.21180731242895126, 0.22232588243484497, 0.22723007258176803, 0.22791901705265044, 0.23515106321573256, 0.23961178255081178, 0.2555751905143261, 0.6953482852935791, 0.703564900970459, 0.7802230184555053, 0.7822952578544616, 0.8113776509284973, 0.8411556464195251, 0.8504612324714661, 0.866305062007904, 0.8702998962402344, 0.8789214824676513, 0.9339030861854554, 0.9376825727462769, 0.9998296499252319, 1.060860987472534, 1.1112602071762085]
1
[0.1913862202525139, 0.20168669731020927, 0.2044870375931263, 0.20738442553281783, 0.20755948738455773, 0.20846164550185203, 0.20899792821407318, 0.20936037682890893, 0.20948406853079796, 0.20950489243268966, 0.20984686757326126, 0.2099250315248966, 0.2104313025712967, 0.21255700039863587, 0.21387489842176438, 0.21466880445480346, 0.21509594742059707, 0.21616980971693991, 0.21689910

## =======================================================================