In [None]:
import numpy as np
import matplotlib.pyplot as plt
from v1sh_model.models.V1_model_2 import V1_model_2 as base_model

# 1. Two-point EI network

In [None]:
# set parameters as in Fig. 5.58 and eq. 5.66-5.67 of "Understanding Vision" (Li Zhaoping, 2014)

alpha_x = 1.0
alpha_y = 1.0

K = 1

# connections
J_o = 0.8
W_o = 1.13


def compute_connection_kernel(K, verbose=False):
    J_dash = 0.4
    W_dash = 0.9

    J_kernel = np.zeros((3, 3, 1, 1))
    J_kernel[0, 0, 0, 0] = J_dash / 2
    J_kernel[0, 2, 0, 0] = J_dash / 2

    W_kernel = np.zeros((3, 3, 1, 1))
    W_kernel[0, 0, 0, 0] = W_dash / 2
    W_kernel[0, 2, 0, 0] = W_dash / 2
    
    Psi_kernel = np.zeros((3, 3, 1, 1))
    
    return (J_kernel, W_kernel, Psi_kernel)

# Normalization terms
I_o = lambda X: np.zeros_like(X)
I_c = lambda I_topdown: I_topdown

# Activation functions
def g_x(x, x_th = 1.):
    return np.maximum(0, x - x_th)

x = np.linspace(-1, 3, 1000)
plt.plot(x, g_x(x))
plt.ylabel(r'$g_x(x)$')
plt.xlabel('x')
plt.title(r'Activation function $g_x$ for pyramidal cells')
plt.tight_layout()
plt.show()

def g_y(y):
    return y

# tuning curves 
def tuning_curve(angle: np.ndarray):
    return (angle == 0.).astype(float) # delta function tuning

In [None]:
# Initialize the model

EI_pair = base_model(
    K=K,
    compute_connection_kernel=compute_connection_kernel,
    tuning_curve=tuning_curve,
    J_o=J_o,
    W_o=W_o,
    g_x = g_x,
    g_y = g_y,
    alpha_x=alpha_x,
    alpha_y=alpha_y,
    I_o=I_o,
    I_c=I_c
)

In [None]:
def column_input(I = (1., 1.)):
    A = np.zeros((3, 3))
    C = np.zeros((3, 3))
    C[:, 0] = I[0]
    C[:, 1] = I[1]
    C[:, 2] = I[0]
    return A, C

A, C = column_input(I = (0., 3.))
T = 30.0
dt = 0.001
X_0 = np.full((3, 3, 1), 1.1) 
Y_0 = np.full((3, 3, 1), 0.1)
X, Y, I = EI_pair.simulate(A, C, dt, T, noisy=False, initial_condition=(X_0, Y_0), mode='wrap')

In [None]:
t = np.arange(0, T, dt)
plt.plot(t, X[:, 1, 0, 0], label='Pyramidal cell (E)')
plt.plot(t, Y[:, 1, 0, 0], label='Interneuron (I)')
plt.xlabel('Time (s)')
plt.ylabel('Firing rate')
plt.title('Two-point EI network response')
plt.legend()
plt.tight_layout()
plt.show()