In [1]:
import numpy as np
import scipy.sparse as sp
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
import ipywidgets as widgets
from scipy.optimize import fsolve



## Neural Network Dynamics:

$$
\begin{align}
\dot{r}_{1} &= \gamma[ -r_{1} + tanh(m_{11}r_{1} + m_{12}r_{2} + u_{1})] \\
\dot{r}_{2} &= \gamma[ -r_{2} + tanh(m_{21}r_{1} + m_{22}r_{2} + u_{2})] \\
\dot{m}_{11} &= \eta_{1}r_{1}^{2} - \eta_2m_{11} \\
\dot{m}_{12} &= \eta_{1}r_{1}r_{2} - \eta_2m_{12} \\
\dot{m}_{21} &= \eta_{1}r_{1}^{2} - \eta_2m_{21} \\
\dot{m}_{22} &= \eta_{1}r_{2}^{2} - \eta_2m_{22} \\
\end{align}
$$


In [2]:
def f_s(t, z, gamma, u1, u2, eta1, eta2):
    r1, r2, m11, m12, m21, m22= z
    z_next = np.zeros(6)
    z_next[0] = gamma*(-r1 + np.tanh(m11*r1 + m12*r2 + u1))
    z_next[1] = gamma*(-r2 + np.tanh(m21*r1 + m22*r2 + u2))
    z_next[2] = eta1*r1*r1 - eta2*m11
    z_next[3] = eta1*r1*r2 - eta2*m12
    z_next[4] = eta1*r1*r2 - eta2*m21
    z_next[5] = eta1*r2*r2 - eta2*m22
    return z_next

def f(t, z, gamma, m11, m12, m21, m22, u1, u2):
    r1, r2 = z
    z_next = np.zeros(2)
    z_next[0] = gamma*(-r1 + np.tanh(m11*r1 + m12*r2 + u1))
    z_next[1] = gamma*(-r2 + np.tanh(m21*r1 + m22*r2 + u2))
    return z_next

## Nullclines:
$$
\begin{align}
\dot{r}_{1}: r_{2} &= \frac{1}{m_{12}}\left(\frac{1}{2}\ln\left(\frac{1+r_1}{1-r_1}\right) - m_{11}r_{1} - u_{1} \right) \\
\dot{r}_{2}: r_{1} &= \frac{1}{m_{21}}\left(\frac{1}{2}\ln\left(\frac{1+r_2}{1-r_2}\right) - m_{22}r_{2} - u_{2} \right) \\
\dot{m}_{11}: m_{11} &= \frac{\eta_{1}}{\eta_{2}}r_{1}^{2} \\
\dot{m}_{12}: m_{12} &= \frac{\eta_{1}}{\eta_{2}}r_{1}r_{2} \\
\dot{m}_{21}: m_{21} &= \frac{\eta_{1}}{\eta_{2}}r_{1}r_{2} \\
\dot{m}_{22}: m_{22} &= \frac{\eta_{1}}{\eta_{2}}r_{2}^{2} \\
\end{align}
$$

## Plotting Evolution of Trajectories in $r_{1}, r_{2}$ Space:

In [3]:
# For plotting r1, r2 nullclines for each set of weights
def r1_null(r1, m11, m12, u1):
    return 1/m12*(1/2*np.log((1+r1)/(1-r1)) - m11*r1 -u1)

def r2_null(r2, m21, m22, u2):
    return 1/m21*(1/2*np.log((1+r2)/(1-r2)) - m22*r2 -u2)

In [4]:
# Plotting Phase portraits
def color_map(x):
    colors = np.log10(x)
    return np.nan_to_num(colors, neginf=0) 
    
def phase_portrait(ax, f, t, xlim, ylim, num_pts, args, norm_arrows=True, stream=True, quiver=True):
    
    x = np.linspace(-xlim, xlim, num_pts)
    y = np.linspace(-ylim, ylim, num_pts)
    X, Y = np.meshgrid(x, y)
    u = np.zeros((num_pts, num_pts))
    v = np.zeros((num_pts, num_pts))
    flow_mag = np.zeros((num_pts, num_pts))

    
    for i in range(num_pts):
        for j in range(num_pts):
            u[i,j], v[i,j] = f(0, [X[i,j], Y[i,j]], *args)
            flow_mag[i,j] = np.sqrt(u[i,j]**2 + v[i,j]**2)
            if (u[i,j], v[i,j]) != (0,0):
                u[i,j] *= 1/flow_mag[i,j]
                v[i,j] *= 1/flow_mag[i,j]
    
    if quiver:
        ax.quiver(X[::1, ::1], Y[::1, ::1], u[::1, ::1], v[::1, ::1], color_map(flow_mag[::1, ::1]), cmap="winter", alpha=0.1)
        
    ax.streamplot(X, Y, u, v,  color=color_map(flow_mag), cmap="winter", density=1)
    ax.grid(True, which='both') 

In [5]:
# Functions used to find fixed points with fsolve
def nullclines_s(r, gamma, m11, m12, m21, m22, u1, u2):
    r1, r2 = r
    return (r1_null(r1, m11, m12, u1) - r1, r2_null(r2, m21, m22, u2) - r2)

def nullclines_o(r, gamma, m11, m12, m21, m22, u1, u2):
    r1, r2 = r
    r_next = np.zeros(2)
    r_next[0] = -r1 + np.tanh(m11*r1 + m12*r2 + u1)
    r_next[1] = -r2 + np.tanh(m21*r1 + m22*r2 + u2)
    return r_next
    

In [6]:
# Find roots of a function from R^2 to R^2
def find_roots(f, d, xrange, yrange, n, args, tol=1e-8):
    x1, x2 = xrange
    y1, y2 = yrange
    x = np.linspace(x1, x2, n)
    y = np.linspace(y1, y2, n)
    xv, yv = np.meshgrid(x, y)
    
    roots = []
    
    for i in range(n):
        for j in range(n):
            root = fsolve(f, [xv[i][j], yv[i][j]], args, maxfev=9000)
            if np.linalg.norm(f(root, *args)) < tol:
                roots.append(root)
    return roots

In [7]:
# equations for r1, r2 nullclines with mij nullcline equations subbed in

def r1_null_m(r1, eta1, eta2, u1):
    root = np.sqrt(eta2/(eta1*r1) * (1/2*np.log((1+r1)/(1-r1)) - u1) - r1**2)
    return np.asarray([root, -root])

def r2_null_m(r2, eta1, eta2, u2):
    root = np.sqrt(eta2/(eta1*r2) * (1/2*np.log((1+r2)/(1-r2)) - u2) - r2**2)
    return np.asarray([root, -root])

In [9]:
@widgets.interact(eta1=(0, 5, 0.001), eta2=(-0, 5, 0.001), u1=(-5, 5, 0.1), u2=(-5, 5, 0.1))
def update(eta1=0.01, eta2=0.01, u1=0.1, u2=0.1):
    r_range = np.linspace(-0.99, 0.99, 100000)

    fig, ax = plt.subplots(1, 1, figsize=(15,15))
    
    r1 = r1_null_m(r_range, eta1, eta2, u1)
    r2 = r2_null_m(r_range, eta1, eta2, u2)
    ax.plot(r_range, r1[0,:], 'b', linewidth=1, label='$\dot{r}_{1}=0$')
    ax.plot(r_range, r1[1,:], 'b', linewidth=1, label='$\dot{r}_{1}=0$')
    
    ax.plot(r2[0,:], r_range, 'g', linewidth=1, label='$\dot{r}_{2}=0$')
    ax.plot(r2[1,:], r_range, 'g', linewidth=1, label='$\dot{r}_{2}=0$')
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)

interactive(children=(FloatSlider(value=0.01, description='eta1', max=5.0, step=0.001), FloatSlider(value=0.01…

In [8]:
@widgets.interact(gamma=(0, 5, 0.1), m110=(-5, 5, 0.1), m120=(-5, 5, 0.1), m210=(-5, 5, 0.1), m220=(-5, 5, 0.1), 
                  u1=(-5, 5, 0.1), u2=(-5, 5, 0.1), r10=(-1, 1, 0.1), r20=(-1, 1, 0.1), t_start=(0, 10000, 10), 
                  t_end=(0, 19999, 10))
def update(gamma=0.1, m110=0.1, m120=0.1, m210=0.1, m220=0.1, u1=0.1, u2=0.1, r10=0, r20=0, t_start=0, t_end=19999):
    t_span = [0, 500]
    z0 = [r10, r20, m110, m120, m210, m220]
    params = (gamma, u1, u2, 0.0013, 0.001)
    sln = solve_ivp(f_s, t_span, z0, args=params, dense_output=True)
    t = np.linspace(0, 800, 20000)
    z = sln.sol(t)
    r1 = z[0]
    r2 = z[1]
    m11 = z[2]
    m12 = z[3]
    m21 = z[4]
    m22 = z[5]
    
    fig, ax = plt.subplots(1, 1, figsize=(15,15))
    params = (gamma, m11[t_end], m12[t_end], m21[t_end], m22[t_end], u1, u2)
    fixed_points = find_roots(nullclines_o, 2, [-1, 1], [-1, 1], 20, params)
    r_range = np.linspace(-0.99, 0.99, 10000)
    phase_portrait(ax, f, t, 2, 2, 15, params, quiver=False)
    ax.plot(r_range, r1_null(r_range, m11[t_end], m12[t_end], u1), 'y', linewidth=2, label='$\dot{r}_{1}=0$, m fixed')
    ax.plot(r2_null(r_range, m21[t_end], m22[t_end], u2), r_range, 'r', linewidth=2, label='$\dot{r}_{2}=0$, m fixed')
    ax.plot(z[0][t_start:t_end], z[1][t_start:t_end], 'k', linewidth=2)
    for x in fixed_points:
        ax.plot(x[0], x[1], 'ko')
    
    r1 = r1_null_m(r_range, eta1=0.0013, eta2=0.001, u1=u1)
    r2 = r2_null_m(r_range, eta1=0.0013, eta2=0.001, u2=u2)
    ax.plot(r_range, r1[0,:], color='orange', linewidth=2, label='$\dot{r}_{1}=0$')
    ax.plot(r_range, r1[1,:], color='orange', linewidth=2, label='$\dot{r}_{1}=0$')
    ax.plot(r2[0,:], r_range, 'm-', linewidth=2, label='$\dot{r}_{2}=0$')
    ax.plot(r2[1,:], r_range, 'm-', linewidth=2, label='$\dot{r}_{2}=0$')
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)
    ax.legend(loc='upper left')
    
    fig, ax = plt.subplots(1, 6, figsize=(15,5))
    ax[0].plot(t[t_start:t_end], m11[t_start:t_end],label="m11")
    ax[1].plot(t[t_start:t_end], m12[t_start:t_end], label="m12")
    ax[2].plot(t[t_start:t_end], m21[t_start:t_end], label="m21")
    ax[3].plot(t[t_start:t_end], m22[t_start:t_end], label="m22")
    #ax[4].plot(t[t_start:t_end], r1[t_start:t_end], label="r1")
    #ax[5].plot(t[t_start:t_end], r2[t_start:t_end], label="r1")
    
    ax[0].set_xlim(0, 50)
    ax[1].set_xlim(0, 50)
    ax[2].set_xlim(0, 50)
    ax[3].set_xlim(0, 50)
    ax[4].set_xlim(0, 50)
    ax[4].set_xlim(0, 50)
    
    ax[0].set_ylim(-5, 5)
    ax[1].set_ylim(-5, 5)
    ax[2].set_ylim(-5, 5)
    ax[3].set_ylim(-5, 5)
    ax[4].set_ylim(-1, 1)
    ax[5].set_ylim(-1, 1)
    
    ax[0].legend(loc='upper left')
    ax[1].legend(loc='upper left')
    ax[2].legend(loc='upper left')
    ax[3].legend(loc='upper left')
    ax[4].legend(loc='upper left')
    ax[5].legend(loc='upper left')

interactive(children=(FloatSlider(value=0.1, description='gamma', max=5.0), FloatSlider(value=0.1, description…

## Jacobian Matrix at fixed points:

In [10]:
def Jac(r1, r2, eps, eta):
    row1 = [eps*(-1+eta*r1**2*(1-r1**2)), eps*eta*r1*r2*(1-r1**2), eps*r1*(1-r1**2), eps*r2*(1-r1**2), 0, 0]
    row2 = [eps*eta*r1*r2*(1-r2**2), eps*(-1+eta*r1**2*(1-r1**2)), 0, 0, eps*r1*(1-r2**2), eps*r2*(1-r2**2)]
    row3 = [2*eta*r1, 0, -1, 0, 0, 0]
    row4 = [eta*r2, eta*r1, 0, -1, 0, 0]
    row5 = [eta*r2, eta*r1, 0, 0, -1, 0]
    row6 = [0, 2*eta*r2, 0, 0, 0, -1]
    
    return np.asarray([row1, row2, row3, row3, row4, row5, row6])