In [None]:
# autoreload
%load_ext autoreload
%autoreload 2

# imports
import numpy as np
from matplotlib import pyplot as plt
from ipywidgets import widgets

import scienceplots
plt.style.use(['science','no-latex'])
# plt.style.use(['science'])

# change matplotlib colrormap
cmap = 'magma'
plt.rcParams['image.cmap'] = cmap

# add src to path
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(''), '../..')))
from src.dynamics import dynamics
from src.linear_response import response, response_regime_metric


In [6]:
def early_vs_late_inhibition(J, alpha, beta, g, a, b, c):


    W = np.array([[J, 0, - alpha*J],
                  [J, J, - beta*J],
                  [J, J, - g*J]])
    
    I = np.array([a, b, c])
    R_l, R_e, det = response(W, I, ret_det=True)
    return R_l, R_e, det


def plot_response_differential_inh(J, g, a, b, c, density=100, save_fig_location='results/'):
    alpha_arr = np.linspace(0.1, 4.0, density)
    beta_arr = np.linspace(0.1, 4.0, density)

    R_l = np.zeros((len(alpha_arr), len(beta_arr)))
    R_e = np.zeros((len(alpha_arr), len(beta_arr)))
    det = np.zeros((len(alpha_arr), len(beta_arr)))

    for alpha_idx, alpha in enumerate(alpha_arr):
        for beta_idx, beta in enumerate(beta_arr):
            R_l[alpha_idx, beta_idx], R_e[alpha_idx, beta_idx], det[alpha_idx, beta_idx] = early_vs_late_inhibition(J, alpha, beta, g, a, b, c)

    R_eff = response_regime_metric(R_l, R_e)


    fig, ax = plt.subplots(figsize=(4, 3), dpi=250)
    im = ax.imshow(det, extent=(0.1, 4, 0.1, 4), origin='lower', aspect='auto')
    # plot a line where det = 0
    cs = ax.contour(alpha_arr, beta_arr, det, levels=[0], colors='white', linewidths=1, linestyles='dashed')
    ax.clabel(cs, inline=True, fontsize=8, fmt=r"$\det(W)=0$")
    ax.set_title(r"Determinant of $W$")
    ax.set_ylabel(r"$\alpha$", fontsize=15)
    ax.set_xlabel(r"$\beta$", fontsize=15)
    fig.colorbar(im, ax=ax)
    fig.tight_layout()
    fig.show()

    fig, ax = plt.subplots(1, 2, figsize=(7, 3), dpi=250)

    vmin = min(np.min(R_l), np.min(R_e))
    vmax = max(np.max(R_l), np.max(R_e))   
    vmin = -1
    vmax = 4
    
    fs = 15
    im1 = ax[0].imshow(R_l, extent=(0.1, 4, 0.1, 4), origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
    
    ax[0].set_title(r"$R_L$")
    ax[0].set_ylabel(r"$\alpha$", fontsize=fs)
    ax[0].set_xlabel(r"$\beta$", fontsize=fs)
    fig.colorbar(im1, ax=ax[0]) 

    # plot a line where R_l = 0 on the im1 plot
    cs = ax[0].contour(alpha_arr, beta_arr, R_l, levels=[0], colors='white', linewidths=1, linestyles='dashed')
    ax[0].clabel(cs, inline=True, fontsize=5, fmt=r"$R_L=0$")

    im2 = ax[1].imshow(R_e, extent=(0.1, 4, 0.1, 4), origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
    ax[1].set_title(r"$R_E$")
    ax[1].set_ylabel(r"$\alpha$", fontsize=fs)
    ax[1].set_xlabel(r"$\beta$", fontsize=fs)
    fig.colorbar(im2, ax=ax[1])

    # plot a line where R_e = 0 on the im2 plot
    cs = ax[1].contour(alpha_arr, beta_arr, R_e, levels=[0], colors='white', linewidths=1, linestyles='dashed')
    ax[1].clabel(cs, inline=True, fontsize=5, fmt=r"$R_E=0$")

    fig.tight_layout()
    fig.savefig(os.path.join(save_fig_location, 'responses_differential_inh.png'), dpi=500)

    fig, ax = plt.subplots(1, 1, figsize=(4, 2.5), dpi=300)
    im = ax.imshow(R_eff, extent=(0.1, 4.0, 0.1, 4.0), origin='lower', aspect='auto')
    ax.set_title('Response Regimes')
    ax.set_ylabel(r"$\alpha$", fontsize=fs)
    ax.set_xlabel(r"$\beta$", fontsize=fs)
    cbar = fig.colorbar(im, ax=ax, ticks=[0, 1, 2])
    cbar.ax.set_yticklabels(['Acquisition', 'Early recall', 'Late recall'])
    fig.tight_layout()
    fig.savefig(os.path.join(save_fig_location, 'response_regimes_differential_inh.png'), dpi=500)
    fig.show()

In [7]:
@widgets.interact(J=widgets.FloatSlider(0.5, min=0, max=2.0, step=0.01), 
                  g=widgets.FloatSlider(0.5, min=0.1, max=1, step=0.01), 
                  a=widgets.FloatSlider(1.0, min=0.0, max=2.0, step=0.1),
                  b=widgets.FloatSlider(1.0, min=0.0, max=2.0, step=0.1), 
                  c=widgets.FloatSlider(1.0, min=0.0, max=2.0, step=0.1))
def explore_response_differential_inh(J, g, a, b, c):
    """ get the dynamics from the other script and plug in the values"""
    plot_response_differential_inh(J, g, a, b, c, save_fig_location='../../results/diff_inh_asymmetric')
    

interactive(children=(FloatSlider(value=0.5, description='J', max=2.0, step=0.01), FloatSlider(value=0.5, descâ€¦

### Plot dynamics

In [None]:
def plot_dynamics(J, alpha, beta, g, delta, a, b, c, save_fig_location='../results/diff_inh/dynamics.svg', title='Population dynamics'):
    """ get the dynamics from the other script and plug in the values"""
    W = np.array([[J, beta*J, - delta*J],
                    [alpha*J, J, - delta*J],
                    [J, J, - g*J]])
    I = np.array([a, b, c])

    bg = np.array([1.0, 1.0, 1.0])
    r_l, r_e, r_i, time_arr = dynamics(W, bg, I)
    
    plt.figure(figsize=(3.5, 2), dpi=300)
    plt.plot(time_arr, r_l, label='L', color='C0')
    plt.plot(time_arr, r_e, label='E', color='C1')
    # plt.plot(time_arr, r_i, label='PV', color='C2')

    # set time axis to seconds

    plt.xlabel('Time (sec)')
    plt.ylabel('Firing Rate (Hz)')
    plt.xlim((0, 10))
    # plt.ylim((-0.5, 8.5))
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_fig_location, dpi=500)
    plt.show()    