In [1]:
from typing import Optional, NamedTuple, Tuple, Any, Sequence
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import heapq
%matplotlib notebook

## Spiking network with the LIF neuron model

In [2]:
class SpikeFunction(torch.autograd.Function):
    """
    Spiking function with rectangular gradient.
    Source: https://www.frontiersin.org/articles/10.3389/fnins.2018.00331/full
    Implementation: https://github.com/combra-lab/pop-spiking-deep-rl/blob/main/popsan_drl/popsan_td3/popsan.py
    """

    @staticmethod
    def forward(ctx: Any, v: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(v)  # save voltage - thresh for backwards pass
        return v.gt(0.0).float()

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        v, = ctx.saved_tensors
        grad_input = grad_output.clone()
        spike_pseudo_grad = (v.abs() < 0.5).float()  # 0.5 is the width of the rectangle
        return grad_input * spike_pseudo_grad, None  # ensure a tuple is returned

In [3]:
# Placeholder for LIF state
class LIFState(NamedTuple):
    z: torch.Tensor
    v: torch.Tensor
    i: torch.Tensor

class LIF(nn.Module):
    """
    Leaky-integrate-and-fire neuron with learnable parameters.
    """

    def __init__(self, size: int):
        super().__init__()
        self.size = size
        # Initialize all parameters randomly as U(0, 1)
        self.i_decay = torch.rand(size) #self.i_decay = nn.Parameter(torch.rand(size))
        self.v_decay = torch.rand(size)
        self.thresh = torch.rand(size)
        self.spike = SpikeFunction.apply  # spike function

    def forward(
        self,
        synapse: nn.Module,
        z: torch.Tensor,
        state: Optional[LIFState] = None,
    ) -> Tuple[torch.Tensor, LIFState]:
        # Previous state
        if state is None:
            state = LIFState(
                z=torch.zeros_like(synapse(z)),
                v=torch.zeros_like(synapse(z)),
                i=torch.zeros_like(synapse(z)),
            )
        # Update state
        i = state.i * self.i_decay + synapse(z)
        #print(self.i_decay)
        #print(synapse(z))
        v = state.v * self.v_decay * (1.0 - state.z) + i
        z = self.spike(v - self.thresh)

        return z, LIFState(z, v, i)

In [4]:
class SpikingMLP(nn.Module):
    """
    Spiking network with LIF neuron model.
    """

    def __init__(self, sizes: Sequence[int]):
        super().__init__()
        self.sizes = sizes
        self.spike = SpikeFunction.apply

        # Define layers
        self.synapses = nn.ModuleList()
        self.neurons = nn.ModuleList()
        self.states = []
        # Loop over current (accessible with 'size') and next (accessible with 'sizes[i]') element
        for i, size in enumerate(sizes[:-1], start=1):
            # Parameters of synapses and neurons are randomly initialized
            self.synapses.append(nn.Linear(size, sizes[i], bias=False))
            self.neurons.append(LIF(sizes[i]))
            self.states.append(None)
           

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        for i, (neuron, synapse) in enumerate(zip(self.neurons, self.synapses)):
            z, self.states[i] = neuron(synapse, z, self.states[i])
        return z

    def reset(self):
        """
        Resetting states when you're done is very important!
        """
        for i, _ in enumerate(self.states):
            self.states[i] = None


## XOR Task

In [5]:
# Data and labels
samples = 1000 #10000
x = torch.randint(2, (samples, 2)).float()
y = (x.sum(-1) == 1).float()
print(x[:10])
print(y[:10])
print(f"Class imbalance: {y.sum() / y.shape[0]}")

tensor([[1., 0.],
        [1., 1.],
        [1., 1.],
        [1., 0.],
        [1., 1.],
        [0., 1.],
        [1., 1.],
        [1., 0.],
        [0., 0.],
        [1., 0.]])
tensor([1., 0., 0., 1., 0., 1., 0., 1., 0., 1.])
Class imbalance: 0.5040000081062317


## CGP
Operands: one presynaptic trace, one postsynaptic trace, reward (1/(loss+0.01))

Operators: + - * (/ is excluded for now to avoid invalid learning rules becasue there is no validity check program yet)

In [6]:
def add(x,y):
        return x+y

def sub(x,y):
        return x-y

def mul(x,y):
        return x*y

def div(x,y):
        return x/y

F = [add, sub, mul]#, div]

This very simple index graph has 8 internal nodes: 2 columns, 4 rows. Each internal nodes has a index with three integers. The first integer indicates the operator and second and third integers indicate the inputs. Inputs for internal nodes in the first column are the operands, while inputs for internal nodes in the second column are the outpus of the first comlumn. 

In [7]:
n_operand = 3 #number of operands
n_row = 4 #number of rows
n_column = 2 #number of columns of internal nodes
n_F = len(F) #number of operators

def create_node():
        internal1 = torch.cat((torch.randint(n_F,(n_row,1)),torch.randint(n_operand,(n_row,n_column))),1)
        internal2 = torch.cat((torch.randint(n_F,(n_row,1)),torch.randint(n_row,(n_row,n_column))),1)
        internal = torch.cat((internal1,internal2))
        return internal


In [8]:
def compute(S,operand):
        pheno = S[0]
        i = S[1]
        def compute1(i):
                operator = F[pheno[i][0]]
                output = operator(operand[pheno[i][1]],operand[pheno[i][2]])
                return output
        if i < n_row:
                output = compute1(i)
        else:
                operator = F[pheno[i][0]]
                output = operator(compute1(pheno[i][1]), compute1(pheno[i][2]))
        return output

def mutation(S):
        mu = torch.randn((8,3))#*0.3
        #mu = torch.clamp(mu, -1, 1) + 1
        mutated_internal = torch.round(S[0]+mu)
        l1 = torch.FloatTensor([[0, 0, 0]])
        u1 = torch.FloatTensor([[n_F-1, n_operand-1, n_operand-1]]) #[2,2,2]
        mutated_internal[:4] = torch.max(torch.min(mutated_internal[:4], u1), l1) #clamp the nodes in the first column within the range [l1, u1]
        l2 = torch.FloatTensor([[0, 0, 0]])
        u2 = torch.FloatTensor([[n_F-1, n_row-1, n_row-1]]) #[2,3,3] 
        mutated_internal[4:] = torch.max(torch.min(mutated_internal[4:], u2), l2) #clamp the nodes in the second column within the range [l2, u2]
        mutated_internal = mutated_internal.int()
        mutated_node_index = torch.randint(len(mutated_internal),(1,1)).item()
        mutated_S = [mutated_internal,mutated_node_index]
        return mutated_S

In [9]:
def update_weight(S,snn,trace1,trace2,trace3,loss):
    w1 = torch.zeros_like(list(snn.parameters())[0])
    for i in range(w1.size()[1]):
        for j in range(w1.size()[0]):
            w1[j][i] = compute(S,[trace1[i],trace2[j],loss])

    w2 = torch.zeros_like(list(snn.parameters())[1])
    for i in range(w2.size()[1]):
        for j in range(w2.size()[0]):
            w2[j][i] = compute(S,[trace2[i],trace3[j],loss])
    return w1,w2  

In [10]:
batch = 50
epochs = 3
decay = 0.5
alpha = 1

criterion = torch.nn.MSELoss()

def train(S,snn):
    for e in range(epochs):
        losses = []
        for i in range(samples // batch):

            # Zero the parameter gradients
            #for p in snn.parameters():
                #p.grad = None
            
            # Reset the network
            snn.reset()

            
            trace1 = 0
            trace2 = 0
            trace3 = 0
            prediction = torch.Tensor()
            
            #traces are computed in a timestep equal to the 
            for d in range(i*batch,(i+1)*batch): 
                y_hat = snn(x[d])
                prediction = torch.cat((prediction,y_hat))
                trace1 = trace1 * decay + alpha * x[d]
                trace2 = trace2 * decay + alpha * snn.states[0][0]
                trace3 = trace3 * decay + alpha * snn.states[1][0]
            loss = criterion(prediction, y[i*batch:(i+1)*batch])    
            w1,w2 = update_weight(S,snn, trace1, trace2, trace3, (1/(loss.item()+0.01)))
            for i,w in enumerate(snn.parameters()):
                    if i == 0:
                        w.data = w1 + w.data
                    else:
                        w.data = w2 + w.data

            # Print statistics
            losses.append(loss.item())
            #print(f"[{e + 1}, {i}] loss: {loss.item()}")
    return sum(losses)

2 + 4 Evolutionary Strategy: 2 random individuals are intialized randomly and thier losses are computed. For each individual, two offsprings are generated by mutation. If the performane of the offspring is better, the offspring will replace the parent.

In [11]:
internal1 = create_node()

node_index1 = torch.randint(len(internal1),(1,1)).item()

S1 = [internal1,node_index1]

internal2 = create_node()

node_index2 = torch.randint(len(internal2),(1,1)).item()

S2 = [internal2,node_index2]

Stab = [S1,S2]

In [12]:
sizes = [2,5,1]
print(train(S1,SpikingMLP(sizes)))

9.819999933242798


In [13]:
L1 = train(S1,SpikingMLP(sizes))
L2 = train(S2,SpikingMLP(sizes))
Ltab = [L1,L2]
print(Ltab)

[9.819999933242798, 9.819999933242798]


In [14]:
gen = 500
for g in range(gen):
    Smutab = []
    Smu11 = mutation(Stab[0])
    Smu12 = mutation(Stab[0])
    Smu21 = mutation(Stab[1])
    Smu22 = mutation(Stab[1])
    Smutab.append(Smu11)
    Smutab.append(Smu12)
    Smutab.append(Smu21)
    Smutab.append(Smu22)
    #print(Smutab)
    Lmutab = []
    Lmutab.append(train(Smu11,SpikingMLP(sizes)))
    Lmutab.append(train(Smu12,SpikingMLP(sizes)))
    Lmutab.append(train(Smu21,SpikingMLP(sizes)))
    Lmutab.append(train(Smu22,SpikingMLP(sizes)))
    #print(Lmutab)
    Ljointtab = Ltab + Lmutab
    Sjointtab = Stab + Smutab
    good_index = heapq.nsmallest(2,range(len(Ljointtab)), Ljointtab.__getitem__)
    Stab = [Sjointtab[m] for m in good_index]
    Ltab = [Ljointtab[n] for n in good_index]
    if g % 10 == 0:
        print(Ljointtab)
        print(Ltab)
       

[9.819999933242798, 9.819999933242798, 9.819999933242798, 10.0799999833107, 9.840000063180923, 9.819999933242798]
[9.819999933242798, 9.819999933242798]
[7.000000014901161, 9.059999942779541, 9.819999933242798, 9.819999933242798, 10.0799999833107, 9.819999933242798]
[7.000000014901161, 9.059999942779541]
[7.000000014901161, 9.059999942779541, 10.0799999833107, 9.819999933242798, 9.819999933242798, 9.819999933242798]
[7.000000014901161, 9.059999942779541]
[7.000000014901161, 9.059999942779541, 9.819999933242798, 9.819999933242798, 9.139999955892563, 10.0799999833107]
[7.000000014901161, 9.059999942779541]
[7.000000014901161, 9.059999942779541, 10.0799999833107, 9.819999933242798, 10.0799999833107, 9.819999933242798]
[7.000000014901161, 9.059999942779541]
[7.000000014901161, 9.059999942779541, 10.0799999833107, 10.46000000834465, 9.819999933242798, 9.819999933242798]
[7.000000014901161, 9.059999942779541]
[7.000000014901161, 8.499999970197678, 9.819999933242798, 9.819999933242798, 10.079

In [15]:
print(Stab)

[[tensor([[1, 2, 2],
        [1, 0, 1],
        [0, 2, 2],
        [1, 1, 0],
        [2, 1, 2],
        [0, 3, 0],
        [2, 3, 2],
        [1, 1, 3]], dtype=torch.int32), 0], [tensor([[1, 1, 1],
        [0, 2, 0],
        [2, 1, 2],
        [1, 2, 0],
        [2, 3, 3],
        [1, 2, 2],
        [2, 1, 0],
        [2, 1, 3]], dtype=torch.int32), 5]]


## Test=========================

In [15]:
sizes = [2,5,1]
snn = SpikingMLP(sizes)
snn.reset()

decay = 0.5
alpha = 1
trace1 = 0
trace2 = 0
trace3 = 0

loss = torch.nn.MSELoss()

prediction = torch.Tensor()
for i in range(10):
    #print(i)
    trace1 = trace1 * decay + alpha * x[i]
    #print(snn.states[0])
    output = snn(x[i])
    prediction = torch.cat((prediction,output))
    trace2 = trace2 * decay + alpha * snn.states[0][0]
    trace3 = trace3 * decay + alpha * snn.states[1][0]
    #print('x:',x[i])
    #print(snn.states[0])
#print(trace1,trace2,trace3)
print(prediction)
print(loss(prediction,y[:10]))

tensor([0., 1., 1., 1., 1., 0., 0., 1., 1., 0.], grad_fn=<CatBackward>)
tensor(0.9000, grad_fn=<MseLossBackward>)


In [12]:
# Network
sizes = [2, 5, 1]
snn = SpikingMLP(sizes)

# Loss function and optimizer
def most_basic_loss_ever(x, y):
    return (x.view(-1) - y.view(-1)).abs().sum()  # ensure both are flat

criterion = most_basic_loss_ever
optimizer = optim.Adam(snn.parameters())

# Batch size of 100, 2 epochs
batch = 100
epochs = 3

def train(S,snn):
    losses = []
    for e in range(epochs):
        for i in range(samples // batch):

            # Zero the parameter gradients
            for p in snn.parameters():
                p.grad = None
            
            # Reset the network
            snn.reset()

            # Forward + backward + optimize
            y_hat = snn(x[i:i + batch])
            #print(list(snn.parameters()))
            loss = criterion(y_hat, y[i:i + batch])
            trace1, trace2, trace3 = get_trace(x[i:i + batch],snn.states[0][0],snn.states[1][0])
            w1,w2 = update_weight(S, trace1, trace2, trace3, (1/(loss.item()+0.01)))
            for i,w in enumerate(snn.parameters()):
                if i == 0:
                    w.data = w1 + w.data
                else:
                    w.data = w2 + w.data

            # Print statistics
            losses.append(loss.item())
            #print(f"[{e + 1}, {i}] loss: {loss.item()}")
    return min(losses)

#train(mutation(S))

#print('Finished Training')

#plt.plot(losses)
#plt.show()

In [21]:
population = 20
Ltab = []
for p in range(population):
    internal = create_node()
    node_index = torch.randint(len(internal),(1,1)).item()
    S = [internal,node_index]
    snn = SpikingMLP(sizes)
    L = train(S,snn)
    print(L)
    print(list(snn.parameters()))
    Ltab.append(L)

45.0
[Parameter containing:
tensor([[-0.3738, -0.3887],
        [13.5311, 13.8867],
        [13.1364, 12.7798],
        [-0.4300, -0.4871],
        [ 0.1038, -0.2728]], requires_grad=True), Parameter containing:
tensor([[ 0.2200,  0.2119, -0.2160, -0.3945,  0.0668]], requires_grad=True)]
18.0
[Parameter containing:
tensor([[329.4209, 523.0168],
        [329.2086, 524.1368],
        [328.8070, 523.6957],
        [328.1838, 523.1481],
        [329.5231, 524.0338]], requires_grad=True), Parameter containing:
tensor([[829.4711, 827.4448, 828.0056, 827.7215, 827.9511]],
       requires_grad=True)]
18.0
[Parameter containing:
tensor([[252.2332, 339.0680],
        [251.8015, 338.3237],
        [252.3256, 339.1833],
        [251.9291, 338.7458],
        [252.8493, 338.3752]], requires_grad=True), Parameter containing:
tensor([[464.6884, 462.7307, 463.8718, 463.2959, 465.1858]],
       requires_grad=True)]
45.0
[Parameter containing:
tensor([[ 0.0946,  0.0338],
        [ 0.0810,  0.0420],
     

In [33]:
e = [3,5]
f = [1,4,6]
joint = e+f
print(joint)
good_index = heapq.nsmallest(2,range(len(joint)), joint.__getitem__)
print(good_index)
parents = [joint[m] for m in good_index]
print(parents)
#for i in f:
    #if i<e[0]:
        #e[0] = i
    #else:
        #if i<e[1]:
            #e[1] = i
    #print(e)

[3, 5, 1, 4, 6]
[2, 0]
[1, 3]


In [12]:
weights = list(snn.parameters())
print(weights)

[Parameter containing:
tensor([[ 0.5049, -0.5355],
        [-0.6936,  0.8443],
        [ 0.0882, -0.3541],
        [ 0.3409,  0.5484],
        [ 0.2396, -0.7130]], requires_grad=True), Parameter containing:
tensor([[ 0.3296,  0.2162, -0.1045,  0.3895,  0.4227]], requires_grad=True)]


In [54]:
x = torch.randint(2, (samples, 2)).float()
print(x)

tensor([[1., 1.],
        [0., 0.],
        [0., 1.],
        [0., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 0.],
        [1., 1.],
        [0., 0.]])


In [11]:
y_hat = snn(x)
print(snn.states[0][0])

tensor([[0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], grad_fn=<SpikeFunctionBackward>)


In [14]:
decay = 0.5
alpha = 1
def update_trace(prespike:torch.Tensor):
    trace = 0
    for i in range(prespike.size()[0]):
        trace = trace * decay + alpha * prespike[i]           
    return trace

def get_trace(layer1,layer2,layer3):
    trace1 = update_trace(layer1)
    trace2 = update_trace(layer2)
    trace3 = update_trace(layer3)
    return trace1, trace2, trace3

a,b,c = get_trace(x,snn.states[0][0],snn.states[1][0])  
print(a,b,c)

tensor([1.9038, 0.5303]) tensor([1.9766, 0.0000, 0.0000, 0.0000, 1.9766], grad_fn=<AddBackward0>) tensor([1.9766], grad_fn=<AddBackward0>)


In [74]:
w1 = torch.zeros_like(list(snn.parameters())[0])
for i in range(w1.size()[1]):
    for j in range(w1.size()[0]):
        w1[j][i] = a[i] + b[j]
#print(w1)

w2 = torch.zeros_like(list(snn.parameters())[1])
for i in range(w2.size()[1]):
    for j in range(w2.size()[0]):
        w2[j][i] = b[i] + c[j] 
#print(w2)

#print(list(snn.parameters()))
for i,w in enumerate(snn.parameters()):
    if i == 0:
        w.data = w1
    else:
        w.data = w2
#print(list(snn.parameters()))

tensor([[0.4258, 0.4902],
        [0.4258, 0.4902],
        [0.4902, 0.5547],
        [0.4258, 0.4902],
        [0.9160, 0.9805]], grad_fn=<CopySlices>)
tensor([[0.0000, 0.0000, 0.0645, 0.0000, 0.4902]], grad_fn=<CopySlices>)
[Parameter containing:
tensor([[0.4258, 0.4902],
        [0.4258, 0.4902],
        [0.4902, 0.5547],
        [0.4258, 0.4902],
        [0.9160, 0.9805]], requires_grad=True), Parameter containing:
tensor([[ 0.2868, -0.8998, -1.0204, -0.0059,  0.0199]], requires_grad=True)]
[Parameter containing:
tensor([[0.4258, 0.4902],
        [0.4258, 0.4902],
        [0.4902, 0.5547],
        [0.4258, 0.4902],
        [0.9160, 0.9805]], requires_grad=True), Parameter containing:
tensor([[0.0000, 0.0000, 0.0645, 0.0000, 0.4902]], requires_grad=True)]


In [84]:
n_operand = 3 #number of operands
n_row = 4 #number of rows
n_column = 2 #number of columns of internal nodes
n_F = len(F) #number of operators

def create_node():
        internal1 = torch.cat((torch.randint(n_F,(n_row,1)),torch.randint(n_operand,(n_row,n_column))),1)
        internal2 = torch.cat((torch.randint(n_F,(n_row,1)),torch.randint(n_row,(n_row,n_column))),1)
        internal = torch.cat((internal1,internal2))
        return internal

internal = create_node()

node_index = torch.randint(len(internal),(1,1)).item()

S = [internal,node_index]

In [16]:
def compute(S,operand):
        pheno = S[0]
        i = S[1]
        def compute1(i):
                operator = F[pheno[i][0]]
                output = operator(operand[pheno[i][1]],operand[pheno[i][2]])
                return output
        if i < n_row:
                output = compute1(i)
        else:
                operator = F[pheno[i][0]]
                output = operator(compute1(pheno[i][1]), compute1(pheno[i][2]))
        return output

def mutation(S):
        mu = torch.randn((8,3))#*0.3
        #mu = torch.clamp(mu, -1, 1) + 1
        mutated_internal = torch.round(S[0]+mu)
        l1 = torch.FloatTensor([[0, 0, 0]])
        u1 = torch.FloatTensor([[3, 2, 2]])
        mutated_internal[:4] = torch.max(torch.min(mutated_internal[:4], u1), l1)
        l2 = torch.FloatTensor([[0, 2, 2]])
        u2 = torch.FloatTensor([[3, 3, 3]])
        mutated_internal[4:] = torch.max(torch.min(mutated_internal[4:], u2), l2)
        mutated_internal = mutated_internal.int()
        mutated_node_index = torch.randint(len(internal),(1,1)).item()
        mutated_S = [mutated_internal,mutated_node_index]
        return mutated_S