In [1532]:
from pymonntorch import *
import torch
from matplotlib import pyplot as plt
import math
import random


In [142]:
class TimeResolution(Behavior):
    def initialize(self, net):
        net.dt = self.parameter('dt', 1)
        net.t_passed = 0


In [1757]:
class ConstantCurrent(Behavior):
    def initialize(self, ng):
        self.value = self.parameter('value', None)
    
    def forward(self, ng):
        ng.I = ng.vector(mode=self.value)

class unitStepCurrent(Behavior):
    def initialize(self, ng):
        self.value = self.parameter('value', None)
        self.tp = self.parameter('tp', None)
        ng.I = ng.vector(mode = 50)
    
    def forward(self, ng):
        if ng.network.iteration * ng.network.dt == self.tp:
            ng.I += ng.vector(mode = self.value)

class SinCurrent(Behavior):
    def initialize(self, ng):
        self.value = self.parameter('value', None)
        self.amp = self.parameter('amp', 1)
        self.freq = self.parameter('freq', 1)
        ng.I = ng.vector(mode=self.value)
    
    def forward(self, ng):
        ng.I = ng.vector(mode=self.value + self.amp * math.sin(self.freq * ng.network.iteration * ng.network.dt / (2 * math.pi)))

class SquareCurrent(Behavior):
    def initialize(self, ng):
        self.value = self.parameter('value', None)
        self.amp = self.parameter('amp', 1)
        self.freq = self.parameter('freq', 1)
        ng.I = ng.vector(mode=self.value)
    
    def forward(self, ng):
        sine_calc = self.value + self.amp * math.sin(self.freq * ng.network.iteration * ng.network.dt / (2 * math.pi))
        ng.I = ng.vector(mode=(self.value if sine_calc > self.value else self.value/2))

class StepSlopeCurrent(Behavior):
    def initialize(self, ng):
        self.value = self.parameter('value', None)
        self.p = self.parameter('p', 0.01)
        self.amp = self.parameter('amp', 1)
        self.freq = self.parameter('freq', 1)
        ng.I = ng.vector(mode=self.value)
    
    def forward(self, ng):
        sine_value = self.value + self.amp * math.sin(self.freq * ng.network.iteration * ng.network.dt / (2 * math.pi))
        ng.I += ng.vector(mode=(2 if sine_value > self.value else 0))

class StepCurrent(Behavior):
    def initialize(self, ng):
        self.value = self.parameter('value', None)
        self.amp = self.parameter('amp', 1)
        self.freq = self.parameter('freq', 1)
        ng.I = ng.vector(mode=0)
    
    def forward(self, ng):
        sine_value = math.sin(self.freq * ng.network.iteration * ng.network.dt / (2 * math.pi))
        if ng.network.iteration * ng.network.dt % 50 == 0:
            ng.I += ng.vector(mode=20)

class NoiseCurrent(Behavior):
    def initialize(self,ng):
        self.value = self.parameter('value', None)
        self.limit = self.parameter('limit', 10)
        self.rate = self.parameter('rate', 0.2)
        ng.I = ng.vector(mode=self.value)
    
    def forward(self, ng):
        if(random.random() < self.rate):
            ng.I += self.limit * (2 * random.random()-1)


In [1470]:
def drawPlots(net, th=-15, rest=-67):
    fig, axs = plt.subplots(3, sharex=True, figsize=(10, 10))

    axs[0].plot(net['ng1_rec', 0].variables['u'][:,:5])
    axs[0].axhline(y = th, color='black', linestyle='--')
    axs[0].axhline(y = rest, color='black', linestyle='--')
    axs[0].set_ylabel("u(t)")

    axs[1].plot(net['ng1_rec', 0].variables['I'][:,:5])
    axs[1].set_ylabel("I(t)")

    axs[2].scatter(net['ng1_eventrec', 0].variables['spike'][:, 0], net['ng1_eventrec', 0].variables['spike'][:, 1])
    axs[2].set_xlabel("time")
    axs[2].set_ylabel("spikes")

In [1861]:
class LIF(Behavior):
    def initialize(self, ng):
        self.R = self.parameter('R', 1)
        self.tau = self.parameter('tau', 10)
        self.threshold = self.parameter('threshold', -50)
        self.u_rest = self.parameter('u_rest', -67)
        self.u_reset = self.parameter('u_reset', -72)
        self.rp = self.parameter('rp', 0)

        ng.u = ng.vector('uniform') * (self.threshold - self.u_reset) * 1.1
        ng.u += self.u_reset
        ng.rp = ng.vector(mode = 0)

        ng.spike = ng.u >= self.threshold
        ng.u[ng.spike] = self.u_reset
        ng.rp[ng.spike] = self.rp

    def forward(self, ng):
        #dynamic
        leakage = -(ng.u - self.u_rest)
        current = self.R * ng.I 
        ng.u += (leakage + current)/self.tau * ng.network.dt
        ng.u[ng.rp > 0] = self.u_reset
        ng.rp -= ng.network.dt
        
        #fire and reset 
        ng.spike = ng.u >= self.threshold
        ng.u[ng.spike] = self.u_reset
        ng.rp[ng.spike] = self.rp

class ELIF(Behavior):
    def initialize(self, ng):
        self.R = self.parameter('R', 1)
        self.tau = self.parameter('tau', 10)
        self.threshold = self.parameter('threshold', -50)
        self.u_rest = self.parameter('u_rest', -67)
        self.u_reset = self.parameter('u_reset', -72)
        self.rp = self.parameter('rp', 0)
        self.theta_rh = self.parameter('theta_rh', self.threshold)
        self.sharpness = self.parameter('sharpness', 30)

        ng.u = ng.vector('uniform') * (self.theta_rh - self.u_reset) * 1.1
        ng.u += self.u_reset
        ng.rp = ng.vector(mode = 0)

        ng.spike = ng.u >= self.theta_rh
        ng.u[ng.spike] = self.u_reset
        ng.rp[ng.spike] = self.rp

    def forward(self, ng):
        #dynamic
        leakage = -(ng.u - self.u_rest)
        current = self.R * ng.I 
        expn = self.sharpness * torch.exp((ng.u - self.theta_rh)/self.sharpness)
        ng.u += (leakage + current + expn)/self.tau * ng.network.dt
        ng.u[ng.rp > 0] = self.u_reset
        ng.rp -= ng.network.dt

        #fire and reset 
        ng.spike = ng.u >= self.theta_rh
        ng.u[ng.spike] = self.u_reset
        ng.rp[ng.spike] = self.rp

class AELIF(Behavior):
    def initialize(self, ng):
        self.R = self.parameter('R', 1)
        self.tau_m = self.parameter('tau_m', 10)
        self.threshold = self.parameter('threshold', -50)
        self.u_rest = self.parameter('u_rest', -67)
        self.u_reset = self.parameter('u_reset', -72)
        self.rp = self.parameter('rp', 0)
        self.theta_rh = self.parameter('theta_rh', self.threshold)
        self.sharpness = self.parameter('sharpness', 20)
        self.tau_w = self.parameter('tau_w', 10)
        self.a = self.parameter('a', 1)
        self.b = self.parameter('b', 1)

        ng.u = ng.vector('uniform') * (self.threshold - self.u_reset) * 1.1
        ng.u += self.u_reset
        ng.w = ng.vector(mode = 20)
        ng.rp = ng.vector(mode = 0)

        ng.spike = ng.u >= self.threshold
        ng.u[ng.spike] = self.u_reset
        ng.rp[ng.spike] = self.rp
        

    def forward(self, ng):
        #dynamic
        leakage = -(ng.u - self.u_rest)
        current = self.R * ng.I 
        expn = self.sharpness * torch.exp((ng.u - self.theta_rh)/self.sharpness)
        adapt = self.R * ng.w
        ng.u += (leakage + current - adapt + expn)/self.tau_m * ng.network.dt
        ng.u[ng.rp > 0] = self.u_reset
        ng.rp -= ng.network.dt

        #fire and reset 
        ng.spike = ng.u >= self.threshold
        ng.u[ng.spike] = self.u_reset
        ng.rp[ng.spike] = self.rp
        
        #set weights
        ng.w += (self.a * (-1 * leakage) - ng.w + self.b * self.tau_w * sum(bool(x) for x in ng.spike)) / self.tau_w * ng.network.dt


In [1418]:
class LIF_R(Behavior):
    def initialize(self, ng):
        self.R = self.parameter('R', 1)
        self.tau = self.parameter('tau', 10)
        self.threshold = self.parameter('threshold', -50)
        self.u_rest = self.parameter('u_rest', -67)
        self.u_reset = self.parameter('u_reset', -72)
        self.rp = self.parameter('rp', 0)

        ng.u = ng.vector('uniform') * (self.threshold - self.u_reset) * 1.1
        ng.u += self.u_reset
        ng.rp = ng.vector(mode = 0)
        ng.th = ng.vector(mode = self.threshold)

        ng.spike = ng.u >= ng.th
        ng.u[ng.spike] = self.u_reset
        ng.th[ng.spike] = self.threshold + 100
        ng.rp[ng.spike] = self.rp

    def forward(self, ng):
        #dynamic
        leakage = -(ng.u - self.u_rest)
        current = self.R * ng.I 
        ng.u += (leakage + current)/self.tau * ng.network.dt
        #refractory period
        ng.rp -= ng.network.dt
        ng.th[ng.rp <= 0] = self.threshold
        
        #fire and reset 
        ng.spike = ng.u >= ng.th
        ng.u[ng.spike] = self.u_reset
        ng.rp[ng.spike] = self.rp
        ng.th[ng.spike] = self.threshold + 100

class ELIF_R(Behavior):
    def initialize(self, ng):
        self.R = self.parameter('R', 1)
        self.tau = self.parameter('tau', 10)
        self.threshold = self.parameter('threshold', -50)
        self.u_rest = self.parameter('u_rest', -67)
        self.u_reset = self.parameter('u_reset', -72)
        self.rp = self.parameter('rp', 0)
        self.theta_rh = self.parameter('theta_rh', self.threshold)
        self.sharpness = self.parameter('sharpness', 20)

        ng.u = ng.vector('uniform') * (self.theta_rh - self.u_reset) * 1.1
        ng.u += self.u_reset
        ng.th = ng.vector(mode = self.threshold)
        ng.rp = ng.vector(mode = 0)

        ng.spike = ng.u >= ng.th
        ng.u[ng.spike] = self.u_reset
        ng.th[ng.spike] = self.threshold + 100
        ng.rp[ng.spike] = self.rp

    def forward(self, ng):
        #dynamic
        leakage = -(ng.u - self.u_rest)
        current = self.R * ng.I 
        expn = self.sharpness * torch.exp((ng.u - self.theta_rh)/self.sharpness)
        ng.u += (leakage + current + expn)/self.tau * ng.network.dt
        
        ng.rp -= ng.network.dt
        ng.th[ng.rp <= 0] = self.threshold

        #fire and reset 
        ng.spike = ng.u >= ng.th
        ng.u[ng.spike] = self.u_reset
        ng.rp[ng.spike] = self.rp
        ng.th[ng.spike] = self.threshold + 100

class AELIF_R(Behavior):
    def initialize(self, ng):
        self.R = self.parameter('R', 1)
        self.tau_m = self.parameter('tau_m', 10)
        self.threshold = self.parameter('threshold', -50)
        self.u_rest = self.parameter('u_rest', -67)
        self.u_reset = self.parameter('u_reset', -72)
        self.rp = self.parameter('rp', 0)
        self.theta_rh = self.parameter('theta_rh', self.threshold)
        self.sharpness = self.parameter('sharpness', 20)
        self.tau_w = self.parameter('tau_w', 10)
        self.a = self.parameter('a', 1)
        self.b = self.parameter('b', 1)

        ng.u = ng.vector('uniform') * (self.theta_rh - self.u_reset) * 1.1
        ng.u += self.u_reset
        ng.th = ng.vector(mode = self.threshold)
        ng.rp = ng.vector(mode = 0)
        ng.w = ng.vector(mode = 15)

        ng.spike = ng.u >= ng.th
        ng.u[ng.spike] = self.u_reset
        ng.th[ng.spike] = self.threshold + 100
        ng.rp[ng.spike] = self.rp
        

    def forward(self, ng):
        #dynamic
        leakage = -(ng.u - self.u_rest)
        current = self.R * ng.I 
        expn = self.sharpness * torch.exp((ng.u - self.theta_rh)/self.sharpness)
        adapt = self.R * ng.w
        ng.u += (leakage + current - adapt + expn)/self.tau_m * ng.network.dt

        #fire and reset 
        ng.rp -= ng.network.dt
        ng.th[ng.rp <= 0] = self.threshold
        
        #set weights
        ng.w += (self.a * (-1 * leakage) - ng.w + self.b * self.tau_w * sum(bool(x) for x in ng.spike)) / self.tau_w * ng.network.dt
        ng.spike = ng.u >= ng.th
        ng.u[ng.spike] = self.u_reset
        ng.rp[ng.spike] = self.rp
        ng.th[ng.spike] = self.threshold + 100


In [None]:
def FI_curve(net):
    fi_curve = []
    i_curve = []
    f_curve = [[]]
    ns = 0
    dt = 0
    tot = 0
    it = 100
    for i in range (1, it):
        ns = int(sum(net['spike'][0] == i * dt)[0])
        cur = int(net['I',0][i, 0])
        frq = ns
        if i_curve.count(cur) == 0:
            i_curve.append(cur)
            f_curve.append([frq])
        else:
            f_curve[i_curve.index(cur)].append(frq)
    
    for i in range (0, len(i_curve)):
        f_curve[i] = sum(f_curve[i])/len(f_curve[i])
        fi_curve.append((i_curve[i], f_curve[i]))
    
    fi_curve.sort(key=lambda x: x[0])
    print(fi_curve)
    
    i_curve = [x[0] for x in fi_curve]
    f_curve = [x[1] for x in fi_curve]
    
    plt.plot(i_curve, f_curve)

    i_curve_array = np.array(i_curve) 
    slope, intercept = np.polyfit(i_curve, f_curve, 1)
    line = slope * i_curve_array + intercept

    plt.plot(i_curve, line, color='red')
    plt.show()

In [None]:
net = Network(behavior={1: TimeResolution(dt=1)})
ng1 = NeuronGroup(net=net,
                  size=1,
                  behavior={3:ConstantCurrent(value=50), 
                           5: LIF(R=1, 
                                    threshold=-23, 
                                    u_rest=-70, 
                                    u_reset=-74,  
                                    tau=10), 
                           7: Recorder(tag='ng1_rec', variables=['u', 'I']), 
                           8: EventRecorder(tag='ng1_eventrec', variables=['spike'])})

net.initialize()
net.simulate_iterations(200)
drawPlots(net, -23, -70)


In [None]:
net = Network(behavior={1: TimeResolution(dt=1)})
ng1 = NeuronGroup(net=net,
                  size=1,
                  behavior={3:unitStepCurrent(value=50, tp = 20), 
                           5: ELIF(R=1, 
                                    threshold=-30, 
                                    u_rest=-70, 
                                    u_reset=-74,  
                                    tau=10,
                                    rp=0, 
                                    sharpness=30), 
                           7: Recorder(tag='ng1_rec', variables=['u', 'I']), 
                           8: EventRecorder(tag='ng1_eventrec', variables=['spike'])})

net.initialize()
net.simulate_iterations(200)
drawPlots(net, -30, -70)

# plt.plot(net['ng1_rec', 0].variables['th'][:,:1])
# plt.xlabel("time")
# plt.ylabel("I")
# plt.show()

In [None]:
net = Network(behavior={1: TimeResolution(dt=1)})
ng1 = NeuronGroup(net=net,
                  size=1,
                  behavior={3:ConstantCurrent(value=80), 
                           5: AELIF(R=1, 
                                    threshold=-30, 
                                    u_rest=-70, 
                                    u_reset=-74,  
                                    tau_m=20,
                                    tau_w=30,
                                    rp=0, 
                                    sharpness=30,
                                    a=2, 
                                    b=40), 
                           7: Recorder(tag='ng1_rec', variables=['u', 'I']), 
                           8: EventRecorder(tag='ng1_eventrec', variables=['spike'])})

net.initialize()
net.simulate_iterations(200)
drawPlots(net, -30, -70)

# plt.plot(net['ng1_rec', 0].variables['th'][:,:1])
# plt.xlabel("time")
# plt.ylabel("I")
# plt.show()