In [None]:
import numpy as np 
import matplotlib.pyplot as plt

from v1sh_model.inputs.visualize import visualize_input, visualize_output
from v1sh_model.models.V1_model_2 import V1_model_2 as V1_model

In [None]:
def figure(n_fig = 3, I_input=3.0, n_rows = 40):
    n_cols = n_fig + 2 * n_rows
    A = np.full((n_rows, n_cols), np.pi / 2)
    A[:, n_rows: n_rows + n_fig] = 0.0
    C = np.full((n_rows, n_cols), I_input)
    return A, C

In [None]:
# Instantiate and test the FullModel
seed = 42
model = V1_model(seed=seed, alpha_x=1.0, alpha_y=1.0)
T = 12.0
dt = 0.001

In [None]:
X_gens = {}

n_figs = [3, 11, 15, 35]
for n_fig in n_figs:
    A, C = figure(n_fig=n_fig)

    X_gen, Y_gen, I = model.simulate(
        A, C, dt=dt, T=T, verbose=False, noisy=True, mode="wrap"
    )
    X_gens[n_fig] = X_gen

In [None]:
N_x_plot = 40
N_y_plot = 20

for n_fig in n_figs:
    # Plot input
    A, C = figure(n_fig=n_fig)
    
    y_center, x_center = int(A.shape[0] / 2), int(A.shape[1] / 2)
    y_lower = y_center - int(N_y_plot / 2)
    y_upper = y_center + int(N_y_plot / 2)
    x_lower = x_center - int(N_x_plot / 2)
    x_upper = x_center + int(N_x_plot / 2)
    A_plot, C_plot = A[y_lower:y_upper+1, x_lower:x_upper+1], C[y_lower:y_upper+1, x_lower:x_upper+1]
    
    visualize_input(A_plot, C_plot, verbose=False, dpi = 400)
    plt.savefig("../../docs/figure/medial_axis_effect/Fig. 5.34_input_{}.png".format(n_fig))
    plt.show()
    plt.close()

    # Plot model output
    X_gen = X_gens[n_fig]
    X_gen_to_plot = X_gen[:, y_lower:y_upper+1, x_lower:x_upper+1]
    model_output = model.g_x(X_gen_to_plot).mean(axis=0)  # N_y x N_x x K
    C_out = model_output.max(axis=-1)  # N_y x N_x

    argmax_angle_indices = model_output.argmax(axis=-1)  # N_y x N_x
    A_out = np.pi / model.K * argmax_angle_indices  # N_y x N_x

    visualize_output(A_out, C_out * 5, verbose=False, dpi=400)
    plt.savefig("../../docs/figure/medial_axis_effect/Fig. 5.34_output_{}.png".format(n_fig))
    plt.show()
    plt.close()
