In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint
from mpl_toolkits.mplot3d import Axes3D

In [2]:
# common parameters
Time = 10 # total time
dt = 0.0001 # time step
Nt=int(Time/dt) # number of time steps
N = 100 # neurons

# dimensionality of problem (Lorentz)
K = 3
# kernel
np.random.seed(12) # for reproducibility
D=np.random.randn(K,N) # N x K - Weights associated to each neuron
D = D / np.sqrt(np.diag(D.T@D)) # normalize
D = D / 10 # avoid too big discontinuities
# threshold
T = np.diag(D.T@D)/2

In [4]:
#We will need to write the dynamical system as a Kronecker product
# Standard Lorentz parameters
rho=28;sigma=10;beta=8/3
# - Linear part (Y' = AY)
A = np.array([[-sigma,sigma,0],[rho,-1,0],[0,0,-beta]])
# - Nonlinear part (Y' = A Y +  B Y(x)Y) where (x) is Kroenecker product
B = np.array([[0]*9,
              [0,0,-1]+[0]*6,
              [0,1]+[0]*7])

In [5]:
lam=0.75 # time constant for the neurons
# Initialize Voltage, spikes, rate
V = np.zeros([N,Nt+1])
s = np.zeros([N,Nt+1])
r = np.zeros([N,Nt+1])

# Set initial conditions
x0 = np.array([-11.40057002, -14.01987468,  27.49928125])
r[:,0] = np.array(np.linalg.pinv(D)@x0) # pseudo-inverse - "cheaty" way of getting the right firing rate
V[:,0] = 0.9*T

# Network connections:
# - fast
O_f = D.T @ D
# - slow
O_s = D.T @ (lam*np.identity(K) + A) @ D
# - nonlinear
O_nl =D.T @ B @ np.kron(D,D)

In [None]:
# actual simulation: simple forward Euler
for t in range(Nt):
    V[:,t+1] = V[:,t] + dt*(-lam*V[:,t] - O_f@s[:,t] + O_s @ r[:,t] + O_nl @ np.kron(r[:,t],r[:,t]))
    
    # check if there are neurons whose voltage is above threshold
    above = np.where(V[:,t+1] > T)[0]
    
    # introduce a control to let only one neuron fire at the time
    if len(above):
        s[np.argmax(V[:,t+1]),t+1] = 1/dt
    
    # update rate
    r[:,t+1] = r[:,t] + dt*(s[:,t+1] - lam*r[:,t])

In [None]:
# "Decode" - i.e. multiply rate with decoding matrix D
xn = D@r

In [None]:
# plot network output
plt.figure(figsize=(4,4))

ax = plt.subplot(1,1,1,projection='3d')

ax.plot(xn[0], xn[1], xn[2],alpha=0.9,color='dimgrey')
plt.axis('off')
plt.xlim([-17.5,15])
plt.ylim([-12,15])
ax.set_zlim([15,38])