In [None]:
import numpy as np 
import matplotlib.pyplot as plt

from v1sh_model.inputs.visualize import visualize_input, visualize_output
from v1sh_model.models.V1_model_2 import V1_model_2 as V1_model

In [None]:
def grating_disc(diameter = 1, n_image = 4, I_input = 3.5):
    A = np.zeros((n_image, n_image))
    C = np.zeros((n_image, n_image))
    
    if n_image % 2 == 1:
        middle = int((n_image - 1) / 2)
        x = np.arange(- middle, middle + 1, 1)
    else:
        middle = n_image / 2
        x = np.arange(- middle, middle, 1)
    X, Y = np.meshgrid(x, x)
    R = np.sqrt(X**2 + Y**2)
    mask = (R <= diameter / 2)
    C[mask] = I_input
    
    return A, C

In [None]:
# Instantiate and test the FullModel
seed = 42
model = V1_model(seed=seed, alpha_x=1.0, alpha_y=1.0)
T = 2 # 12 # 1.6 # 12.0
dt = 0.001

X_gens = {}
for d in [6]: # 1, 2, 4, 6, 8, 20]:
    A, C = grating_disc(diameter=d, n_image=d+31, I_input=3.0)
    visualize_input(A, C, verbose=True)
    plt.show()
    
    X_gen, Y_gen, I = model.simulate(
    A, C, dt=dt, T=T, verbose=False, noisy=False, mode="wrap")
    X_gens[d] = X_gen
    
    model_output = model.g_x(X_gen).mean(axis=0)  # N_y x N_x x K
    C_out = model_output.max(axis=-1)  # N_y x N_x
    argmax_angle_indices = model_output.argmax(axis=-1)  # N_y x N_x
    A_out = np.pi / model.K * argmax_angle_indices  # N_y x N_x
    visualize_output(A_out, C_out, verbose=False)
    plt.show()