In [None]:
import numpy as np
import matplotlib.pyplot as plt
from v1sh_model.models.V1_model_2 import V1_model_2 as base_model

# 1. Two-point EI network

In [None]:
# set parameters as in Fig. 5.58 and eq. 5.66-5.67 of "Understanding Vision" (Li Zhaoping, 2014)

alpha_x = 1.0
alpha_y = 1.0

K = 1

# connections
J_o = 2.1
W_o = 1.13

def compute_connection_kernel(K, verbose=False):
    J_dash = 0.4
    W_dash = 0.9

    J_kernel = np.zeros((3, 3, 1, 1))
    J_kernel[0, 0, 0, 0] = J_dash / 2
    J_kernel[0, 2, 0, 0] = J_dash / 2

    W_kernel = np.zeros((3, 3, 1, 1))
    W_kernel[0, 0, 0, 0] = W_dash / 2
    W_kernel[0, 2, 0, 0] = W_dash / 2
    
    Psi_kernel = np.zeros((3, 3, 1, 1))
    
    return (J_kernel, W_kernel, Psi_kernel)

# Normalization terms
I_o = lambda X: np.zeros_like(X)
I_c = lambda I_topdown: I_topdown

# Activation functions
def g_x(x, x_th = 1.):
    return np.maximum(0, x - x_th)

x = np.linspace(-1, 3, 1000)
plt.plot(x, g_x(x))
plt.ylabel(r'$g_x(x)$')
plt.xlabel('x')
plt.title(r'Activation function $g_x$ for pyramidal cells')
plt.tight_layout()
plt.show()

def g_y(y):
    return y

# tuning curves 
def tuning_curve(angle: np.ndarray):
    return (angle == 0.).astype(float) # delta function tuning

In [None]:
# Initialize the model

EI_pair = base_model(
    K=K,
    compute_connection_kernel=compute_connection_kernel,
    tuning_curve=tuning_curve,
    J_o=J_o,
    W_o=W_o,
    g_x = g_x,
    g_y = g_y,
    alpha_x=alpha_x,
    alpha_y=alpha_y,
    I_o=I_o,
    I_c=I_c
)

In [None]:
def column_input(I = (1., 1.)):
    A = np.zeros((3, 3))
    C = np.zeros((3, 3))
    C[:, 0] = I[0]
    C[:, 1] = I[1]
    C[:, 2] = I[0]
    return A, C

A, C = column_input(I = (0., 3.))
T = 200.0
dt = 0.001
X_0 = np.zeros((3, 3, 1))
Y_0 = np.zeros((3, 3, 1))
X, Y, I = EI_pair.simulate(A, C, dt, T, noisy=False, initial_condition=(X_0, Y_0), mode='wrap')

In [None]:
def plot_x1_y1(x_1, y_1, xlim=None, ylim=None, ax=None, figsize=(4, 3), dpi=200):
    """
    Plot x_1 vs y_1 on the provided Matplotlib axis.
    If ax is None, creates a new figure and axis with given figsize and dpi.
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, dpi=dpi, constrained_layout=True)
    ax.axvline(x=1, color='red', linestyle='--', linewidth=2)
    ax.plot(x_1, y_1, linewidth=0.5)
    plt.xlim(xlim)
    plt.ylim(ylim)
    ax.set_xlabel(r'$x_1$')
    ax.set_ylabel(r'$y_1$')
    ax.set_title('$x_1(t)$ versus $y_1(t)$')
    return ax

x_1 = X[:, 0, 1, 0]
y_1 = Y[:, 0, 1, 0]

plot_x1_y1(x_1, y_1, xlim=[-100, 300], ylim=[0, 300])
plt.show()

In [None]:
x_2 = X[:, 0, 2, 0]
gx1_plus_gx2 = g_x(x_1) + g_x(x_2)
gx1_minus_gx2 = g_x(x_1) - g_x(x_2)

def plot_gx_sum_diff(x1, x2, ax=None, xlim=None, ylim=None, dpi=200):
    """
    Plot g_x(x1) + g_x(x2) and g_x(x1) - g_x(x2).
    x1, x2 : 1D numpy arrays
    ax     : optional matplotlib axis to plot on; if None a new figure/axis is created
    dt     : time step (if None, uses global `dt` if available, else 1.0)
    xlim, ylim : axis limits
    Returns the axis used.
    """

    # make sure arrays have same length
    n = min(x1.size, x2.size)
    x1 = x1[:n]
    x2 = x2[:n]

    gx1 = g_x(x1)
    gx2 = g_x(x2)

    gx_sum = gx1 + gx2
    gx_diff = gx1 - gx2

    t = np.arange(0, n * dt, dt)

    if ax is None:
        fig, ax = plt.subplots(figsize=(4, 3), dpi=dpi, constrained_layout=True)

    ax.plot(t, gx_sum, linewidth=1, label=r'$g_x(x_1) + g_x(x_2)$')
    ax.plot(t, gx_diff, linewidth=1, label=r'$g_x(x_1) - g_x(x_2)$', color="red", linestyle='--')
    # ax.legend(loc="upper right", framealpha=0.98)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xlabel("Time")
    ax.set_ylabel(r'$g_x(x_1) + / - g_x(x_2)$')
    ax.set_title(r'$g_x(x_1) \pm g_x(x_2)$')

    return ax

plot_gx_sum_diff(x_1, x_2, xlim=(0, 200), ylim=(0, 250))
plt.show()

In [None]:
A_uni, C_uni = column_input(I = (3., 3.))
T = 60.0
dt = 0.001
X_0 = np.full((3, 3, 1), 4.3) 
X_0[:, 1, :] = 2.5
Y_0 = np.full((3, 3, 1), 3.3)
Y_0[:, 1, :] = 1.5
X_uni, Y_uni, I_uni = EI_pair.simulate(A_uni, C_uni, dt, T, noisy=False, initial_condition=(X_0, Y_0), mode='wrap')

In [None]:
x_1_uni = X_uni[:, 0, 1, 0]
y_1_uni = Y_uni[:, 0, 1, 0]

ax = plot_x1_y1(x_1_uni, y_1_uni, xlim=[-10, 20], ylim=[0, 40])
plt.xticks([-10, 0, 10, 20])
plt.yticks([0, 10, 20, 30, 40])

# x_2_uni = X_uni[:, 0, 0, 0]
# y_2_uni = Y_uni[:, 0, 0, 0]
# ax.plot(x_2_uni, y_2_uni, linewidth=0.5, color = "k")

plt.show()

In [None]:
x_2_uni = X_uni[:, 0, 0, 0]
plot_gx_sum_diff(x_1_uni, x_2_uni, xlim=(0, 60), ylim=(-10, 40))
plt.show()

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 8), dpi=200)

# Top-left: x_1 vs y_1 for uniform input
plot_x1_y1(x_1_uni, y_1_uni, xlim=[-10, 20], ylim=[0, 40], ax=axs[0, 0])
axs[0, 0].set_title("A: $x_1$ versus $y_1$")
axs[0, 0].set_xticks([-10, 0, 10, 20])
axs[0, 0].set_yticks([0, 10, 20, 30, 40])
axs[0, 0].set_xlim([-10, 20])
axs[0, 0].set_ylim([0, 40])

# Top-right: x_1 vs y_1 for non-uniform input
plot_x1_y1(x_1, y_1, xlim=[-100, 300], ylim=[0, 300], ax=axs[0, 1])
axs[0, 1].set_xticks([-100, 0, 100, 200, 300])
axs[0, 1].set_yticks([0, 100, 200, 300])
axs[0, 1].set_xlim([-100, 300])
axs[0, 1].set_ylim([0, 300])
axs[0, 1].set_title("B: $x_1$ versus $y_1$")

# Bottom-left: g(x_1) + g(x_2) and g(x_1) - g(x_2) for uniform input
plot_gx_sum_diff(x_1_uni, x_2_uni, xlim=(0, 60), ylim=(-10, 40), ax=axs[1, 0])
axs[1, 0].set_title(r"C: $g_x(x_1) \pm g_x(x_2)$")

# Bottom-right: g(x_1) + g(x_2) and g(x_1) - g(x_2) for non-uniform input
plot_gx_sum_diff(x_1, x_2, xlim=(0, 200), ylim=(0, 250), ax=axs[1, 1])
axs[1, 1].set_title(r"D: $g_x(x_1) \pm g_x(x_2)$")

fig.subplots_adjust(top=0.9, wspace=0.2, hspace=0.3)
fig.text(0.3, 0.95, r"Uniform inputs $I = (3, 3)$", ha='center', va='center', fontsize=14)
fig.text(0.73, 0.95, r"Non-uniform inputs $I = (3, 0)$", ha='center', va='center', fontsize=14)


# Add a global title and show the figure
# fig.suptitle("Comparison of Uniform and Non-Uniform Inputs", fontsize=16)
plt.show()