In [9]:
import numpy as np
import scipy.sparse as sp
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
import ipywidgets as widgets

## Neural Network Dynamics:

$$\dot{r}_{1} = \gamma[ -r_{1} + tanh(m_{11}r_{1} + m_{12}r_{2} + u_{1})]$$
$$\dot{r}_{2} = \gamma[ -r_{2} + tanh(m_{21}r_{1} + m_{22}r_{2} + u_{2})]$$

In [10]:
def f(t, z, gamma, m11, m12, m21, m22, u1, u2):
    r1, r2 = z
    z_next = np.zeros(2)
    z_next[0] = gamma*(-r1 + np.tanh(m11*r1 + m12*r2 + u1))
    z_next[1] = gamma*(-r2 + np.tanh(m21*r1 + m22*r2 + u2))
    return z_next

## Nullclines:

$$\dot{r_1}: r_{2} = \frac{1}{m_{12}}\left(\frac{1}{2}\ln\left(\frac{1+r_1}{1-r_1}\right) - m_{11}r_{1} - u_{1} \right)$$
$$\dot{r_2}: r_{1} = \frac{1}{m_{21}}\left(\frac{1}{2}\ln\left(\frac{1+r_2}{1-r_2}\right) - m_{22}r_{2} - u_{2} \right)$$

In [11]:
def r1_null(r1, m11, m12, u1):
    return 1/m12*(1/2*np.log((1+r1)/(1-r1)) - m11*r1 -u1)

def r2_null(r2, m21, m22, u2):
    return 1/m21*(1/2*np.log((1+r2)/(1-r2)) - m22*r2 -u2)

In [12]:
@widgets.interact(gamma=(0, 5, 0.1), m11=(0, 5, 0.1), m12=(0, 5, 0.1), m21=(0, 5, 0.1), m22=(0, 5, 0.1), 
                 u1=(0, 5, 0.1), u2=(0, 5, 0.1), r10=(-1, 1, 0.1), r20=(-1, 1, 0.1))
def update(gamma=0.1, m11=0.1, m12=0.1, m21=0.1, m22=0.1, u1=0.1, u2=0.1, r10=0, r20=0):
    t_span = [0, 100]
    z0 = [r10, r20]
    params = (gamma, m11, m12, m21, m22, u1, u2)
    sln = solve_ivp(f, t_span, z0, args=params, dense_output=True)
    t = np.linspace(0, 100, 10000)
    z = sln.sol(t)
    
    
    fig, ax = plt.subplots(1, 1, figsize=(11,7))
    r_range = np.linspace(-1, 1, 1000)
    ax.plot(r_range, r1_null(r_range, m11, m12, u1))
    ax.plot(r2_null(r_range, m21, m22, u2), r_range)
    ax.plot(z[0], z[1])
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)

interactive(children=(FloatSlider(value=0.1, description='gamma', max=5.0), FloatSlider(value=0.1, description…

## Plot Vector Field and Nullclines:

In [13]:
def color_map(x):
    colors = np.log10(x)
    return np.nan_to_num(colors, neginf=0) 
    
def phase_portrait(ax, f, t, xlim, ylim, num_pts, args, norm_arrows=True, stream=True, quiver=True):
    
    x = np.linspace(-xlim, xlim, num_pts)
    y = np.linspace(-ylim, ylim, num_pts)
    X, Y = np.meshgrid(x, y)
    u = np.zeros((num_pts, num_pts))
    v = np.zeros((num_pts, num_pts))
    flow_mag = np.zeros((num_pts, num_pts))

    for i in range(num_pts):
        for j in range(num_pts):
            u[i,j], v[i,j] = f(0, [X[i,j], Y[i,j]], *args)
            flow_mag[i,j] = np.sqrt(u[i,j]**2 + v[i,j]**2)
            if (u[i,j], v[i,j]) != (0,0):
                u[i,j] *= 1/flow_mag[i,j]
                v[i,j] *= 1/flow_mag[i,j]
    
    ax.quiver(X[::1, ::1], Y[::1, ::1], u[::1, ::1], v[::1, ::1], color_map(flow_mag[::1, ::1]), cmap="winter", alpha=0.1)
        
    ax.streamplot(X, Y, u, v,  color=color_map(flow_mag), cmap="winter", density=1)
    ax.grid(True, which='both')  

In [75]:
@widgets.interact(gamma=(0, 5, 0.1), m11=(-5, 5, 0.1), m12=(-5, 5, 0.1), m21=(-5, 5, 0.1), m22=(-5, 5, 0.1), 
                u1=(-5, 5, 0.1), u2=(-5, 5, 0.1), r10=(-1, 1, 0.1), r20=(-1, 1, 0.1))
def update(gamma=0.1, m11=0.1, m12=0.1, m21=0.1, m22=0.1, u1=0.1, u2=0.1, r10=0, r20=0):
    t_span = [0, 500]
    z0 = [r10, r20]
    params = (gamma, m11, m12, m21, m22, u1, u2)
    sln = solve_ivp(f, t_span, z0, args=params, dense_output=True)
    t = np.linspace(0, 500, 10000)
    z = sln.sol(t)
    
    
    fig, ax = plt.subplots(1, 1, figsize=(11,7))
    r_range = np.linspace(-1, 1, 10000)
    phase_portrait(ax, f, t, 2, 2, 15, params)
    ax.plot(r_range, r1_null(r_range, m11, m12, u1), 'y', linewidth=2, label='$\dot{r}_{1}=0$')
    ax.plot(r2_null(r_range, m21, m22, u2), r_range, 'r', linewidth=2, label='$\dot{r}_{2}=0$')
    ax.plot(z[0], z[1], 'k', linewidth=2)
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)
    ax.legend()

interactive(children=(FloatSlider(value=0.1, description='gamma', max=5.0), FloatSlider(value=0.1, description…

## Find Fixed Points:

In [8]:
from scipy.optimize import fsolve

In [61]:
def nullclines_s(r, gamma, m11, m12, m21, m22, u1, u2):
    r1, r2 = r
    return (r1_null(r1, m11, m12, u1) - r1, r2_null(r2, m21, m22, u2) - r2)

def nullclines_o(r, gamma, m11, m12, m21, m22, u1, u2):
    r1, r2 = r
    r_next = np.zeros(2)
    r_next[0] = -r1 + np.tanh(m11*r1 + m12*r2 + u1)
    r_next[1] = -r2 + np.tanh(m21*r1 + m22*r2 + u2)
    return r_next
    

In [20]:
nx, ny = (10, 10)
x = np.linspace(-1, 1, nx)
y = np.linspace(-1, 1, ny)
xv, yv = np.meshgrid(x, y)
xv, yv

for i in range(3):
    for j in range(3):
        print(xv[i][j], yv[i][j])
        
import itertools  
list(itertools.product(xv, yv))

0.0 0.0
0.5 0.0
1.0 0.0
0.0 0.5
0.5 0.5
1.0 0.5
0.0 1.0
0.5 1.0
1.0 1.0


[(array([0. , 0.5, 1. ]), array([0., 0., 0.])),
 (array([0. , 0.5, 1. ]), array([0.5, 0.5, 0.5])),
 (array([0. , 0.5, 1. ]), array([1., 1., 1.])),
 (array([0. , 0.5, 1. ]), array([0., 0., 0.])),
 (array([0. , 0.5, 1. ]), array([0.5, 0.5, 0.5])),
 (array([0. , 0.5, 1. ]), array([1., 1., 1.])),
 (array([0. , 0.5, 1. ]), array([0., 0., 0.])),
 (array([0. , 0.5, 1. ]), array([0.5, 0.5, 0.5])),
 (array([0. , 0.5, 1. ]), array([1., 1., 1.]))]

In [37]:
gamma = 0.1 
m11=0.1
m12=0.1
m21=0.1
m22=0.1
u1=0.1
u2=0.1

args = (gamma, m11, m12, m21, m22, u1, u2)
fsolve(nullclines, [0, 0], args)


    

array([0.11060624, 0.11060624])

In [65]:
nx, ny = (10, 10)
x = np.linspace(-1, 1, nx)
y = np.linspace(-1, 1, ny)
xv, yv = np.meshgrid(x, y)
xv, yv

for i in range(3):
    for j in range(3):
        print(xv[i][j], yv[i][j])

def find_roots(f, n, xrange, yrange, args):
    nx, ny = (10, 10)
    x = np.linspace(-1, 1, nx)
    y = np.linspace(-1, 1, ny)
    xv, yv = np.meshgrid(x, y)
    xv, yv
    roots = np.zeros((len(guesses), n))
    for i in range(len(guesses)):
        roots[i] = fsolve(f, guesses[i], args, maxfev=5000)
    return roots
    

-1.0 -1.0
-0.7777777777777778 -1.0
-0.5555555555555556 -1.0
-1.0 -0.7777777777777778
-0.7777777777777778 -0.7777777777777778
-0.5555555555555556 -0.7777777777777778
-1.0 -0.5555555555555556
-0.7777777777777778 -0.5555555555555556
-0.5555555555555556 -0.5555555555555556


In [72]:
nx, ny = (10, 10)
x = np.linspace(-1, 1, nx)
y = np.linspace(-1, 1, ny)
xv, yv = np.meshgrid(x, y)
xv, yv

for i in range(3):
    for j in range(3):
        print(xv[i][j], yv[i][j])

def find_roots(f, d, xrange, yrange, n, args):
    x1, x2 = xrange
    y1, y2 = yrange
    x = np.linspace(x1, x2, n)
    y = np.linspace(y1, y2, n)
    xv, yv = np.meshgrid(x, y)
    
    roots = np.zeros((n*n, d))
    
    for i in range(n):
        for j in range(n):
            roots[i+j] = fsolve(f, [xv[i][j], yv[i][j]], args, maxfev=5000)
    return roots

-1.0 -1.0
-0.7777777777777778 -1.0
-0.5555555555555556 -1.0
-1.0 -0.7777777777777778
-0.7777777777777778 -0.7777777777777778
-0.5555555555555556 -0.7777777777777778
-1.0 -0.5555555555555556
-0.7777777777777778 -0.5555555555555556
-0.5555555555555556 -0.5555555555555556


In [47]:
gamma = 0.1 
m11=0.1
m12=1.3
m21=1.7
m22=0.1
u1=0.1
u2=0.1

args = (gamma, m11, m12, m21, m22, u1, u2)
find_roots(nullclines, 2, [[0, 0], [1/2, 1/2]], args)

  return 1/m12*(1/2*np.log((1+r1)/(1-r1)) - m11*r1 -u1)
  improvement from the last ten iterations.


array([[-0.26647662, -0.12583826],
       [ 0.5       ,  0.5       ]])

In [77]:
@widgets.interact(gamma=(0, 5, 0.1), m11=(0, 5, 0.1), m12=(0, 5, 0.1), m21=(0, 5, 0.1), m22=(0, 5, 0.1), 
                 u1=(0, 5, 0.1), u2=(0, 5, 0.1), r10=(-1, 1, 0.1), r20=(-1, 1, 0.1))
def update(gamma=0.1, m11=0.1, m12=0.1, m21=0.1, m22=0.1, u1=0.1, u2=0.1, r10=0, r20=0):
    t_span = [0, 100]
    z0 = [r10, r20]
    params = (gamma, m11, m12, m21, m22, u1, u2)
    sln = solve_ivp(f, t_span, z0, args=params, dense_output=True)
    t = np.linspace(0, 100, 10000)
    z = sln.sol(t)
    fixed_points = find_roots(nullclines_o, 2, [-1, 1], [-1, 1], 100, params)
    
    
    fig, ax = plt.subplots(1, 1, figsize=(11,7))
    r_range = np.linspace(-1, 1, 1000)
    ax.plot(r_range, r1_null(r_range, m11, m12, u1))
    ax.plot(r2_null(r_range, m21, m22, u2), r_range)
    ax.plot(z[0], z[1])
    for x in fixed_points:
        ax.plot(x[0], x[1], 'ko')
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)

interactive(children=(FloatSlider(value=0.1, description='gamma', max=5.0), FloatSlider(value=0.1, description…