In [14]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
import ipywidgets as widgets
from scipy.optimize import fsolve
from scipy import linalg
from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt

In [4]:
# Find roots of a function from R^2 to R^2
def find_roots(f, d, xrange, yrange, n, args, tol=1e-8):
    x1, x2 = xrange
    y1, y2 = yrange
    x = np.linspace(x1, x2, n)
    y = np.linspace(y1, y2, n)
    xv, yv = np.meshgrid(x, y)
    
    roots = []
    
    for i in range(n):
        for j in range(n):
            root = fsolve(f, [xv[i][j], yv[i][j]], args, maxfev=9000)
            if np.linalg.norm(f(root, *args)) < tol:
                roots.append(root)
    return roots

# Plotting Phase portraits
def color_map(x):
    colors = np.log10(x)
    return np.nan_to_num(colors, neginf=0) 
    
def phase_portrait(ax, f, t, xlim, ylim, num_pts, args, norm_arrows=True, stream=True, quiver=True):
    
    x = np.linspace(-xlim, xlim, num_pts)
    y = np.linspace(-ylim, ylim, num_pts)
    X, Y = np.meshgrid(x, y)
    u = np.zeros((num_pts, num_pts))
    v = np.zeros((num_pts, num_pts))
    flow_mag = np.zeros((num_pts, num_pts))

    
    for i in range(num_pts):
        for j in range(num_pts):
            u[i,j], v[i,j] = f(0, [X[i,j], Y[i,j]], *args)
            flow_mag[i,j] = np.sqrt(u[i,j]**2 + v[i,j]**2)
            if (u[i,j], v[i,j]) != (0,0):
                u[i,j] *= 1/flow_mag[i,j]
                v[i,j] *= 1/flow_mag[i,j]
    
    if quiver:
        ax.quiver(X[::1, ::1], Y[::1, ::1], u[::1, ::1], v[::1, ::1], color_map(flow_mag[::1, ::1]), cmap="winter", alpha=0.1)
        
    ax.streamplot(X, Y, u, v,  color=color_map(flow_mag), cmap="winter", density=1)
    ax.grid(True, which='both') 

## Neural Network Dynamics:

$$
\begin{align}
\dot{r}_{1} &= \epsilon[ -r_{1} + tanh(m_{11}r_{1} + m_{12}r_{2} + u_{1})] \\
\dot{r}_{2} &= \epsilon[ -r_{2} + tanh(m_{21}r_{1} + m_{22}r_{2} + u_{2})] \\
\dot{m}_{11} &= \eta r_{1}^{2} - m_{11} \\
\dot{m}_{12} &= \eta r_{1}r_{2} - 2m_{12} \\
\dot{m}_{21} &= \eta r_{1}^{2} - m_{21} \\
\dot{m}_{22} &= \eta r_{2}^{2} - m_{22} \\
\end{align}
$$


In [5]:
# Function for finding solutions to 6D system using solve_ivp
def f_6(t, z, gamma, eta1, eta2, u1, u2):
    r1, r2, m11, m12, m21, m22= z
    z_next = np.zeros(6)
    z_next[0] = gamma*(-r1 + np.tanh(m11*r1 + m12*r2 + u1))
    z_next[1] = gamma*(-r2 + np.tanh(m21*r1 + m22*r2 + u2))
    z_next[2] = eta1*r1*r1 - eta2*m11
    z_next[3] = eta1*r1*r2 - eta2*m12
    z_next[4] = eta1*r1*r2 - eta2*m21
    z_next[5] = eta1*r2*r2 - eta2*m22
    return z_next

## Nullclines:
$$
\begin{align}
\dot{r}_{1}: r_{1} &= tanh(m_{11}r_{1} + m_{12}r_{2} + u_{1}) \\
\dot{r}_{2}: r_{2}&= tanh(m_{21}r_{1} + m_{22}r_{2} + u_{2}) \\
\dot{m}_{11}: m_{11} &= \eta r_{1}^{2} \\
\dot{m}_{12}: m_{12} &= \eta r_{1}r_{2} \\
\dot{m}_{21}: m_{21} &= \eta r_{1}r_{2} \\
\dot{m}_{22}: m_{22} &= \eta r_{2}^{2} \\
\end{align}
$$

Which reduce to the following pair of equations when subbing the $m_{ij}$'s into the $r_{1}$, $r_{2}$ equations

$$
\begin{align}
\dot{r}_{1}: r_{1} &= tanh(\eta(r_{1}^{3} - r_{1}r_{2}^{2}) + u_{1}) \\
\dot{r}_{2}: r_{2}&= tanh(\eta(r_{1}^{2}r_{2} + r_{2}^{3}) + u_{2}) \\
\end{align}
$$

In [6]:
# Equations for r1, r2 nullclines in 6D system with mij nullcline equations subbed in. Used for plotting nullclines
def r1_null_6(r1, r2, eta1, eta2, u1):
    return -r1 + np.tanh((eta1/eta2)*(r1**3 + r1*r2**2) + u1)

def r2_null_6(r1, r2, eta1, eta2, u2):
    return -r2 + np.tanh((eta1/eta2)*(r1**2*r2 + r2**3) + u2)

# Equations for nullclines used to find intersections (fixed points)
def nullclines_6(r, eta1, eta2, u1, u2):
    r1, r2 = r
    return np.asarray([r1_null_6(r1, r2, eta1, eta2, u1), r2_null_6(r1, r2, eta1, eta2, u2)])

In [7]:
# Jacobian matrix used to indicate stability of fixed points for 6D system
def jac(r1, r2, gamma, eta1, eta2):
    row1 = [gamma*(-1+eta1/eta2*r1**2*(1-r1**2)), gamma*eta1/eta2*r1*r2*(1-r1**2), gamma*r1*(1-r1**2), gamma*r2*(1-r1**2), 0, 0]
    row2 = [gamma*eta1/eta2*r1*r2*(1-r2**2), gamma*(-1+eta1/eta2*r1**2*(1-r1**2)), 0, 0, gamma*r1*(1-r2**2), gamma*r2*(1-r2**2)]
    row3 = [2*eta1*r1, 0, -eta2, 0, 0, 0]
    row4 = [eta1*r2, eta1*r1, 0, -eta2, 0, 0]
    row5 = [eta1*r2, eta1*r1, 0, 0, -eta2, 0]
    row6 = [0, 2*eta1*r2, 0, 0, 0, -eta2]
    
    return np.matrix([row1, row2, row3, row4, row5, row6])

We also want to visualise the dynamics of the system for fixed weights, while the trajectory is evolving

In [8]:
# Equation used to plot streamlines
def f_2(t, r, gamma, m11, m12, m21, m22, u1, u2):
    r1, r2 = r
    r_next = np.zeros(2)
    r_next[0] = gamma*(-r1 + np.tanh(m11*r1 + m12*r2 + u1))
    r_next[1] = gamma*(-r2 + np.tanh(m21*r1 + m22*r2 + u2))
    return r_next

# Equations for r1, r2 nullclines in 2D system with fixed mij. Used for plotting nullclines
def r1_null_2(r1, r2, m11, m12, u1):
    return -r1 + np.tanh(m11*r1 + m12*r2 + u1)

def r2_null_2(r1, r2, m21, m22, u2):
    return -r2 + np.tanh(m21*r1 + m22*r2 + u2)

# Equations for nullclines used to find intersections (fixed points)
def nullclines_2(r, m11, m12, m21, m22, u1, u2):
    r1, r2 = r
    return np.asarray([r1_null_2(r1, r2, m11, m12, u1), r2_null_2(r1, r2, m21, m22, u2)])

In [27]:
@widgets.interact(gamma=(0, 5, 0.001), eta1=(0, 5, 0.001), eta2=(0, 5, 0.001), m110=(-5, 5, 0.1), m120=(-5, 5, 0.1), m210=(-5, 5, 0.1), m220=(-5, 5, 0.1), 
                  u1=(-5, 5, 0.1), u2=(-5, 5, 0.1), r10=(-1, 1, 0.1), r20=(-1, 1, 0.1), t_start=(0, 10000, 10), 
                  t_end=(0, 19999, 10))
def update(gamma=0.1, eta1=0.1, eta2=0.1, m110=0.1, m120=0.1, m210=0.1, m220=0.1, u1=0.1, u2=0.1, r10=0, r20=0, t_start=0, t_end=0):
    # Solve ODE for 6D system with given initial conditions, from t=0, 500
    t_span = [0, 500]
    z0 = [r10, r20, m110, m120, m210, m220]
    params = (gamma, eta1, eta2, u1, u2)
    sln = solve_ivp(f_6, t_span, z0, args=params, dense_output=True)
    t = np.linspace(0, 500, 20000)
    z = sln.sol(t)
    r1 = z[0]
    r2 = z[1]
    m11 = z[2]
    m12 = z[3]
    m21 = z[4]
    m22 = z[5]
    

    fig=plt.figure(figsize=(15,15))

    gs=GridSpec(6,7, wspace=0.3) # 6 rows, 7 columns

    ax1=fig.add_subplot(gs[0:6,0:6]) # First row, first column
    ax2=fig.add_subplot(gs[0,6]) # First row, second column
    ax3=fig.add_subplot(gs[1,6]) # First row, third column
    ax4=fig.add_subplot(gs[2,6]) # Second row, span all columns
    ax5=fig.add_subplot(gs[3,6]) # First row, second column
    ax6=fig.add_subplot(gs[4,6]) # First row, third column
    ax7=fig.add_subplot(gs[5,6]) # Second row, span all columns
    
#     # Create phase portrait for (r1, r2) subsystem with fixed values of weights.
#     fig, ax = plt.subplots(1, 1, figsize=(15,15))
    
    # Plot r1, r2 nullclines for fixed m values and (r1, r2) trajectory
    delta = 0.025
    x = np.arange(-2, 2, delta)
    y = np.arange(-2, 2, delta)
    p, q = np.meshgrid(x, y)
    z1 = r1_null_2(p, q, m11[t_end], m12[t_end], u1)
    z2 = r2_null_2(p, q, m21[t_end], m22[t_end], u2)
    ax1.contour(p, q, z1, [0], colors=["y"])
    ax1.contour(p, q, z2, [0], colors=["r"])
    
    
    # Find Fixed points of 2D (r1, r2) system by finding intersection of r nullclines.
    params = (gamma, m11[t_end], m12[t_end], m21[t_end], m22[t_end], u1, u2)
    fixed_points = find_roots(nullclines_2, 2, [-1, 1], [-1, 1], 20, (params[1:]))
    for x in fixed_points:
        ax1.plot(x[0], x[1], 'ks')
    
    # Plot phase portrait for 2D (r1, r2) system
    r_range = np.linspace(-0.99, 0.99, 10000)
    phase_portrait(ax1, f_2, t, 2, 2, 15, params, quiver=False)
    
    # Plot r1, r2 nullclines for 6D system and their intersections
    z1 = r1_null_6(p, q, eta1, eta2, u1)
    z2 = r2_null_6(p, q, eta1, eta2, u2)
    ax1.contour(p, q, z1, [0], colors=["m"])
    ax1.contour(p, q, z2, [0], colors=["orange"])
    
    # Plot projection of 6D trajectory onto (r1, r2) space
    ax1.plot(r1[t_start:t_end], r2[t_start:t_end], 'k', linewidth=2)
    
    params = (eta1, eta2, u1, u2)
    fixed_points = find_roots(nullclines_6, 2, [-1, 1], [-1, 1], 20, params)
    for x in fixed_points:
        evals = linalg.eigvals(jac(x[0], x[1], gamma, eta1, eta2))
        
        if all(evals.real < 0):
            ax1.plot(x[0], x[1], 'ko')
        else:
            ax1.scatter(x[0], x[1], s=80, facecolors='none', edgecolors='k')
    ax1.set_xlim(-2, 2)
    ax1.set_ylim(-2, 2)
    
    ax2.plot(t[t_start:t_end], m11[t_start:t_end],label="m11")
    ax3.plot(t[t_start:t_end], m12[t_start:t_end], label="m12")
    ax4.plot(t[t_start:t_end], m21[t_start:t_end], label="m21")
    ax5.plot(t[t_start:t_end], m22[t_start:t_end], label="m22")
    ax6.plot(t[t_start:t_end], r1[t_start:t_end], label="r1")
    ax7.plot(t[t_start:t_end], r2[t_start:t_end], label="r1")
    ax2.set_xlim(0, 500)
    #ax2.set_ylim(-2, 2)
    ax3.set_xlim(0, 500)
    #ax3.set_ylim(-2, 2)
    ax4.set_xlim(0, 500)
    #ax4.set_ylim(-2, 2)
    ax5.set_xlim(0, 500)
    #ax5.set_ylim(-2, 2)
    ax6.set_xlim(0, 500)
    ax6.set_ylim(-2, 2)
    ax7.set_xlim(0, 500)
    ax7.set_ylim(-2, 2)
    
    ax2.legend(loc='upper left')
    ax3.legend(loc='upper left')
    ax4.legend(loc='upper left')
    ax5.legend(loc='upper left')
    ax6.legend(loc='upper left')
    ax7.legend(loc='upper left')

interactive(children=(FloatSlider(value=0.1, description='gamma', max=5.0, step=0.001), FloatSlider(value=0.1,…