In [2]:
from typing import Optional, NamedTuple, Tuple, Any, Sequence
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import heapq
import math
%matplotlib notebook

## Spiking network with the LIF neuron model

In [3]:
class SpikeFunction(torch.autograd.Function):
    """
    Spiking function with rectangular gradient.
    Source: https://www.frontiersin.org/articles/10.3389/fnins.2018.00331/full
    Implementation: https://github.com/combra-lab/pop-spiking-deep-rl/blob/main/popsan_drl/popsan_td3/popsan.py
    """

    @staticmethod
    def forward(ctx: Any, v: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(v)  # save voltage - thresh for backwards pass
        return v.gt(0.0).float()

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        v, = ctx.saved_tensors
        grad_input = grad_output.clone()
        spike_pseudo_grad = (v.abs() < 0.5).float()  # 0.5 is the width of the rectangle
        return grad_input * spike_pseudo_grad, None  # ensure a tuple is returned

In [4]:
# Placeholder for LIF state
class LIFState(NamedTuple):
    z: torch.Tensor
    v: torch.Tensor
    i: torch.Tensor

class LIF(nn.Module):
    """
    Leaky-integrate-and-fire neuron with learnable parameters.
    """

    def __init__(self, size: int):
        super().__init__()
        self.size = size
        # Initialize all parameters randomly as U(0, 1)
        self.i_decay = torch.rand(size) #self.i_decay = nn.Parameter(torch.rand(size))
        self.v_decay = torch.rand(size)
        self.thresh = torch.rand(size)
        self.spike = SpikeFunction.apply  # spike function

    def forward(
        self,
        synapse: nn.Module,
        z: torch.Tensor,
        state: Optional[LIFState] = None,
    ) -> Tuple[torch.Tensor, LIFState]:
        # Previous state
        if state is None:
            state = LIFState(
                z=torch.zeros_like(synapse(z)),
                v=torch.zeros_like(synapse(z)),
                i=torch.zeros_like(synapse(z)),
            )
        # Update state
        i = state.i * self.i_decay + synapse(z)
        #print(self.i_decay)
        #print(synapse(z))
        v = state.v * self.v_decay * (1.0 - state.z) + i
        z = self.spike(v - self.thresh)

        return z, LIFState(z, v, i)

In [5]:
class SpikingMLP(nn.Module):
    """
    Spiking network with LIF neuron model.
    """

    def __init__(self, sizes: Sequence[int]):
        super().__init__()
        self.sizes = sizes
        self.spike = SpikeFunction.apply

        # Define layers
        self.synapses = nn.ModuleList()
        self.neurons = nn.ModuleList()
        self.states = []
        # Loop over current (accessible with 'size') and next (accessible with 'sizes[i]') element
        for i, size in enumerate(sizes[:-1], start=1):
            # Parameters of synapses and neurons are randomly initialized
            self.synapses.append(nn.Linear(size, sizes[i], bias=False))
            self.neurons.append(LIF(sizes[i]))
            self.states.append(None)
           

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        for i, (neuron, synapse) in enumerate(zip(self.neurons, self.synapses)):
            z, self.states[i] = neuron(synapse, z, self.states[i])
        return z

    def reset(self):
        """
        Resetting states when you're done is very important!
        """
        for i, _ in enumerate(self.states):
            self.states[i] = None


## XOR Task

In [6]:
# Data and labels
#samples = 50 #10000
#x = torch.randint(2, (samples, 2)).float()
x = torch.Tensor([[0,0],[0,1],[1,0],[1,1]])
y = (x.sum(-1) == 1).float()
print(x)
print(y)
#print(f"Class imbalance: {y.sum() / y.shape[0]}")

tensor([[0., 0.],
        [0., 1.],
        [1., 0.],
        [1., 1.]])
tensor([0., 1., 1., 0.])


In [47]:
T = 10

def encoding(inp):
    #height, width = x.shape
    inp = inp.unsqueeze(-1)
    rate = 0.7
    return (torch.rand((len(inp),T))<(inp*rate+0.1)).float()
    
#test = encoding(x[3])  
#print(test)

In [44]:
def reward(output,label):
    if output == 0 and label == 0:
        R = 1
    if output == 1 and label == 1:
        R = 1
    if output == 1 and label == 0:
        R = -1 
    if output == 0 and label == 1:
        R = -1
    return R

In [82]:
sizes = [2,5,1]
snn = SpikingMLP(sizes)
snn

SpikingMLP(
  (synapses): ModuleList(
    (0): Linear(in_features=2, out_features=5, bias=False)
    (1): Linear(in_features=5, out_features=1, bias=False)
  )
  (neurons): ModuleList(
    (0): LIF()
    (1): LIF()
  )
)

## RSTDP

In [76]:
delta = 1
taup = 20
taum = 20
Ap = 2
Am = -1
tauz = 25
lr = 0.0005

In [33]:
def update_E():
    for i in range(e1.size()[1]):
        for j in range(e1.size()[0]):
            e1[j][i] = e1[j][i]*math.exp(-delta/tauz) + P12p[j]*a[i]+P12m[i]*snn.states[0][0][j]  
            
    for i in range(e2.size()[1]):
        for j in range(e2.size()[0]):
            e2[j][i] = e2[j][i]*math.exp(-delta/tauz) + P23p[j]*snn.states[0][0][i]+P23m[i]*snn.states[1][0][j]
    return e1,e2    

In [83]:
for i in range(1000):
    #print(i)
    for d in range(len(x)):
        inp_spike = encoding(x[d])
        P12p = 0
        P12m = 0
        P23p = 0
        P23m = 0
        e1 = torch.zeros_like(list(snn.parameters())[0])
        e2 = torch.zeros_like(list(snn.parameters())[1])
        snn.reset()
        #print(list(snn.parameters()))
        for t in range(T):
            out = snn(inp_spike[...,t])
            #print(out.item())
            R = reward(out.item(),y[d].item())
            P12p = P12p*math.exp(-delta/taup)+Ap*snn.states[0][0]
            P12m = P12m*math.exp(-delta/taum)+Am*inp_spike[...,t]
            P23p = P23p*math.exp(-delta/taup)+Ap*snn.states[1][0]
            P23m = P23m*math.exp(-delta/taum)+Am*snn.states[0][0]
            e1,e2 = update_E()
        #print(e1,e2)
            R = reward(out.item(),y[d].item())
            for idx,w in enumerate(snn.parameters()):
                if idx == 0:
                    w.data = w.data + lr*R*e1
                    w.data = torch.clamp(w.data,-1,1)
                else:
                    w.data = w.data + lr*R*e2
                    w.data = torch.clamp(w.data,0,1)
        if i%100 == 0:
            print('output',out.item())
            print('label',y[d].item())
            print(list(snn.parameters()))
            print(e1,e2)
        
    #print(i)
    #print(e1,e2)
    #print(P12p,P12m,P23p,P23m)
  

output 0.0
label 0.0
[Parameter containing:
tensor([[ 0.6909,  0.0824],
        [ 0.6587, -0.2232],
        [ 0.2863,  0.0133],
        [ 0.4934, -0.0448],
        [ 0.4625,  0.6565]], requires_grad=True), Parameter containing:
tensor([[0.1993, 0.0339, 0.0000, 0.2545, 0.0000]], requires_grad=True)]
output 0.0
label 1.0
[Parameter containing:
tensor([[ 0.6909,  0.0824],
        [ 0.6587, -0.2232],
        [ 0.2836,  0.0105],
        [ 0.4908, -0.0477],
        [ 0.4011,  0.6204]], requires_grad=True), Parameter containing:
tensor([[1.9925e-01, 3.3895e-02, 7.5048e-05, 2.5459e-01, 8.4750e-04]],
       requires_grad=True)]
output 1.0
label 1.0
[Parameter containing:
tensor([[ 0.7061,  0.1210],
        [ 0.7502, -0.0724],
        [ 0.3746,  0.1685],
        [ 0.5902,  0.1285],
        [ 0.5005,  0.7966]], requires_grad=True), Parameter containing:
tensor([[0.2359, 0.0904, 0.0748, 0.3450, 0.0912]], requires_grad=True)]
output 1.0
label 0.0
[Parameter containing:
tensor([[ 0.6798,  0.0968],
 

tensor([[0.1336, 0.1338, 0.0010, 0.0202, 0.0071]], requires_grad=True)]
output 0.0
label 0.0
[Parameter containing:
tensor([[ 0.6179,  0.0292],
        [ 0.5686, -0.0325],
        [ 0.0782, -0.0053],
        [ 0.0266, -1.0000],
        [-1.0000,  0.0204]], requires_grad=True), Parameter containing:
tensor([[0.1336, 0.1338, 0.0010, 0.0202, 0.0071]], requires_grad=True)]
output 0.0
label 1.0
[Parameter containing:
tensor([[ 0.6179,  0.0292],
        [ 0.5686, -0.0325],
        [ 0.0782, -0.0053],
        [ 0.0266, -1.0000],
        [-1.0000,  0.0204]], requires_grad=True), Parameter containing:
tensor([[0.1336, 0.1338, 0.0010, 0.0202, 0.0071]], requires_grad=True)]
output 0.0
label 1.0
[Parameter containing:
tensor([[ 0.6179,  0.0292],
        [ 0.5686, -0.0325],
        [ 0.0782, -0.0053],
        [ 0.0293, -1.0000],
        [-1.0000,  0.0204]], requires_grad=True), Parameter containing:
tensor([[0.1336, 0.1338, 0.0010, 0.0202, 0.0071]], requires_grad=True)]
output 0.0
label 0.0
[Parame

## Hebbian Rule

In [46]:
def update_weight_Hebb(trace1,trace2,trace3):
    w1 = torch.zeros_like(list(snn.parameters())[0])
    for i in range(w1.size()[1]):
        for j in range(w1.size()[0]):
            w1[j][i] = alpha*trace2[j]*trace1[i]            

    w2 = torch.zeros_like(list(snn.parameters())[1])
    for i in range(w2.size()[1]):
        for j in range(w2.size()[0]):
            w2[j][i] = alpha*trace3[j]*trace2[i]
    return w1,w2 
#print(update_weight_STDP(trace1,trace2,trace3))

(tensor([[0.0006, 0.0006],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0001, 0.0001],
        [0.0000, 0.0000]], grad_fn=<CopySlices>), tensor([[2.8000e-04, 0.0000e+00, 0.0000e+00, 7.0000e-05, 0.0000e+00]],
       grad_fn=<CopySlices>))


In [59]:
n_iter = 201
decay = 0.002
alpha = 0.175

for i in range(n_iter):
    output = []
    for inp in x:
        inp_spike = encoding(inp)
        trace1 = 0
        trace2 = 0
        trace3 = 0
        snn.reset()
        for t in range(T):
            out = snn(inp_spike[...,t])
            trace1 += inp_spike[...,t]
            trace2 += snn.states[0][0]
            trace3 += snn.states[1][0]
        output.append(out.item())
        trace1 = trace1/T
        trace2 = trace2/T
        trace3 = trace3/T
        w1,w2 = update_weight_Hebb(trace1, trace2, trace3)
        for idx,w in enumerate(snn.parameters()):
            if idx == 0:
                w.data = w1 + w.data*(1-decay)
            else:
                w.data = w2 + w.data*(1-decay)
    if i%100 == 0:
        print(i)
        print(output)
            

0
[0.0, 0.0, 0.0, 0.0]
100
[0.0, 0.0, 0.0, 0.0]
200
[0.0, 0.0, 0.0, 0.0]


In [45]:
print(list(snn.parameters()))

[Parameter containing:
tensor([[ 0.4554, -0.2323],
        [-0.7028, -0.2477],
        [-0.4719, -0.0389],
        [-0.3774,  0.6311],
        [ 0.0950,  0.0724]], requires_grad=True), Parameter containing:
tensor([[ 0.3258,  0.0890,  0.3129, -0.1612, -0.2478]], requires_grad=True)]


## Test=========================

In [84]:
print(e1,e2)

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], grad_fn=<CopySlices>) tensor([[0., 0., 0., 0., 0.]], grad_fn=<CopySlices>)


In [101]:
y_hat = (count>5).float()
y_hat = y_hat.squeeze(-1)
criterion = torch.nn.MSELoss()
loss = criterion(y_hat.resize(50),y)
print(loss)

tensor(0.4800)


In [119]:
print(x[0])
snna = SpikingMLP(sizes)
out = snna(x[0])
print(snna.states[0][0][0])
print(test[0][...,1])#[0:2,0])

tensor([0., 1.])
tensor(0., grad_fn=<SelectBackward>)
tensor([0., 1.])
