Trying out

$$
\begin{eqnarray*}
    \dot{x} & = & W x \\
    \dot{a} & = & r_U \left( \frac{1}{N_a} \mathbf{1}_{N_a} - I_{N_a} \right)a + V x \\
    \dot{W} & = & \alpha \left( I_{N_x} - x x^T \right) + \beta \sum_k S^{(k)} a_k
\end{eqnarray*}
$$

where

$$ V_{ij} = -\ell_{ij} r_V $$

where $\ell_{ij}$ denotes the number of times neuron $j$ is presynaptic for a connection owned by astrocyte $i$, and $S^{(k)}_ij = 1$ if synapse $j \to i$ is owned by astrocyte $k$.

In [1]:
import numpy as np
import scipy.integrate

import matplotlib.pyplot as plt

In [27]:
Nx = 20
Na = 20
alpha = 0.1

ru = 0.001
rv = 0.25
beta = 0.01

In [28]:
U = ru * ( - np.eye( Na ) + (1 / Na) * np.ones( (Na, Na) ) )

In [29]:
synapses = []
for i in range( Nx ):
    for j in range( Nx ):
        if i != j:
            synapses.append( (i, j) )

In [30]:
synapse_astrocytes = np.random.randint( Na, size = (len( synapses ),) )

In [31]:
astrocyte_synapses = [ [] for x in range( Na ) ]
for s, a in zip( synapses, synapse_astrocytes ):
    astrocyte_synapses[a].append( s )

In [32]:
astrocyte_presynaptic = [ [] for x in range( Na ) ]
for s, a in zip( synapses, synapse_astrocytes ):
    astrocyte_presynaptic[a].append( s[1] )

# for i in range( Na ):
#     astrocyte_presynaptic[i] = list( set( astrocyte_presynaptic[i] ) )

In [33]:
Wk = [ None for x in range( Na ) ]
for a, ss in enumerate( astrocyte_synapses ):
    cur_Wk = np.zeros( (Nx, Nx) )
    for s in ss:
        cur_Wk[s[0], s[1]] = 1.
    Wk[a] = cur_Wk

In [34]:
V = np.zeros( (Na, Nx) )

for a in range( Na ):
    for ps in astrocyte_presynaptic[a]:
        V[a, ps] -= rv

In [35]:
def deriv( t, y ):
    
    x = y[:Nx]
    a = y[Nx:(Nx+Na)]
    W = np.reshape( y[(Nx+Na):], (Nx, Nx) )
    
    x_dot = np.matmul( W, x )
    a_dot = np.matmul( U, a ) + np.matmul( V, x )
    
    W_dot_a = np.zeros( W.shape )
    for k in range( Na ):
        W_dot_a += Wk[k] * a[k]
    W_dot = alpha * ( np.eye( Nx ) - np.outer( x, x ) ) + beta * W_dot_a
    
    y_dot = np.zeros( y.shape[0] )
    y_dot[:Nx] = x_dot
    y_dot[Nx:(Nx+Na)] = a_dot
    y_dot[(Nx+Na):] = W_dot.flatten()
    
    return y_dot

In [36]:
t_span = [0, 3e3]
t_eval = np.arange( t_span[0], t_span[-1], 5e-2 )

In [37]:
x0 = np.random.randn( Nx )
a0 = np.random.randn( Na )
W0 = np.random.randn( Nx, Nx )

y0 = np.zeros( (Nx + Na + Nx*Nx,) )
y0[:Nx] = x0
y0[Nx:(Nx+Na)] = a0
y0[(Nx+Na):] = W0.flatten()

In [None]:
sol = scipy.integrate.solve_ivp( deriv, t_span, y0 )

In [None]:
t_star = sol.t
y_star = sol.y

x_star = y_star[:Nx, :]
a_star = y_star[Nx:(Nx+Na), :]
W_star = np.reshape( y_star[(Nx+Na):, :], (Nx, Nx, y_star.shape[1]) )

In [None]:
plt.figure( figsize = (12, 5 ) )
plt.plot( t_star, x_star.T )
plt.show()

In [None]:
plt.figure( figsize = (24, 5) )
plt.imshow( x_star, aspect = 'auto', cmap = 'Spectral_r',
            extent = (t_star[0], t_star[-1], 0, x_star.shape[0]) )
plt.clim( -8, 8 )

In [None]:
plt.figure( figsize = (48, 5) )
plt.plot( t_star, x_star[0, :], 'k-' )
plt.xlim( 2500, 2550 )
plt.show()

In [None]:
plt.figure( figsize = (12, 5) )
plt.plot( t_star, a_star.T )
plt.show()

In [None]:
c = np.corrcoef( a_star )

In [None]:
plt.imshow( c )
plt.clim( -1, 1 )

In [None]:
cx = np.corrcoef( x_star )

In [None]:
plt.imshow( cx )
plt.clim( -1, 1 )

In [None]:
corrs = []
for i in range( cx.shape[0] ):
    for j in range( i+1, cx.shape[1] ):
        corrs.append( cx[i,j] )
corrs = np.array( corrs )

In [None]:
np.median( np.abs( corrs ) )