In [1]:
# autoreload
%load_ext autoreload
%autoreload 2
# imports
import numpy as np
from matplotlib import pyplot as plt
from ipywidgets import widgets

import scienceplots
plt.style.use(['science','no-latex'])
# plt.style.use(['science'])


# change matplotlib colrormap
cmap = 'magma'
plt.rcParams['image.cmap'] = cmap

# add src to path
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(''), '..')))
from src.dynamics import dynamics


In [2]:

def response(W, I):
    """Compute the linear response of a 3-population network.
    assumes that cellular gains are 1."""

    w_ll, w_le, w_lp = W[0, 0], W[0, 1], - W[0, 2]
    w_el, w_ee, w_ep = W[1, 0], W[1, 1], - W[1, 2]
    w_pl, w_pe, w_pp = W[2, 0], W[2, 1], - W[2, 2]

    dI_l, dI_e, dI_p = I[0], I[1], I[2]

    det = np.linalg.det(np.eye(3) - W)
    if det == 0:
        raise ValueError("The matrix (I - W) is singular, cannot compute response.")
    
    # det = 1
    R_l = (1/det) * (dI_l*(1 + w_pp - w_ee - w_ee*w_pp + w_ep*w_pe) + dI_e*(w_le + w_le*w_pp - w_lp*w_pe) + dI_p*(- w_lp - w_le*w_ep + w_lp*w_ee))

    R_e = (1/det) * (dI_l*(w_el + w_el*w_pp - w_ep*w_pl) + dI_e*(1 + w_pp - w_ll - w_ll*w_pp + w_lp*w_pl) + dI_p * (- w_ep - w_el*w_lp + w_ep*w_ll))

    return R_l, R_e

def response_regime_metric(R_l, R_e):
    """ returns the matrix of response regimes:
    0: both positive
    1: late positive, early negative or zero
    2: late negative or zero, early positive
    """

    regime = np.zeros(R_l.shape)
    regime[(R_l > 0) & (R_e <= 0)] = 1
    regime[(R_l <= 0) & (R_e > 0)] = 2
    return regime

def response_regime_metric_2(R_l, R_e):
    """ returns the matrix of response regimes:
    0: both positive
    1: late positive, early negative or zero
    2: late negative or zero, early positive
    """


    regime = R_l / R_e

    return regime


In [3]:
def simple_exc_vs_inh(J, g, a, b, c):

    W = np.array([[J, J, -g*J],
                  [J, J, -g*J],
                  [J, J, -g*J]])
    
    I = np.array([a, b, c])
    R_l, R_e = response(W, I)
    return R_l, R_e

    

In [4]:
@widgets.interact(a=widgets.FloatSlider(1.0, min=0.0, max=2.0, step=0.1),
                  b=widgets.FloatSlider(1.0, min=0.0, max=2.0, step=0.1), 
                  c=widgets.FloatSlider(0.5, min=0.0, max=2.0, step=0.1),)
def explore_response_exc_vs_inh(a, b, c):
    J = np.linspace(0.1, 2.0, 100)
    g = np.linspace(0.1, 1.0, 100)
    R_l_mat = np.zeros((len(J), len(g)))
    R_e_mat = np.zeros((len(J), len(g)))
    regime_mat = np.zeros((len(J), len(g)))

    for i, J_val in enumerate(J):
        for j, g_val in enumerate(g):
            R_l, R_e = simple_exc_vs_inh(J_val, g_val, a, b, c)
            R_l_mat[i, j] = R_l
            R_e_mat[i, j] = R_e
            regime_mat[i, j] = response_regime_metric(R_l, R_e)
    
    fig, ax = plt.subplots(figsize=(4, 3), dpi=100)
    im = ax.imshow(regime_mat.T, origin='lower', extent=(J[0], J[-1], g[0], g[-1]), aspect='auto', cmap='viridis')
    ax.set_xlabel('J')
    ax.set_ylabel('g')  
    ax.set_title('Response Regimes')
    cbar = fig.colorbar(im, ax=ax, ticks=[0, 1, 2])
    cbar.ax.set_yticklabels(['Acquisition', 'Early recall', 'Late recall'])
    fig.tight_layout()
    fig.show()

    fig, ax = plt.subplots(1, 2, figsize=(7, 3), dpi=100)
    im0 = ax[0].imshow(R_l_mat.T, origin='lower', extent=(J[0], J[-1], g[0], g[-1]), aspect='auto', cmap='magma')
    ax[0].set_title('R_l')
    ax[0].set_xlabel('J ')
    ax[0].set_ylabel('g')  
    fig.colorbar(im0, ax=ax[0]) 

    im1 = ax[1].imshow(R_e_mat.T, origin='lower', extent=(J[0], J[-1], g[0], g[-1]), aspect='auto', cmap='magma')
    ax[1].set_title('R_e (Hz)')
    ax[1].set_xlabel('J')
    ax[1].set_ylabel('g')
    fig.colorbar(im1, ax=ax[1])
    fig.tight_layout()
    fig.show()
    # fig.savefig('results/response_regimes_exc_vs_inh.png', dpi=500


interactive(children=(FloatSlider(value=1.0, description='a', max=2.0), FloatSlider(value=1.0, description='b'…

In [5]:

def plot_dynamics(J, g, a, b, c, save_fig_location='../results/exc_vs_inh_simple/'):
    """ get the dynamics from the other script and plug in the values"""
    W = np.array([[J, J, - g*J],
                  [J, J, - g*J],
                  [J, J, - g*J]])
    I = np.array([a, b, c])

    bg = np.array([1.0, 1.0, 1.0])
    r_l, r_e, r_i, time_arr = dynamics(W, bg, I)
    
    plt.figure(figsize=(3.5, 2), dpi=300)
    plt.plot(time_arr, r_l, label='L', color='C0')
    plt.plot(time_arr, r_e, label='E', color='C1')
    # plt.plot(time_arr, r_i, label='PV', color='C2')

    # set time axis to seconds

    plt.xlabel('Time (sec)')
    plt.ylabel('Firing Rate (Hz)')
    # plt.ylim((-0.5, 4))
    plt.title('Population Dynamics')
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_fig_location, dpi=500)
    plt.show()

In [6]:
@widgets.interact(J=widgets.FloatSlider(0.5, min=0, max=2.0, step=0.01),
                  g=widgets.FloatSlider(0.5, min=0.1, max=1, step=0.01), 
                  a=widgets.FloatSlider(1.0, min=0.0, max=2.0, step=0.1),
                  b=widgets.FloatSlider(1.0, min=0.0, max=2.0, step=0.1), 
                  c=widgets.FloatSlider(1.0, min=0.0, max=2.0, step=0.1))
def explore_plot_dynamics(J, g, a, b, c):
    """ get the dynamics from the other script and plug in the values"""
    plot_dynamics(J, g, a, b, c, save_fig_location='../results/exc_vs_inh_simple/dynamics_exc_vs_inh.svg')
    

interactive(children=(FloatSlider(value=0.5, description='J', max=2.0, step=0.01), FloatSlider(value=0.5, desc…