# Naive implementation of Model
Idea: start with fast feedback loop in one messy notebook, afterwards organize in folder structure and adapt from there on

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.figure

## 1. Create and visualize some input

In [None]:
def plot_bars(A, W, l=9, r=1.3, verbose=True, dpi=500):
    """
    Plots a grid of bars with given angles, contrasts, and saliency (linewidths).
    
    Parameters:
        A (np.ndarray): 2D array of angles (radians), shape (N_y, N_x)
        W (np.ndarray or None): 2D array of linewidths, same shape as A
        l (float): Bar length
        r (float): Grid spacing factor
        verbose (bool): If True, show the plot
        dpi (int): Dots per inch for rendering
        
    Returns:
        fig (matplotlib.figure.Figure): The matplotlib figure object
    """
    assert A.ndim == 2, "A must be a 2D array"
    assert W.shape == A.shape, "C must have the same shape as A"
    N_y, N_x = A.shape
    
    # Calculate image size in pixels
    d = l * r # grid spacing
    img_height = int(N_y * d)
    img_width = int(N_x * d)

    # Create figure
    fig, ax = plt.subplots(figsize=(img_width/100, img_height/100), dpi=dpi)
    ax.set_xlim(0, img_width)
    ax.set_ylim(0, img_height)
    ax.set_aspect('equal') # keep x and y scales the same, avoding distortion
    ax.axis('off')

    # Draw bars
    for i in range(N_y):
        for j in range(N_x):
            # compute center of the bar
            cx = (j + 0.5) * d
            cy = (i + 0.5) * d
            # compute bar directions
            angle = A[i, j]
            dx = l * np.sin(angle) / 2
            dy = l * np.cos(angle) / 2
            # compute endpoints of the bar
            x0, y0 = cx - dx, cy - dy
            x1, y1 = cx + dx, cy + dy
            # draw the bar
            ax.plot([x0, x1], [y0, y1], 
                    color = "k", 
                    linewidth=W[i, j], 
                    solid_capstyle='butt'
            )
    
    if verbose:
        plt.show()
    
    return fig

def visualize_input(A, C, l=9, r=1.3, verbose=True, dpi=500):
    """ 
        Visualizes the input angles A and contrasts C as a grid of bars.
        
        Parameters:
            A (np.ndarray): 2D array of angles (radians), shape (N_y, N_x)
            C (np.ndarray): 2D array of contrasts, same shape as A, values in [1, 4]
            l (float): Bar length
            r (float): Grid spacing factor
            verbose (bool): If True, show the plot
            dpi (int): Dots per inch for rendering
        
        Returns:
            fig (matplotlib.figure.Figure): The matplotlib figure object
    """
    
    assert np.all((C >= 1) & (C <= 4) | (C == 0)), "C values must 0 or in [1, 4]"
    W = C / 3
    return plot_bars(A, W, l=l, r=r, verbose=verbose, dpi=dpi)

def visualize_output(A, S, l=9, r=1.3, verbose=True, dpi=500):
    """ 
        Visualizes the output saliency S as a grid of bars with uniform orientation.
        
        Parameters:
            A (np.ndarray): 2D array of angles (radians), shape (N_y, N_x)
            S (np.ndarray): 3D array (Y x X) of saliency values, shape (N_y, N_x)
            l (float): Bar length
            r (float): Grid spacing factor
            verbose (bool): If True, show the plot
            dpi (int): Dots per inch for rendering
        
        Returns:
            fig (matplotlib.figure.Figure): The matplotlib figure object
    """
    # TODO: how to scale and normalize when reading out S?
    assert np.all(S >= 0), "S values must be non-negative"
    return plot_bars(A, S, l=l, r=r, verbose=verbose, dpi=dpi)

In [None]:
def test_bar_without_surround(verbose=False):
    C = np.zeros((9, 9)) 
    C[4, 4] = 3.5
    A = np.zeros((9, 9))
    fig = visualize_input(A, C, verbose=verbose)
    assert isinstance(fig, matplotlib.figure.Figure)

def test_iso_orientation(verbose=False):
    C = np.full((9, 9), 3.5)
    A = np.zeros((9, 9))
    fig = visualize_input(A, C, verbose=verbose)
    assert isinstance(fig, matplotlib.figure.Figure)

def test_random_background(verbose=False, seed=42):
    C = np.full((9, 9), 3.5)
    rng = np.random.default_rng(seed)
    A = rng.uniform(0, np.pi, (9, 9))
    A[4, 4] = 0.
    fig = visualize_input(A, C, verbose=verbose)
    assert isinstance(fig, matplotlib.figure.Figure)

def test_cross_orientation(verbose=False):
    C = np.full((9, 9), 3.5)
    A = np.full((9, 9),  np.pi / 2)
    A[4, 4] = 0
    fig = visualize_input(A, C, verbose=verbose)
    assert isinstance(fig, matplotlib.figure.Figure)
    
def test_bar_without_surround_low_contrast(verbose=False):
    C = np.zeros((9, 9))
    C[4, 4] = 1.05
    A = np.zeros((9, 9))
    fig = visualize_input(A, C, verbose=verbose)
    assert isinstance(fig, matplotlib.figure.Figure)

def test_with_one_flanker(verbose=False):
    C = np.zeros((9, 9))
    C[4, 4] = 1.05
    C[5, 4] = 3.5
    A = np.zeros((9, 9))
    fig = visualize_input(A, C, verbose=verbose)
    assert isinstance(fig, matplotlib.figure.Figure)

def test_with_two_flankers(verbose=False):
    C = np.zeros((9, 9))
    C[4, 4] = 1.5
    C[3, 4] = 3.5
    C[5, 4] = 3.5
    A = np.zeros((9, 9))
    fig = visualize_input(A, C, verbose=verbose)
    assert isinstance(fig, matplotlib.figure.Figure)

def test_with_flanking_line_and_noise(verbose=False):
    rng = np.random.default_rng(42)
    A = rng.uniform(0, np.pi, (9, 9))
    A[:, 4] = 0.
    C = np.full((9, 9), 3.5)
    C[4, 4] = 1.5
    fig = visualize_input(A, C, verbose=verbose)
    assert isinstance(fig, matplotlib.figure.Figure)

def test_neighboring_textures(verbose=False):
    A = np.zeros((11, 27))
    A[:, :14] = np.pi/2
    C = np.full((11, 27), 2.0)
    fig = visualize_input(A, C, verbose=verbose)
    assert isinstance(fig, matplotlib.figure.Figure)

# run tests
test_bar_without_surround(verbose=True)
test_iso_orientation(verbose=True)
test_random_background(verbose=True)
test_cross_orientation(verbose=True)

test_bar_without_surround_low_contrast(verbose=True)
test_with_one_flanker(verbose=True)
test_with_two_flankers(verbose=True)
test_with_flanking_line_and_noise(verbose=True)

test_neighboring_textures(verbose=True)

## 2. Implement Naive Model

In [None]:
def tuning_curve(angle : np.ndarray) -> np.ndarray:
    """ Tuning curve function
    
    Parameters:
        angle (np.ndarray): angle difference (radians), shape (N_y, N_x, ...), values in [-pi/2, +pi/2]
    
    Returns:
        (np.ndarray): tuning curve values, shape (N_y, N_x, ...), values in [0, 1]
    """
    absolute_angle = np.abs(angle)
    absolute_angle[absolute_angle >= np.pi/2] = np.pi - absolute_angle[absolute_angle >= np.pi/2]
    phi = np.exp(- absolute_angle / (np.pi / 8))
    phi[absolute_angle >= np.pi/6] = 0
    return phi

In [None]:
k = 10000
angles_1 = np.linspace(-np.pi/2 - np.pi/k, 0, k)
angles_2 = - angles_1[::-1][1:]
angles = np.concatenate([angles_1, angles_2])
tc_values = tuning_curve(angles)
print(np.allclose(tc_values, tc_values[::-1], atol=1e-6))  # Should be True for perfect symmetry

plt.rcParams.update({'font.size': 17})
plt.figure(figsize=(6, 4), dpi = 1000, constrained_layout=True)
plt.plot(angles, tc_values, linewidth = 4)
plt.xlabel(r"Angle $x$ (radians)")
plt.ylabel(r"$\phi(x)$")
plt.title(r"Tuning Curve $\phi(x)$")
plt.grid(True)
plt.show()

In [None]:
from typing import Optional

def get_model_input(A : np.ndarray, C : np.ndarray, M : Optional[np.ndarray] = None, K = 12) -> np.ndarray:
    """ Computes model input from visual input

    TODO: extend to multiple input bars per locationn (i.e. A and C of shape (N_y, N_x, L) where L is number of input bars per location)
    
    Parameters:
        A (np.ndarray): 2D array of angles (radians) of input bars, shape (N_y, N_x), values in [0, pi]  
        C (np.ndarray): 2D array of contrasts of input bars, same shape as A, values in [1, 4] or 0 (no bar)
        M (np.ndarray): prefered orientations of model neurons, shape (N_y, N_x, K), values in [0, pi], where K is number of orientation channels
    
    Returns:
        I (np.ndarray): 3D array of model input, shape (K, N_y, N_x)
    
    """
    if M is None:
        angles = np.linspace(0, np.pi, K, endpoint=False) 
        M = angles[np.newaxis, np.newaxis, :]
        N_y, N_x = A.shape
        M = np.broadcast_to(M, (N_y, N_x, K))
    
    M = M % np.pi  # ensure M in [0, pi]
    A = A % np.pi  # ensure A in [0, pi]
    
    A = A[:, :, np.newaxis]  # shape (N_y, N_x, 1)
    C = C[:, :, np.newaxis]  # shape (N_y, N_x, 1)
    return C * tuning_curve(A - M), M

In [None]:
# Generate random valid inputs
N_y, N_x = 3, 9
rng = np.random.default_rng(42)
A_test = rng.uniform(0, np.pi, (N_y, N_x))
C_test = np.full((N_y, N_x), 2.5) # rng.uniform(1, 4, (N_y, N_x))

# Call get_model_input
I, M = get_model_input(A_test, C_test)

print("Input shape:", I.shape)
print("Input min/max:", I.min(), I.max())

for k in range(I.shape[2]):
    print(M[0, 0, k] / np.pi * 180)
    visualize_output(M[:, :, k], I[:, :, k])