In [None]:
import numpy as np 
import matplotlib.pyplot as plt

from v1sh_model.inputs.visualize import visualize_input, visualize_output
from v1sh_model.models.V1_model_2 import V1_model_2 as V1_model

In [None]:
def figure(n_fig = 3, I_input=3.0):
    n_rows = 20
    n_cols = n_fig + 2 * n_rows
    A = np.full((n_rows, n_cols), np.pi / 2)
    A[:, n_rows: n_rows + n_fig] = 0.0
    C = np.full((n_rows, n_cols), I_input)
    return A, C

In [None]:
# Instantiate and test the FullModel
seed = 42
model = V1_model(seed=seed, alpha_x=1.0, alpha_y=1.0)
T = 12 # 1.6 # 12.0
dt = 0.001

In [None]:
X_gens = {}

for n_fig in [3, 10, 15, 28, 60]:
    A, C = figure(n_fig=n_fig)

    X_gen, Y_gen, I = model.simulate(
        A, C, dt=dt, T=T, verbose=False, noisy=True, mode="symmetric"
    )
    X_gens[n_fig] = X_gen

    # Time points in seconds to plot
    time_points = np.array([0.35, 0.45]) #  0.9, 0.95, 1.3, 1.5])
    steps = [int(t / dt) for t in time_points]

    # average of input across columns for the orientation of the target bar
    fig, axes = plt.subplots(3, 1, figsize=(6, 7.5), constrained_layout=True, dpi=400)
    visualize_input(A, C, verbose=False, axis=axes[0])
    axes[0].set_title(r"Input Image $\hat{I}_{i \theta}$ to model")

    model_output = model.g_x(X_gen).mean(axis=0)  # N_y x N_x x K
    C_out = model_output.max(axis=-1)  # N_y x N_x
    argmax_angle_indices = model_output.argmax(axis=-1)  # N_y x N_x
    A_out = np.pi / model.K * argmax_angle_indices  # N_y x N_x
    visualize_output(A_out, C_out, verbose=False, axis=axes[1])
    axes[1].set_title(r"Model output $g_x(x_{i \theta})$s")
    
    # Plot model dynamics, see fig. 5.21 in "Understanding Vision" (Li Zhaoping, 2014)
    X_gen_max = np.max(model.g_x(X_gen), axis = -1)
    X_per_column = np.mean(model.g_x(X_gen_max), axis=1)
    x_axis = np.arange(X_per_column.shape[1])  # column indices
    for t_idx, step in zip(time_points, steps):
        # Neural response: sum over orientation channels (axis=-1)
        response = X_per_column[step]
        axes[2].plot(x_axis, response, label=f"t = {t_idx:.1f}")

    avg_response = X_per_column.mean(axis=0)
    axes[2].plot(
        x_axis, avg_response, label="Temporal average", linewidth=3, linestyle="--", color="k"
    )

    axes[2].set_xlabel("Texture column number")
    axes[2].set_ylabel("Model response " + r"$g_x(x)$")
    # axes[2].set_ylim(-0.1, 1.1)
    axes[2].set_xlim(0, A.shape[1] - 1)
    axes[2].legend(framealpha=0.8, loc="upper left")
    axes[2].grid(True)

    plt.show()