In [1]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

In [2]:
N = 2
rcpt_types = 3
t = np.arange(0,5000.1, 0.1)
fs = np.arange(0, 101, 1)
fs = fs/1000 #convert from Hz to kHz
c = np.array([0, 25, 50, 100])
cons = len(c)

#[1.7 -1.525; 1.7 -0.5]
Jee = 1.7
Jei = 1.525
Jie = 1.7
Jii = 0.5 
i2e = 0.6

J0 = np.array([[Jee, -Jei], [Jie, -Jii]])
J0



DeviceArray([[ 1.70000005, -1.52499998],
             [ 1.70000005, -0.5  ]], dtype=float32)

In [3]:
W = J0
print('Det(W) =', '%.3f' % np.linalg.det(W))

#define nonlinearity parameters
kk = 0.04
nn = 2

if rcpt_types > 1:
    g = np.array([1, i2e, 0, 0, 0, 0])
else:
    g = np.array([1, i2e])

tauE = 15
tau_ratio = 1
tauI = tauE/tau_ratio

# tau = np.ones(N)
# tau[:2:] = tauE
# tau[1:2:] = tauI

t_scale = 1
tauNMDA = 100 * t_scale
tauAMPA = 3 * t_scale
tauGABA = 5 * t_scale
nmdaRatio = 0.1 # sets the ratio of NMDA cells to AMPA cell 

NoiseNMDAratio = 0
NoiseTau = 1 * t_scale


totalT = t[-1]
dt = np.mean(np.diff(t))
dt2 = np.sqrt(dt)

Det(W) = 1.742


In [4]:
if rcpt_types > 1:
    tauS = np.array([tauAMPA, tauNMDA, tauGABA])
    tauSvec = np.kron(tauS, np.ones(N))
    
    Wtot = np.array([[(1-nmdaRatio)*Jee, 0, 0, 0, 0, 0], [(1-nmdaRatio)* Jie, 0, 0, 0, 0, 0], [0, 0, nmdaRatio * Jee, 0, 0, 0], [0, 0, nmdaRatio * Jie, 0, 0, 0], [0, 0, 0, 0, 0, -Jei], [0, 0, 0, 0, 0, -Jii]])
    
else:
    tauSvec = tau
    Wrcpt = W
    Wtot = W
    

In [5]:
v1 = np.zeros([N*rcpt_types, cons])
r_starcons = np.zeros([N, cons])

# vv_t = np.zeros([len(t), N*rcpt_types, cons])
# v1 = np.zeros(N*rcpt_types)
# r_t = np.zeros([len(t), N, cons])
# tt_c = np.zeros([len(t), cons])

In [6]:
def rect_powerLaw(vv, kk, nn):
    fv = kk*np.maximum(np.array([np.sum(vv[::2,:], axis=0), np.sum(vv[1::2,:],axis=0)]), np.zeros([N, cons]))**nn
    return fv

def dvdt(vv):
    delta_v = np.reshape(dt/tauSvec, [6,1]) * (-vv + Wtot @ np.kron(np.ones([rcpt_types,1]), rect_powerLaw(vv, kk ,nn)) + I_total)
    return delta_v


In [7]:
I_total = np.kron( g.reshape(N*rcpt_types,1),  c.reshape(1,cons))
# I_total = np.kron(c, g)
# I_total = np.transpose(np.reshape(I_total, [4, 6]))
I_total

DeviceArray([[  0.      ,  25.      ,  50.      , 100.      ],
             [  0.      ,  15.00000095,  30.00000191,  60.00000381],
             [  0.      ,   0.      ,   0.      ,   0.      ],
             [  0.      ,   0.      ,   0.      ,   0.      ],
             [  0.      ,   0.      ,   0.      ,   0.      ],
             [  0.      ,   0.      ,   0.      ,   0.      ]],
            dtype=float32)

In [10]:
Conv = True
indt = 0

for tt in t:

    dv = dvdt(v1)
    v1 = dv + v1
#     vv_t[:,:, tt] = v1
    indt += 1
    
    if tt >= totalT - 1000*dt:
        itr = np.max(np.abs(dv))
        
        if itr > 0.01:
            Conv = False

In [22]:
r_starcons = rect_powerLaw(v1, kk, nn)
rs = nn*kk**(1/nn)*r_starcons**(1-1/nn)
v1

DeviceArray([[  0.       ,  36.47018433,  63.64903641, 111.50494385],
             [  0.       ,  26.4701767,  43.64904022,  71.50494385],
             [  0.       ,   1.27440214,   1.51650786,   1.27829814],
             [  0.       ,   1.27440214,   1.51650786,   1.27829814],
             [  0.       , -24.05441475, -50.23154068, -99.07228851],
             [  0.       ,  -7.88669729, -16.46936417, -32.48272324]],
            dtype=float32)

In [123]:
Phi = lambda rr: np.diag(rr)
eE = np.array([[1], [0]])
eE = np.kron(np.ones([rcpt_types,1]), eE)
J = np.array([[Wtot @ np.kron(np.ones([rcpt_types, rcpt_types]), Phi(rs[:,cc])) -np.eye(N*rcpt_types)] for cc in range(cons)])
Gf = np.array([-1j * 2 * np.pi * ff * np.diag(np.kron(tauS, np.ones(N))) - J[cc,1] for cc in range(cons) for ff in fs])

cuE = np.array([eE for cc in range(cons) for ff in fs])

In [184]:
iGf = np.linalg.inv(Gf)

x = np.einsum("ijk, ikm-> ijm", iGf, cuE)

y = (1-NoiseNMDAratio) * x[:, :N] + NoiseNMDAratio * x[:, N:(N+N)]
spect = np.einsum('ijk, ijm -> ikm ', np.conj(y), y)

In [186]:
spect = spect/np.mean(spect)

In [187]:
np.mean(spect)

DeviceArray(1.-6.1831726e-18j, dtype=complex64)