In [1]:
from typing import Optional, NamedTuple, Tuple, Any, Sequence
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import heapq
%matplotlib notebook

## Spiking network with the LIF neuron model

In [2]:
class SpikeFunction(torch.autograd.Function):
    """
    Spiking function with rectangular gradient.
    Source: https://www.frontiersin.org/articles/10.3389/fnins.2018.00331/full
    Implementation: https://github.com/combra-lab/pop-spiking-deep-rl/blob/main/popsan_drl/popsan_td3/popsan.py
    """

    @staticmethod
    def forward(ctx: Any, v: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(v)  # save voltage - thresh for backwards pass
        return v.gt(0.0).float()

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        v, = ctx.saved_tensors
        grad_input = grad_output.clone()
        spike_pseudo_grad = (v.abs() < 0.5).float()  # 0.5 is the width of the rectangle
        return grad_input * spike_pseudo_grad, None  # ensure a tuple is returned

In [3]:
# Placeholder for LIF state
class LIFState(NamedTuple):
    z: torch.Tensor
    v: torch.Tensor
    i: torch.Tensor

class LIF(nn.Module):
    """
    Leaky-integrate-and-fire neuron with learnable parameters.
    """

    def __init__(self, size: int):
        super().__init__()
        self.size = size
        # Initialize all parameters randomly as U(0, 1)
        self.i_decay = torch.rand(size) #self.i_decay = nn.Parameter(torch.rand(size))
        self.v_decay = torch.rand(size)
        self.thresh = torch.rand(size)
        self.spike = SpikeFunction.apply  # spike function

    def forward(
        self,
        synapse: nn.Module,
        z: torch.Tensor,
        state: Optional[LIFState] = None,
    ) -> Tuple[torch.Tensor, LIFState]:
        # Previous state
        if state is None:
            state = LIFState(
                z=torch.zeros_like(synapse(z)),
                v=torch.zeros_like(synapse(z)),
                i=torch.zeros_like(synapse(z)),
            )
        # Update state
        i = state.i * self.i_decay + synapse(z)
        #print(self.i_decay)
        #print(synapse(z))
        v = state.v * self.v_decay * (1.0 - state.z) + i
        z = self.spike(v - self.thresh)

        return z, LIFState(z, v, i)

In [4]:
class SpikingMLP(nn.Module):
    """
    Spiking network with LIF neuron model.
    """

    def __init__(self, sizes: Sequence[int]):
        super().__init__()
        self.sizes = sizes
        self.spike = SpikeFunction.apply

        # Define layers
        self.synapses = nn.ModuleList()
        self.neurons = nn.ModuleList()
        self.states = []
        # Loop over current (accessible with 'size') and next (accessible with 'sizes[i]') element
        for i, size in enumerate(sizes[:-1], start=1):
            # Parameters of synapses and neurons are randomly initialized
            self.synapses.append(nn.Linear(size, sizes[i], bias=False))
            self.neurons.append(LIF(sizes[i]))
            self.states.append(None)
           

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        for i, (neuron, synapse) in enumerate(zip(self.neurons, self.synapses)):
            z, self.states[i] = neuron(synapse, z, self.states[i])
        return z

    def reset(self):
        """
        Resetting states when you're done is very important!
        """
        for i, _ in enumerate(self.states):
            self.states[i] = None


## XOR Task

In [5]:
# Data and labels
samples = 1000 #10000
x = torch.randint(2, (samples, 2)).float()
#x = torch.Tensor([[0,0],[0,1],[1,0],[1,1]])
y = (x.sum(-1) == 1).float()
#print(x)
#print(y)

In [6]:
def add(x,y):
        return x+y

def sub(x,y):
        return x-y

def mul(x,y):
        return x*y

def div(x,y):
        return x/y

F = [add, sub, mul]#, div]

In [7]:
n_operand = 3 #number of operands
n_row = 4 #number of rows
n_column = 2 #number of columns of internal nodes
n_F = len(F) #number of operators

def create_node():
        internal1 = torch.cat((torch.randint(n_F,(n_row,1)),torch.randint(n_operand,(n_row,n_column))),1)
        internal2 = torch.cat((torch.randint(n_F,(n_row,1)),torch.randint(n_row,(n_row,n_column))),1)
        internal = torch.cat((internal1,internal2))
        return internal

In [8]:
def compute(S,operand):
        pheno = S[0]
        i = S[1]
        def compute1(i):
                operator = F[pheno[i][0]]
                output = operator(operand[pheno[i][1]],operand[pheno[i][2]])
                return output
        if i < n_row:
                output = compute1(i)
        else:
                operator = F[pheno[i][0]]
                output = operator(compute1(pheno[i][1]), compute1(pheno[i][2]))
        return output

def mutation(S):
        mu = torch.randn((8,3))#*0.3
        #mu = torch.clamp(mu, -1, 1) + 1
        mutated_internal = torch.round(S[0]+mu)
        l1 = torch.FloatTensor([[0, 0, 0]])
        u1 = torch.FloatTensor([[n_F-1, n_operand-1, n_operand-1]]) #[2,2,2]
        mutated_internal[:4] = torch.max(torch.min(mutated_internal[:4], u1), l1) #clamp the nodes in the first column within the range [l1, u1]
        l2 = torch.FloatTensor([[0, 0, 0]])
        u2 = torch.FloatTensor([[n_F-1, n_row-1, n_row-1]]) #[2,3,3] 
        mutated_internal[4:] = torch.max(torch.min(mutated_internal[4:], u2), l2) #clamp the nodes in the second column within the range [l2, u2]
        mutated_internal = mutated_internal.int()
        mutated_node_index = torch.randint(len(mutated_internal),(1,1)).item()
        mutated_S = [mutated_internal,mutated_node_index]
        return mutated_S

In [9]:
def update_weight(S,snn,trace1,trace2,trace3,loss):
    w1 = torch.zeros_like(list(snn.parameters())[0])
    for i in range(w1.size()[1]):
        for j in range(w1.size()[0]):
            w1[j][i] = compute(S,[trace1[i],trace2[j],loss])

    w2 = torch.zeros_like(list(snn.parameters())[1])
    for i in range(w2.size()[1]):
        for j in range(w2.size()[0]):
            w2[j][i] = compute(S,[trace2[i],trace3[j],loss])
    return w1,w2  

In [10]:
#batch = 50
epochs = 3
decay = 0.5
alpha = 1
timesteps = 10

criterion = torch.nn.MSELoss()

def train(S,snn):
    for e in range(epochs):
        losses = []
        for i in range(samples):#sample//batch

            # Zero the parameter gradients
            for p in snn.parameters():
                p.grad = None
            
            # Reset the network
            snn.reset()

            
            trace1 = 0
            trace2 = 0
            trace3 = 0
            prediction = torch.Tensor()
            
            #traces are computed in a timestep equal to the 
            for d in range(timesteps): 
                y_hat = snn(x[i]) #i*batch,(i+1)*batch
                trace1 = trace1 * decay + alpha * x[i]
                trace2 = trace2 * decay + alpha * snn.states[0][0]
                trace3 = trace3 * decay + alpha * snn.states[1][0]
            prediction = torch.cat((prediction,y_hat))
            loss = criterion(prediction, y[i])    
            w1,w2 = update_weight(S,snn, trace1, trace2, trace3, (1/(loss.item()+0.01)))
            for idx,w in enumerate(snn.parameters()):
                    if idx == 0:
                        w.data = w1 + w.data
                    else:
                        w.data = w2 + w.data

            # Print statistics
            losses.append(loss.item())
            #print(f"[{e + 1}, {i}] loss: {loss.item()}")
    return sum(losses)

2 + 4 Evolutionary Strategy: 2 random individuals are intialized randomly and thier losses are computed. For each individual, two offsprings are generated by mutation. If the performane of the offspring is better, the offspring will replace the parent.

In [11]:
internal1 = create_node()

node_index1 = torch.randint(len(internal1),(1,1)).item()

S1 = [internal1,node_index1]

internal2 = create_node()

node_index2 = torch.randint(len(internal2),(1,1)).item()

S2 = [internal2,node_index2]

Stab = [S1,S2]

In [13]:
sizes = [2,5,1]
L1 = train(S1,SpikingMLP(sizes))
L2 = train(S2,SpikingMLP(sizes))
Ltab = [L1,L2]
print(Ltab)

[505.0, 242.0]


In [14]:
gen = 101
for g in range(gen):
    Smutab = []
    Smu11 = mutation(Stab[0])
    Smu12 = mutation(Stab[0])
    Smu21 = mutation(Stab[1])
    Smu22 = mutation(Stab[1])
    Smutab.append(Smu11)
    Smutab.append(Smu12)
    Smutab.append(Smu21)
    Smutab.append(Smu22)
    #print(Smutab)
    Lmutab = []
    Lmutab.append(train(Smu11,SpikingMLP(sizes)))
    Lmutab.append(train(Smu12,SpikingMLP(sizes)))
    Lmutab.append(train(Smu21,SpikingMLP(sizes)))
    Lmutab.append(train(Smu22,SpikingMLP(sizes)))
    #print(Lmutab)
    Ljointtab = Ltab + Lmutab
    Sjointtab = Stab + Smutab
    good_index = heapq.nsmallest(2,range(len(Ljointtab)), Ljointtab.__getitem__)
    Stab = [Sjointtab[m] for m in good_index]
    Ltab = [Ljointtab[n] for n in good_index]
    if g % 10 == 0:
        print(Ljointtab)
        print(Ltab)
       

[505.0, 242.0, 242.0, 242.0, 242.0, 505.0]
[242.0, 242.0]
[242.0, 242.0, 242.0, 505.0, 242.0, 242.0]
[242.0, 242.0]
[242.0, 242.0, 505.0, 505.0, 242.0, 505.0]
[242.0, 242.0]
[242.0, 242.0, 505.0, 503.0, 242.0, 242.0]
[242.0, 242.0]
[0.0, 242.0, 242.0, 471.0, 242.0, 505.0]
[0.0, 242.0]
[0.0, 242.0, 242.0, 505.0, 505.0, 505.0]
[0.0, 242.0]
[0.0, 242.0, 242.0, 242.0, 242.0, 518.0]
[0.0, 242.0]
[0.0, 242.0, 505.0, 242.0, 505.0, 242.0]
[0.0, 242.0]
[0.0, 229.0, 242.0, 242.0, 242.0, 242.0]
[0.0, 229.0]
[0.0, 229.0, 242.0, 242.0, 242.0, 242.0]
[0.0, 229.0]
[0.0, 229.0, 242.0, 242.0, 242.0, 242.0]
[0.0, 229.0]


In [15]:
print(Stab)

[[tensor([[1, 1, 1],
        [1, 2, 1],
        [2, 0, 0],
        [1, 2, 2],
        [0, 2, 0],
        [2, 3, 0],
        [2, 1, 2],
        [2, 3, 0]], dtype=torch.int32), 5], [tensor([[1, 0, 2],
        [0, 1, 0],
        [1, 0, 0],
        [0, 1, 1],
        [0, 1, 0],
        [2, 2, 0],
        [2, 3, 2],
        [1, 3, 0]], dtype=torch.int32), 6]]


## Test=========================