In [None]:
import numpy as np 
import matplotlib.pyplot as plt 
import random
from tqdm import tqdm
import time
class NEURON:
    def __init__(self, dt):
        self.dt = dt
        self.E = -65  # 静止膜電位 mV
        self.th = -50 # しきい値 mV
        self.act = 20 # 活動電位 mV
        self.ref = 40 # 不応期 ms
        self.rest = 0 # 残り不応期時間 ms
        self.last = 0 # 最後の不応期終了時間 ms
        self.tau = 15 # スパイク応答関数の時定数 
        self.phase = random.randint(0,9)*0.1 #0~0.9の位相をランダムに、Cortexの細胞用
        self.fire = False
        
    # input作成、Cortexレイヤー用
    def inputctx(self, time, peg):
        for i in range(len(peg)-1):
            if peg[i] < time < peg[i+1]:
                if peg[i]+self.phase*(peg[i+1]-peg[i]) < time < peg[i]+(self.phase+0.1)*(peg[i+1]-peg[i]):
                    return random.random()*16  ##0~16 半分は閾値を超えて半分は超えない
                    break
        return 0
    
    
    ## IF膜電位計算式 
    def v_cortex (self, weight, spikes, time, peg): 
        currents = [] 
        for sp in spikes: 
            currents.append(self.current( sp, time )) 
        m = np.dot( weight, currents )*50 + self.E + self.inputctx(time,peg) #細胞に流れ込む量と静止膜電位の差し引き
        if self.rest > 0 : # 不応期中は入力を受け付けない 
            self.rest -= self.dt 
            self.fire = False
            if self.rest <= 0: 
                self.last = time 
            return self.E 
        if m > self.th : # スパイク!! 
            self.fire = True
            m = self.act 
            self.rest = self.ref 
        return m
    
    def v_other (self, weight, spikes, time ): 
        currents = [] 
        for sp in spikes: 
            currents.append(self.current( sp, time )) 
        m = np.dot( weight, currents )*50 + self.E  #細胞に流れ込む量と静止膜電位の差し引き
        if self.rest > 0 : # 不応期中は入力を受け付けない 
            self.fire = False
            self.rest -= self.dt 
            if self.rest <= 0: 
                self.last = time 
            return self.E 
        if m > self.th : # スパイク!! 
            self.fire = True
            m = self.act 
            self.rest = self.ref 
        return m

    ## スパイク応答関数 
    def kernel(self, time): 
        return time * np.exp( -time / self.tau ) #tau/eがEPSPの値になっている

    ## 入力電流 
    def current(self, spikes, time ): 
        sum = 0 
        for spike in spikes: 
            if spike > time : 
                break 
            if spike > self.last:
                sum += self.kernel(time-spike) 
        return sum 

In [None]:
class strNEURON:
    def __init__(self, dt):
        self.dt = dt
        self.E = -65  # 静止膜電位 mV
        self.th = -50 # しきい値 mV
        self.act = 20 # 活動電位 mV
        self.ref = 40 # 不応期 ms
        self.rest = 0 # 残り不応期時間 ms
        self.last = 0 # 最後の不応期終了時間 ms
        self.tau = 15 # スパイク応答関数の時定数 
        self.rsn = random.randint(200,400) #細胞の最も反応する周期ms
        self.lph = random.randint(0,9)*0.1  #左からの入力の位相特性
        self.rph = random.randint(0,9)*0.1
        self.fire = False
        if random.random()<0.5:
            self.d = True
        else:
            self.d = False #D1ならTrue、D2ならFalseを代入
        
    ## IF膜電位計算式 
    def v_str (self, weight, spikes, time): 
        currents = [] 
        for sp in spikes: 
            currents.append(self.current( sp, time )) 
        m = np.dot( weight, currents )*100 + self.E + self.resonant(time) + self.d_input() #細胞に流れ込む量と静止膜電位の差し引き
        if self.rest > 50 : # 不応期中は入力を受け付けない 
            self.fire = False
            self.rest -= self.dt 
            if self.rest <= 0: 
                self.last = time 
            return self.E 
        if m > self.th : # スパイク!! 
            self.fire = True
            m = self.act 
            self.rest = self.ref 
        return m 
    
    ## レゾナント 静止膜電位の揺らぎ
    def resonant(self, time):
        return np.sin(time*np.pi/self.rsn)*10
    
    ## スパイク応答関数 
    def kernel(self, time): 
        return time * np.exp( -time / self.tau ) #tau/eがEPSPの値になっている

    ## 入力電流 
    def current(self, spikes, time ): 
        sum = 0 
        for spike in spikes: 
            if spike > time : 
                break 
            if spike > self.last:
                sum += self.kernel(time-spike) 
        return sum 
    
    ##ドーパミン入力
    def d_input(self):
        d_out = random.random()*15.5
        return d_out #とりあえずランダム

In [None]:
class weight:
    def __init__(self, cortex, striatum, snr):
        self.cortex = cortex
        self.striatum = striatum  ##preとpostのニューロン数を定義
        self.snr = snr
        
    #striatum-snr間の結合
    def create_weight_str_snr(self, prephasearr, postphasearr, dlabel): #dlabelにはTrue or False
        weight = np.zeros((self.striatum, self.snr))
        for i in range(self.striatum):
            for j in range(self.snr):
                if random.random() > (prephasearr[i]-postphasearr[j])**2:
                    if dlabel[i]:  #striatumのニューロンがD1ならば
                        weight[i, j] = 0.5
                    else:
                        weight[i, j] = -0.5    #D2ならば重みを反転
                else:
                    weight[i, j] = 0
        return weight
    
    #Cortex-striatum間の結合
    def create_weight_ctx_str(self, prephasearr, postphasearr): #Phase dependentな形にしておく
        weight = np.zeros((self.cortex*2, self.striatum))
        for i in range(self.cortex*2):
            for j in range(self.striatum*2):
                if random.random() > (prephasearr[i]-postphasearr[j%500])**2:
                    weight[i, j%500] = 0.5
                else:
                    weight[i, j%500] = 0
        return weight
    
    #snr-cortex間の結合
    def create_weight_snr_ctx(self, prephasearr, postphasearr):
        weight = np.zeros((self.snr, self.cortex))
        for i in range(self.snr):
            for j in range(self.cortex):
                if random.random()> (prephasearr[i]-postphasearr[j])**2:
                    weight[i, j] = 0.5
                else:
                    weight[i, j] = 0
        return weight

In [None]:
if __name__ == "__main__":
    T = 1000
    dt = 10
    
    ##それぞれのレイヤーの細胞数を宣言、Classを格納
    Rcortex=Lcortex = [x for x in range(500)]
    Rstr=Lstr = [x for x in range(500)]
    Rsnr=Lsnr = [x for x in range(500)]
    cortex = len(Rcortex)
    striatum = len(Rstr)
    snr = len(Rsnr)

    for i in range(cortex):
        Rcortex[i] = NEURON(dt)
        Lcortex[i] = NEURON(dt)
    for i in range(striatum):
        Rstr[i] = strNEURON(dt)
        Lstr[i] = strNEURON(dt)
    for i in range(len(Rsnr)):
        Rsnr[i] = NEURON(dt)
        Lsnr[i] = NEURON(dt)
        
    ##D1 or D2のラベル作成
    Rdlabel=Ldlabel = np.empty(striatum)
    for i in range(striatum):
        Rdlabel[i] = Rstr[i].d
        Ldlabel[i] = Lstr[i].d
    
    
    ##peg input作成
    Rpeg=Lpeg=[]
    for i in range(400):
        Rpeg.append(260*i + random.randint(0,180))
        Lpeg.append(260*i + random.randint(0,180))
    

    ##weightクラスに代入する用の各layerごとのPhase配列を作成
    ctx_phase = np.zeros(cortex*2)
    rstr_phase=lstr_phase = np.zeros(striatum*2)
    rsnr_phase=lsnr_phase = np.zeros(snr)
    
    for i in range(cortex):
        ctx_phase[i] = Rcortex[i].phase
        ctx_phase[cortex+i] = Lcortex[i].phase
    for i in range(striatum):
        rstr_phase[i] = Rstr[i].rph
        rstr_phase[striatum+i] = Rstr[i].lph
        lstr_phase[i] = Lstr[i].rph
        lstr_phase[striatum+i] = Lstr[i].lph
    for i in range(snr):
        rsnr_phase[i] = Rsnr[i].phase
        lsnr_phase[i] = Lsnr[i].phase    

        
    ##weightを入れていく(preneuron, postneuron)の二次元配列にweightを格納したw(ij)をreturn
    ## w[0]=rctx-rstr, w[1]=rctx-lstr, w[2]=lctx-rstr, w[3]=lctx-lstr, w[4]=rstr-rsnr, w[5]=lstr-lsnr, w[6]=rsnr-rctx, w[7]=lsnr-lctx
    w = [x for x in range(8)]
    w[0] = weight(cortex,striatum,snr).create_weight_ctx_str(ctx_phase, rstr_phase)
    w[1] = weight(cortex,striatum,snr).create_weight_ctx_str(ctx_phase, lstr_phase)
    w[2] = weight(cortex,striatum,snr).create_weight_str_snr(rstr_phase, rsnr_phase, Rdlabel)
    w[3] = weight(cortex,striatum,snr).create_weight_str_snr(lstr_phase, lsnr_phase, Ldlabel)
    w[4] = weight(cortex,striatum,snr).create_weight_snr_ctx(rsnr_phase, ctx_phase[:cortex])
    w[5] = weight(cortex,striatum,snr).create_weight_snr_ctx(lsnr_phase, ctx_phase[cortex:])
    
    
    ##スパイクの計算
    ctx_spikes = [[] for i in range(cortex*2)]
    rstr_spikes = [[] for i in range(striatum)]
    lstr_spikes = [[] for i in range(striatum)]
    rsnr_spikes = [[] for i in range(snr)]
    lsnr_spikes = [[] for i in range(snr)]
    
    ##膜電位の変化を格納、appendでdtごとに膜電位を格納。
    RCtxmemV = [[] for i in range(cortex)] 
    LCtxmemV = [[] for i in range(cortex)]
    RstrmemV = [[] for i in range(striatum)]
    LstrmemV = [[] for i in range(striatum)]
    RsnrmemV = [[] for i in range(snr)]
    LsnrmemV = [[] for i in range(snr)]

    
    tau = T//dt
    for t in tqdm(range(tau)):
        for i in range(cortex):
            RCtxmemV[i].append(Rcortex[i].v_cortex(w[4][:,i], rsnr_spikes, t*dt, Rpeg ))
            LCtxmemV[i].append(Lcortex[i].v_cortex(w[5][:,i], lsnr_spikes, t*dt, Lpeg))
            if Rcortex[i].fire:
                ctx_spikes[i].append(t)
            if Lcortex[i].fire:
                ctx_spikes[cortex+i].append(t)
        
        for i in range(striatum):
            RstrmemV[i].append(Rstr[i].v_str(w[0][:,i], ctx_spikes, t*dt ))
            LstrmemV[i].append(Lstr[i].v_str(w[1][:,i], ctx_spikes, t*dt ))
            if Rstr[i].fire:
                rstr_spikes[i].append(t)
            if Lstr[i].fire:
                lstr_spikes[i].append(t)
            
        for i in range(snr):
            RsnrmemV[i].append(Rsnr[i].v_other(w[2][:,i], rstr_spikes, t*dt ))
            LsnrmemV[i].append(Lsnr[i].v_other(w[3][:,i], lstr_spikes, t*dt ))
            if Rsnr[i].fire:
                rsnr_spikes[i].append(t)
            if Lsnr[i].fire:
                lsnr_spikes[i].append(t)
    
    for i in range(50):
        plt.figure(num=i)
        plt.plot(RCtxmemV[i*10])
        plt.figure(num=50+i)
        plt.plot(LCtxmemV[i*10])
        plt.figure(num=100+i)
        plt.plot(RstrmemV[i*10])
        plt.figure(num=150+i)
        plt.plot(LstrmemV[i*10])
        plt.figure(num=200+i)
        plt.plot(RstrmemV[i*10])
        plt.figure(num=250+i)
        plt.plot(RstrmemV[i*10])
    
    plt.show()