In [None]:
import numpy as np
from matplotlib import pyplot as plt

import scienceplots
plt.style.use(['science', 'no-latex'])

import ipywidgets as widgets

In [2]:
def response(W, B, dI):
    """
    W : weight matrix
    I : input (dI)
    B: diagonal matrix of cellular gains
    """

    A = np.linalg.inv(B) - W
    R = np.linalg.inv(A)
    
    R_l = R[0] @ dI
    R_e = R[1] @ dI

    return R_l, R_e

def determinant(W, B):
    """
    Compute the determinant of the matrix A = inv(B) - W
    where B is the diagonal matrix of cellular gains
    and W is the weight matrix
    
    This determinant is the denominator of the response functions
    """
    A = np.linalg.inv(B) - W
    det = np.linalg.det(A)
    return det

def response_regime_metric(R_l, R_e):
    """
    Compute the response regime metric
    defined as the ratio of the change in R_e to the change in R_l
    """

    if R_l > 0 and R_e > 0: # training
        return 0
    elif R_l > 0 and R_e <= 0: # early recall
        return 1
    elif R_l <= 0 and R_e > 0: # late recall
        return 2
    else: # no response
        return 3

## connectivity no - 1
def exc_vs_inh_simple(J, g, I_l, I_e, I_p, I_c, B=np.eye(4)):
    """
    4 populations: LB, EB, PV, CCK
        all excitatory weights : J
        all inhibitory weights : g*J
    """
    W = np.array([[J, J, -g*J, -g*J],
                  [J, J, -g*J, -g*J],
                  [J, J, -g*J, -g*J],
                  [J, J, -g*J, -g*J]])
    
    I = np.array([I_l, I_e, I_p, I_c])
    R_l, R_e = response(W, B, I)  
    det = determinant(W, B)

    return R_l, R_e, det

## connectivity no - 2
def inh_PV_CCK(J, g_p, g_c, I_l, I_e, I_p, I_c,  B=np.eye(4)):
    """
    4 populations: LB, EB, PV, CCK
        all excitatory weights : J
        PV inhibitory weights : g_p*J
        CCK inhibitory weights : g_c*J
    """
    W = np.array([[J, J, -g_p*J, -g_c*J],
                  [J, J, -g_p*J, -g_c*J],
                  [J, J, -g_p*J, -g_c*J],
                  [J, J, -g_p*J, -g_c*J]])
    
    I = np.array([I_l, I_e, I_p, I_c])
    R_l, R_e = response(W, B, I)  
    det = determinant(W, B)
    return R_l, R_e, det

## connectivity no - 3
def inh_PV_diff(J, g_p, g_c, alpha_p, beta_p, I_l, I_e, I_p, I_c, B=np.eye(4)):
    """
    4 populations: LB, EB, PV, CCK
        all excitatory weights : J
        PV inhibitory weights : g_p*J
            - to late born excitatory neurons : alpha_p*g_p*J
            - to early born excitatory neurons : beta_p*g_p*J
        CCK inhibitory weights : g_c*J
    """
    W = np.array([[J, J, -alpha_p*g_p*J, -g_c*J],
                  [J, J, -beta_p*g_p*J, -g_c*J],
                  [J, J, -g_p*J, -g_c*J],
                  [J, J, -g_p*J, -g_c*J]])
    
    I = np.array([I_l, I_e, I_p, I_c])
    R_l, R_e = response(W, B, I)  
    det = determinant(W, B)
    return R_l, R_e, det

## connectivity no - 4
def inh_CCK_diff(J, g_p, g_c, alpha_c, beta_c, I_l, I_e, I_p, I_c, B=np.eye(4)):
    """
    4 populations: LB, EB, PV, CCK
        all excitatory weights : J
        PV inhibitory weights : g_p*J
        CCK inhibitory weights : g_c*J
            - to late born excitatory neurons : alpha_c*g_c*J
            - to early born excitatory neurons : beta_c*g_c*J
    """
    W = np.array([[J, J, -g_p*J, -alpha_c*g_c*J],
                  [J, J, -g_p*J, -beta_c*g_c*J],
                  [J, J, -g_p*J, -g_c*J],
                  [J, J, -g_p*J, -g_c*J]])
    
    I = np.array([I_l, I_e, I_p, I_c])
    R_l, R_e = response(W, B, I)  
    det = determinant(W, B)
    return R_l, R_e, det

## connectivity no - 5
def inh_PV_diff_CCK_diff(J, g_p, g_c, alpha_p, beta_p, alpha_c, beta_c, I_l, I_e, I_p, I_c, B=np.eye(4)):
    """
    4 populations: LB, EB, PV, CCK
        all excitatory weights : J
        PV inhibitory weights : g_p*J
            - to late born excitatory neurons : alpha_p*g_p*J
            - to early born excitatory neurons : beta_p*g_p*J
        CCK inhibitory weights : g_c*J
            - to late born excitatory neurons : alpha_c*g_c*J
            - to early born excitatory neurons : beta_c*g_c*J
    """
    W = np.array([[J, J, -alpha_p*g_p*J, -alpha_c*g_c*J],
                  [J, J, -beta_p*g_p*J, -beta_c*g_c*J],
                  [J, J, -g_p*J, -g_c*J],
                  [J, J, -g_p*J, -g_c*J]])
    
    I = np.array([I_l, I_e, I_p, I_c])
    R_l, R_e = response(W, B, I)  
    det = determinant(W, B)
    return R_l, R_e, det


In [3]:
@widgets.interact( I_l = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                    I_e = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                    I_p = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                    I_c = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01))
def explore_exc_vs_inh_simple(I_l, I_e, I_p, I_c):

    J_arr = np.linspace(0, 2, 100)
    g_arr = np.linspace(0, 4, 100)

    R_l = np.meshgrid(np.zeros_like(J_arr), np.zeros_like(g_arr))[0]
    R_e = np.meshgrid(np.zeros_like(J_arr), np.zeros_like(g_arr))[0]
    determinant_grid = np.meshgrid(np.zeros_like(J_arr), np.zeros_like(g_arr))[0]
    response_regime = np.meshgrid(np.zeros_like(J_arr), np.zeros_like(g_arr))[0]

    for i, j in enumerate(J_arr):
        for k, gg in enumerate(g_arr):
            R_l[i, k], R_e[i, k], determinant_grid[i, k] = exc_vs_inh_simple(j, gg, I_l, I_e, I_p, I_c)
            response_regime[i, k] = response_regime_metric(R_l[i, k], R_e[i, k])

    fig, ax = plt.subplots(1, 4, figsize=(10, 2.5), dpi=200)
    im1 = ax[0].imshow(determinant_grid.T, extent=(J_arr.min(), J_arr.max(), g_arr.min(), g_arr.max()), origin='lower', aspect='auto')
    ax[0].set_title('Determinant')
    ax[0].set_xlabel(r'$J$')
    ax[0].set_ylabel(r'$g$')
    fig.colorbar(im1, ax=ax[0])
    # add dashed white line at determinant = 0
    cx = ax[0].contour(J_arr, g_arr, determinant_grid.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im2 = ax[1].imshow(R_l.T, extent=(J_arr.min(), J_arr.max(), g_arr.min(), g_arr.max()), origin='lower', aspect='auto')
    ax[1].set_title(r'$R_l$')
    ax[1].set_xlabel(r'$J$')
    ax[1].set_ylabel(r'$g$')
    fig.colorbar(im2, ax=ax[1])
    cx = ax[1].contour(J_arr, g_arr, R_l.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im3 = ax[2].imshow(R_e.T, extent=(J_arr.min(), J_arr.max(), g_arr.min(), g_arr.max()), origin='lower', aspect='auto')
    ax[2].set_title(r'$R_e$')
    ax[2].set_xlabel(r'$J$')
    ax[2].set_ylabel(r'$g$')
    fig.colorbar(im3, ax=ax[2])
    cx = ax[2].contour(J_arr, g_arr, R_e.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im4 = ax[3].imshow(response_regime.T, extent=(J_arr.min(), J_arr.max(), g_arr.min(), g_arr.max()), origin='lower', aspect='auto')
    ax[3].set_title('Response Regime')
    ax[3].set_xlabel(r'$J$')
    ax[3].set_ylabel(r'$g$')
    cbar = fig.colorbar(im4, ax=ax[3], ticks=[0, 1, 2, 3])
    cbar.ax.set_yticklabels(['Training', 'Early Recall', 'Late Recall', 'No Response'])

    fig.suptitle(f'Exc vs Inh simple ($g$ vs $J$)')
    fig.tight_layout()
    plt.show()

interactive(children=(FloatSlider(value=1.0, description='I_l', max=2.0, min=-1.0, step=0.01), FloatSlider(val…

In [4]:
# Exploring connectivity no - 2
@widgets.interact(J = widgets.FloatSlider(value=1, min=0, max=2, step=0.01),
                  I_l = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  I_e = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  I_p = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  I_c = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  v_min = widgets.FloatSlider(value=-1, min=-10, max=0, step=0.1),
                  v_max = widgets.FloatSlider(value=1, min=0, max=10, step=0.1))
def explore_inh_PV_CCK(J, I_l, I_e, I_p, I_c, v_min=-10, v_max=10):

    g_p_arr = np.linspace(0, 4, 100)
    g_c_arr = np.linspace(0, 4, 100)

    R_l = np.meshgrid(np.zeros_like(g_p_arr), np.zeros_like(g_c_arr))[0]
    R_e = np.meshgrid(np.zeros_like(g_p_arr), np.zeros_like(g_c_arr))[0]
    determinant_grid = np.meshgrid(np.zeros_like(g_p_arr), np.zeros_like(g_c_arr))[0]
    response_regime = np.meshgrid(np.zeros_like(g_p_arr), np.zeros_like(g_c_arr))[0]
    for k, gp in enumerate(g_p_arr):
        for l, gc in enumerate(g_c_arr):
            R_l[k, l], R_e[k, l], determinant_grid[k, l] = inh_PV_CCK(J, gp, gc, I_l, I_e, I_p, I_c)
            response_regime[k, l] = response_regime_metric(R_l[k, l], R_e[k, l])

    fig, ax = plt.subplots(1, 4, figsize=(10, 2.5), dpi=200)
    im1 = ax[0].imshow(determinant_grid.T, extent=(g_p_arr.min(), g_p_arr.max(), g_c_arr.min(), g_c_arr.max()), origin='lower', aspect='auto')
    ax[0].set_title('Determinant')
    ax[0].set_xlabel(r'$g_p$')
    ax[0].set_ylabel(r'$g_c$')
    fig.colorbar(im1, ax=ax[0])
    # add dashed white line at determinant = 0
    cx = ax[0].contour(g_p_arr, g_c_arr, determinant_grid.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im2 = ax[1].imshow(R_l.T, extent=(g_p_arr.min(), g_p_arr.max(), g_c_arr.min(), g_c_arr.max()), origin='lower', aspect='auto', vmin=v_min, vmax=v_max)
    ax[1].set_title(r'$R_l$')
    ax[1].set_xlabel(r'$g_p$')
    ax[1].set_ylabel(r'$g_c$')
    fig.colorbar(im2, ax=ax[1])
    cx = ax[1].contour(g_p_arr, g_c_arr, R_l.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im3 = ax[2].imshow(R_e.T, extent=(g_p_arr.min(), g_p_arr.max(), g_c_arr.min(), g_c_arr.max()), origin='lower', aspect='auto', vmin=v_min, vmax=v_max)
    ax[2].set_title(r'$R_e$')
    ax[2].set_xlabel(r'$g_p$')
    ax[2].set_ylabel(r'$g_c$')
    fig.colorbar(im3, ax=ax[2])
    cx = ax[2].contour(g_p_arr, g_c_arr, R_e.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im4 = ax[3].imshow(response_regime.T, extent=(g_p_arr.min(), g_p_arr.max(), g_c_arr.min(), g_c_arr.max()), origin='lower', aspect='auto')
    ax[3].set_title('Response Regime')
    ax[3].set_xlabel(r'$g_p$')
    ax[3].set_ylabel(r'$g_c$')
    cbar = fig.colorbar(im4, ax=ax[3], ticks=[0, 1, 2, 3])
    cbar.ax.set_yticklabels(['Training', 'Early Recall', 'Late Recall', 'No Response'])

    fig.suptitle('Inh PV and CCK')
    fig.tight_layout()
    plt.show()

interactive(children=(FloatSlider(value=1.0, description='J', max=2.0, step=0.01), FloatSlider(value=1.0, desc…

In [5]:
# Exploring connectivity no - 3
@widgets.interact(J = widgets.FloatSlider(value=1, min=0, max=2, step=0.01),
                  g_p = widgets.FloatSlider(value=1, min=0, max=4, step=0.01),
                  g_c = widgets.FloatSlider(value=1, min=0, max=4, step=0.01),
                  I_l = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  I_e = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  I_p = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  I_c = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  v_min = widgets.FloatSlider(value=-1, min=-10, max=0, step=0.1),
                  v_max = widgets.FloatSlider(value=1, min=0, max=10, step=0.1))
def explore_inh_PV_diff(J, g_p, g_c, I_l, I_e, I_p, I_c, v_min, v_max):

    alpha_p_arr = np.linspace(0, 4, 100)
    beta_p_arr = np.linspace(0, 4, 100)
    
    R_l = np.meshgrid(np.zeros_like(alpha_p_arr), np.zeros_like(beta_p_arr))[0]
    R_e = np.meshgrid(np.zeros_like(alpha_p_arr), np.zeros_like(beta_p_arr))[0]
    determinant_grid = np.meshgrid(np.zeros_like(alpha_p_arr), np.zeros_like(beta_p_arr))[0]
    response_regime = np.meshgrid(np.zeros_like(alpha_p_arr), np.zeros_like(beta_p_arr))[0]
    for m, ap in enumerate(alpha_p_arr):
        for n, bp in enumerate(beta_p_arr):
            R_l[m, n], R_e[m, n], determinant_grid[m, n] = inh_PV_diff(J, g_p, g_c, ap, bp, I_l, I_e, I_p, I_c)
            response_regime[m, n] = response_regime_metric(R_l[m, n], R_e[m, n])

    fig, ax = plt.subplots(1, 4, figsize=(10, 2.5), dpi=200)
    im1 = ax[0].imshow(determinant_grid.T, extent=(alpha_p_arr.min(), alpha_p_arr.max(), beta_p_arr.min(), beta_p_arr.max()), origin='lower', aspect='auto')
    ax[0].set_title('Determinant')
    ax[0].set_xlabel(r'$\alpha_p$')
    ax[0].set_ylabel(r'$\beta_p$')
    fig.colorbar(im1, ax=ax[0])
    # add dashed white line at determinant = 0
    cx = ax[0].contour(alpha_p_arr, beta_p_arr, determinant_grid.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im2 = ax[1].imshow(R_l.T, extent=(alpha_p_arr.min(), alpha_p_arr.max(), beta_p_arr.min(), beta_p_arr.max()), origin='lower', aspect='auto', vmin=v_min, vmax=v_max)
    ax[1].set_title(r'$R_l$')
    ax[1].set_xlabel(r'$\alpha_p$')
    ax[1].set_ylabel(r'$\beta_p$')
    fig.colorbar(im2, ax=ax[1])
    cx = ax[1].contour(alpha_p_arr, beta_p_arr, R_l.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im3 = ax[2].imshow(R_e.T, extent=(alpha_p_arr.min(), alpha_p_arr.max(), beta_p_arr.min(), beta_p_arr.max()), origin='lower', aspect='auto', vmin=v_min, vmax=v_max)
    ax[2].set_title(r'$R_e$')
    ax[2].set_xlabel(r'$\alpha_p$')
    ax[2].set_ylabel(r'$\beta_p$')
    fig.colorbar(im3, ax=ax[2])
    cx = ax[2].contour(alpha_p_arr, beta_p_arr, R_e.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im4 = ax[3].imshow(response_regime.T, extent=(alpha_p_arr.min(), alpha_p_arr.max(), beta_p_arr.min(), beta_p_arr.max()), origin='lower', aspect='auto')
    ax[3].set_title('Response Regime')
    ax[3].set_xlabel(r'$\alpha_p$')
    ax[3].set_ylabel(r'$\beta_p$')
    cbar = fig.colorbar(im4, ax=ax[3], ticks=[0, 1, 2, 3])
    cbar.ax.set_yticklabels(['Training', 'Early Recall', 'Late Recall', 'No Response'])
    
    fig.suptitle(r'3. Inh PV different $\alpha_p$ and $\beta_p$')
    fig.tight_layout()
    plt.show()



interactive(children=(FloatSlider(value=1.0, description='J', max=2.0, step=0.01), FloatSlider(value=1.0, desc…

In [6]:
# Exploring connectivity no - 4
@widgets.interact(J = widgets.FloatSlider(value=1, min=0, max=2, step=0.01),
                  g_p = widgets.FloatSlider(value=1, min=0, max=4, step=0.01),
                  g_c = widgets.FloatSlider(value=1, min=0, max=4, step=0.01),
                  I_l = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  I_e = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  I_p = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  I_c = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                  v_min = widgets.FloatSlider(value=-1, min=-10, max=0, step=0.1),
                  v_max = widgets.FloatSlider(value=1, min=0, max=10, step=0.1))
def explore_inh_CCK_diff(J, g_p, g_c, I_l, I_e, I_p, I_c, v_min, v_max):
    
    alpha_c_arr = np.linspace(0, 4, 100)
    beta_c_arr = np.linspace(0, 4, 100)
    
    R_l = np.meshgrid(np.zeros_like(alpha_c_arr), np.zeros_like(beta_c_arr))[0]
    R_e = np.meshgrid(np.zeros_like(alpha_c_arr), np.zeros_like(beta_c_arr))[0]
    determinant_grid = np.meshgrid(np.zeros_like(alpha_c_arr), np.zeros_like(beta_c_arr))[0]
    response_regime = np.meshgrid(np.zeros_like(alpha_c_arr), np.zeros_like(beta_c_arr))[0]

    for p, ac in enumerate(alpha_c_arr):
        for q, bc in enumerate(beta_c_arr):
            R_l[p, q], R_e[p, q], determinant_grid[p, q] = inh_CCK_diff(J, g_p, g_c, ac, bc, I_l, I_e, I_p, I_c)
            response_regime[p, q] = response_regime_metric(R_l[p, q], R_e[p, q])

    fig, ax = plt.subplots(1, 4, figsize=(10, 2.5), dpi=200)
    im1 = ax[0].imshow(determinant_grid.T, extent=(alpha_c_arr.min(), alpha_c_arr.max(), beta_c_arr.min(), beta_c_arr.max()), origin='lower', aspect='auto')
    ax[0].set_title('Determinant')
    ax[0].set_xlabel(r'$\alpha_c$')
    ax[0].set_ylabel(r'$\beta_c$')
    fig.colorbar(im1, ax=ax[0])
    # add dashed white line at determinant = 0
    cx = ax[0].contour(alpha_c_arr, beta_c_arr, determinant_grid.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im2 = ax[1].imshow(R_l.T, extent=(alpha_c_arr.min(), alpha_c_arr.max(), beta_c_arr.min(), beta_c_arr.max()), origin='lower', aspect='auto', vmin=v_min, vmax=v_max)
    ax[1].set_title(r'$R_l$')
    ax[1].set_xlabel(r'$\alpha_c$')
    ax[1].set_ylabel(r'$\beta_c$')
    fig.colorbar(im2, ax=ax[1])
    cx = ax[1].contour(alpha_c_arr, beta_c_arr, R_l.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im3 = ax[2].imshow(R_e.T, extent=(alpha_c_arr.min(), alpha_c_arr.max(), beta_c_arr.min(), beta_c_arr.max()), origin='lower', aspect='auto', vmin=v_min, vmax=v_max)
    ax[2].set_title(r'$R_e$')
    ax[2].set_xlabel(r'$\alpha_c$')
    ax[2].set_ylabel(r'$\beta_c$')
    fig.colorbar(im3, ax=ax[2])
    cx = ax[2].contour(alpha_c_arr, beta_c_arr, R_e.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im4 = ax[3].imshow(response_regime.T, extent=(alpha_c_arr.min(), alpha_c_arr.max(), beta_c_arr.min(), beta_c_arr.max()), origin='lower', aspect='auto')
    ax[3].set_title('Response Regime')
    ax[3].set_xlabel(r'$\alpha_c$')
    ax[3].set_ylabel(r'$\beta_c$')
    cbar = fig.colorbar(im4, ax=ax[3], ticks=[0, 1, 2, 3])
    cbar.ax.set_yticklabels(['Training', 'Early Recall', 'Late Recall', 'No Response'])
    
    fig.suptitle(r'4. Inh CCK different $\alpha_c$ and $\beta_c$')
    fig.tight_layout()
    plt.show()

interactive(children=(FloatSlider(value=1.0, description='J', max=2.0, step=0.01), FloatSlider(value=1.0, desc…

In [None]:
# Exploring connectivity no - 5

@widgets.interact(J = widgets.FloatSlider(value=1, min=0, max=2, step=0.01),
                g_p = widgets.FloatSlider(value=1, min=0, max=4, step=0.01),
                g_c = widgets.FloatSlider(value=1, min=0, max=4, step=0.01),
                alpha_c = widgets.FloatSlider(value=1, min=0, max=4, step=0.01),
                beta_c = widgets.FloatSlider(value=1, min=0, max=4),
                I_l = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                I_e = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                I_p = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                I_c = widgets.FloatSlider(value=1, min=-1, max=2, step=0.01),
                v_min = widgets.FloatSlider(value=-1, min=-10, max=0, step=0.1),
                v_max = widgets.FloatSlider(value=1, min=0, max=10, step=0.1))
def explore_inh_PV_diff_CCK_diff(J, g_p, g_c, alpha_c, beta_c, I_l, I_e, I_p, I_c, v_min, v_max):

    alpha_p_arr = np.linspace(0, 4, 100)
    beta_p_arr = np.linspace(0, 4, 100)
    
    R_l = np.meshgrid(np.zeros_like(alpha_p_arr), np.zeros_like(beta_p_arr))[0]
    R_e = np.meshgrid(np.zeros_like(alpha_p_arr), np.zeros_like(beta_p_arr))[0]
    determinant_grid = np.meshgrid(np.zeros_like(alpha_p_arr), np.zeros_like(beta_p_arr))[0]
    response_regime = np.meshgrid(np.zeros_like(alpha_p_arr), np.zeros_like(beta_p_arr))[0]

    for m, ap in enumerate(alpha_p_arr):
        for n, bp in enumerate(beta_p_arr):
            R_l[m, n], R_e[m, n], determinant_grid[m, n] = inh_PV_diff_CCK_diff(J, g_p, g_c, ap, bp, alpha_c, beta_c, I_l, I_e, I_p, I_c)
            response_regime[m, n] = response_regime_metric(R_l[m, n], R_e[m, n])

    fig, ax = plt.subplots(1, 4, figsize=(10, 2.5), dpi=200)
    im1 = ax[0].imshow(determinant_grid.T, extent=(alpha_p_arr.min(), alpha_p_arr.max(), beta_p_arr.min(), beta_p_arr.max()), origin='lower', aspect='auto')
    ax[0].set_title('Determinant')
    ax[0].set_xlabel(r'$\alpha_p$')
    ax[0].set_ylabel(r'$\beta_p$')
    fig.colorbar(im1, ax=ax[0])
    # add dashed white line at determinant = 0
    cx = ax[0].contour(alpha_p_arr, beta_p_arr, determinant_grid.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im2 = ax[1].imshow(R_l.T, extent=(alpha_p_arr.min(), alpha_p_arr.max(), beta_p_arr.min(), beta_p_arr.max()), origin='lower', aspect='auto', vmin=v_min, vmax=v_max)
    ax[1].set_title(r'$R_l$')
    ax[1].set_xlabel(r'$\alpha_p$')
    ax[1].set_ylabel(r'$\beta_p$')
    fig.colorbar(im2, ax=ax[1])
    cx = ax[1].contour(alpha_p_arr, beta_p_arr, R_l.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im3 = ax[2].imshow(R_e.T, extent=(alpha_p_arr.min(), alpha_p_arr.max(), beta_p_arr.min(), beta_p_arr.max()), origin='lower', aspect='auto', vmin=v_min, vmax=v_max)
    ax[2].set_title(r'$R_e$')
    ax[2].set_xlabel(r'$\alpha_p$')
    ax[2].set_ylabel(r'$\beta_p$')
    fig.colorbar(im3, ax=ax[2])
    cx = ax[2].contour(alpha_p_arr, beta_p_arr, R_e.T, levels=[0], colors='white', linestyles='dashed')
    cx.clabel(fmt='%.2f', inline=True, fontsize=8)

    im4 = ax[3].imshow(response_regime.T, extent=(alpha_p_arr.min(), alpha_p_arr.max(), beta_p_arr.min(), beta_p_arr.max()), origin='lower', aspect='auto')
    ax[3].set_title('Response Regime')
    ax[3].set_xlabel(r'$\alpha_p$')
    ax[3].set_ylabel(r'$\beta_p$')
    cbar = fig.colorbar(im4, ax=ax[3], ticks=[0, 1, 2, 3])
    cbar.ax.set_yticklabels(['Training', 'Early Recall', 'Late Recall', 'No Response'])
    
    fig.suptitle(r'5. Inh PV and CCK different $\alpha_p$, $\beta_p$ and $\alpha_c$, $\beta_c$')
    fig.tight_layout()
    plt.show()


interactive(children=(FloatSlider(value=1.0, description='J', max=2.0, step=0.01), FloatSlider(value=1.0, desc…