In [1]:
# autoreload
%load_ext autoreload
%autoreload 2

# imports
import numpy as np
from matplotlib import pyplot as plt
from ipywidgets import widgets

import scienceplots
plt.style.use(['science','no-latex'])
# plt.style.use(['science'])


# change matplotlib colrormap
cmap = 'magma'
plt.rcParams['image.cmap'] = cmap

# add src to path
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(''), '..')))
from src.dynamics import dynamics



In [2]:

def response(W, I):
    """Compute the linear response of a 3-population network.
    assumes that cellular gains are 1."""

    w_ll, w_le, w_lp = W[0, 0], W[0, 1], - W[0, 2]
    w_el, w_ee, w_ep = W[1, 0], W[1, 1], - W[1, 2]
    w_pl, w_pe, w_pp = W[2, 0], W[2, 1], - W[2, 2]

    dI_l, dI_e, dI_p = I[0], I[1], I[2]

    det = np.linalg.det(np.eye(3) - W)
    if det == 0:
        raise ValueError("The matrix (I - W) is singular, cannot compute response.")
    
    # det = 1
    R_l = (1/det) * (dI_l*(1 + w_pp - w_ee - w_ee*w_pp + w_ep*w_pe) + dI_e*(w_le + w_le*w_pp - w_lp*w_pe) + dI_p*(- w_lp - w_le*w_ep + w_lp*w_ee))

    R_e = (1/det) * (dI_l*(w_el + w_el*w_pp - w_ep*w_pl) + dI_e*(1 + w_pp - w_ll - w_ll*w_pp + w_lp*w_pl) + dI_p * (- w_ep - w_el*w_lp + w_ep*w_ll))


    return R_l, R_e

def response_regime_metric(R_l, R_e):
    """ returns the matrix of response regimes:
    0: both positive
    1: late positive, early negative or zero
    2: late negative or zero, early positive
    """

    regime = np.zeros(R_l.shape)
    regime[(R_l > 0) & (R_e > 0)] = 0
    regime[(R_l > 0) & (R_e <= 0)] = 1
    regime[(R_l <= 0) & (R_e > 0)] = 2
    regime[(R_l < 0) & (R_e < 0)] = 3
    return regime

def response_regime_metric_2(R_l, R_e):
    """ returns the matrix of response regimes:
    0: both positive
    1: late positive, early negative or zero
    2: late negative or zero, early positive
    """


    regime = R_l / R_e

    return regime

In [3]:
def early_vs_late_inhibition(J, w1, w2, alpha, beta, g, a, b, c):


    W = np.array([[w1*J, J, - alpha*J],
                  [J, w2*J, - beta*J],
                  [J, J, - g*J]])
    
    I = np.array([a, b, c])
    R_l, R_e = response(W, I)
    return R_l, R_e
    
def plot_response_differential_inh(J, w1, w2, g, a, b, c, density=100, save_fig_location='results/'):
    alpha_arr = np.linspace(0.1, 4.0, density)
    beta_arr = np.linspace(0.1, 4.0, density)

    R_l = np.zeros((len(alpha_arr), len(beta_arr)))
    R_e = np.zeros((len(alpha_arr), len(beta_arr)))

    for alpha_idx, alpha in enumerate(alpha_arr):
        for beta_idx, beta in enumerate(beta_arr):
            R_l[alpha_idx, beta_idx], R_e[alpha_idx, beta_idx] = early_vs_late_inhibition(J, w1, w2, alpha, beta, g, a, b, c)

    R_eff = response_regime_metric(R_l, R_e)

    fig, ax = plt.subplots(1, 2, figsize=(7, 3), dpi=250)

    vmin = min(np.min(R_l), np.min(R_e))
    vmax = max(np.max(R_l), np.max(R_e))   
    vmin = -1
    vmax = 4
    
    fs = 15
    im1 = ax[0].imshow(R_l, extent=(0.1, 4, 0.1, 4), origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
    
    ax[0].set_title(r"$R_L$")
    ax[0].set_ylabel(r"$\alpha$", fontsize=fs)
    ax[0].set_xlabel(r"$\beta$", fontsize=fs)
    fig.colorbar(im1, ax=ax[0]) 

    # plot a line where R_l = 0 on the im1 plot
    cs = ax[0].contour(alpha_arr, beta_arr, R_l, levels=[0], colors='white', linewidths=1, linestyles='dashed')
    ax[0].clabel(cs, inline=True, fontsize=5, fmt=r"$R_L=0$")

    im2 = ax[1].imshow(R_e, extent=(0.1, 4, 0.1, 4), origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
    ax[1].set_title(r"$R_E$")
    ax[1].set_ylabel(r"$\alpha$", fontsize=fs)
    ax[1].set_xlabel(r"$\beta$", fontsize=fs)
    fig.colorbar(im2, ax=ax[1])

    # plot a line where R_e = 0 on the im2 plot
    cs = ax[1].contour(alpha_arr, beta_arr, R_e, levels=[0], colors='white', linewidths=1, linestyles='dashed')
    ax[1].clabel(cs, inline=True, fontsize=5, fmt=r"$R_E=0$")

    fig.tight_layout()
    fig.savefig(os.path.join(save_fig_location, 'responses_differential_inh.png'), dpi=500)

    fig, ax = plt.subplots(1, 1, figsize=(4, 2.5), dpi=300)
    im = ax.imshow(R_eff, extent=(0.1, 4.0, 0.1, 4.0), origin='lower', aspect='auto')
    ax.set_title('Response Regimes')
    ax.set_ylabel(r"$\alpha$", fontsize=fs)
    ax.set_xlabel(r"$\beta$", fontsize=fs)
    cbar = fig.colorbar(im, ax=ax, ticks=[0, 1, 2])
    cbar.ax.set_yticklabels(['Acquisition', 'Early recall', 'Late recall'])
    fig.tight_layout()
    # fig.savefig(os.path.join(save_fig_location, 'response_regimes_differential_inh.png'), dpi=500)
    fig.show()

In [4]:
@widgets.interact(J=widgets.FloatSlider(0.5, min=0, max=2.0, step=0.01), 
                  w1 = widgets.FloatSlider(1.0, min=0.1, max=2.0, step=0.01),
                  w2 = widgets.FloatSlider(1.0, min=0.1, max=2.0, step=0.01),
                  g=widgets.FloatSlider(0.5, min=0.1, max=1, step=0.01), 
                  a=widgets.FloatSlider(1.0, min=0.0, max=2.0, step=0.1),
                  b=widgets.FloatSlider(1.0, min=0.0, max=2.0, step=0.1), 
                  c=widgets.FloatSlider(1.0, min=0.0, max=2.0, step=0.1))
def explore_response_differential_inh(J, w1, w2, g, a, b, c):
    """ get the dynamics from the other script and plug in the values"""
    plot_response_differential_inh(J, w1, w2, g, a, b, c, save_fig_location='../results/diff_inh_plastic')
    

interactive(children=(FloatSlider(value=0.5, description='J', max=2.0, step=0.01), FloatSlider(value=1.0, descâ€¦