# Naive implementation of Model
Idea: start with fast feedback loop in one messy notebook, afterwards organize in folder structure and adapt from there on

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.ndimage import convolve
import time 

## 1. Create and visualize some input

In [None]:
def plot_bars(A, W, l=9, r=1.3, verbose=True, dpi=500, axis=None):
    """
    Plots a grid of bars with given angles, contrasts, and saliency (linewidths).

    Parameters:
        A (np.ndarray): 2D array of angles (radians), shape (N_y, N_x)
        W (np.ndarray or None): 2D array of linewidths, same shape as A
        l (float): Bar length
        r (float): Grid spacing factor
        verbose (bool): If True, show the plot
        dpi (int): Dots per inch for rendering
        axis (matplotlib axis or None): If provided, plot on this axis instead of creating a new figure.

    Returns:
        fig (matplotlib.figure.Figure): The matplotlib figure object
    """
    assert A.ndim == 2, "A must be a 2D array"
    assert W.shape == A.shape, "C must have the same shape as A"
    N_y, N_x = A.shape

    # Calculate image size in pixels
    d = l * r  # grid spacing
    img_height = int(N_y * d)
    img_width = int(N_x * d)

    # Create figure
    if axis is not None:
        ax = axis
        fig = None
    else:
        fig, ax = plt.subplots(figsize=(img_width / 100, img_height / 100), dpi=dpi)
    ax.set_xlim(0, img_width)
    ax.set_ylim(0, img_height)
    ax.set_aspect("equal")  # keep x and y scales the same, avoding distortion
    ax.axis("off")

    # Draw bars
    for i in range(N_y):
        for j in range(N_x):
            # compute center of the bar
            cx = (j + 0.5) * d
            cy = (i + 0.5) * d
            # compute bar directions
            angle = A[i, j]
            dx = l * np.sin(angle) / 2
            dy = l * np.cos(angle) / 2
            # compute endpoints of the bar
            x0, y0 = cx - dx, cy - dy
            x1, y1 = cx + dx, cy + dy
            # draw the bar
            ax.plot(
                [x0, x1], [y0, y1], color="k", linewidth=W[i, j], solid_capstyle="butt"
            )

    if verbose:
        plt.show()

    return fig


def visualize_input(A, C, l=9, r=1.3, verbose=True, dpi=500, axis=None):
    """
    Visualizes the input angles A and contrasts C as a grid of bars.

    Parameters:
        A (np.ndarray): 2D array of angles (radians), shape (N_y, N_x)
        C (np.ndarray): 2D array of contrasts, same shape as A, values in [1, 4]
        l (float): Bar length
        r (float): Grid spacing factor
        verbose (bool): If True, show the plot
        dpi (int): Dots per inch for rendering
        axis (matplotlib axis or None): If provided, plot on this axis instead of creating a new figure.

    Returns:
        fig (matplotlib.figure.Figure): The matplotlib figure object
    """

    assert np.all((C >= 1) & (C <= 4) | (C == 0)), "C values must 0 or in [1, 4]"
    W = C / 3
    return plot_bars(A, W, l=l, r=r, verbose=verbose, dpi=dpi, axis=axis)


def visualize_output(A, S, l=9, r=1.3, verbose=True, dpi=500, axis=None):
    """
    Visualizes the output saliency S as a grid of bars with uniform orientation.

    Parameters:
        A (np.ndarray): 2D array of angles (radians), shape (N_y, N_x)
        S (np.ndarray): 3D array (Y x X) of saliency values, shape (N_y, N_x)
        l (float): Bar length
        r (float): Grid spacing factor
        verbose (bool): If True, show the plot
        dpi (int): Dots per inch for rendering

    Returns:
        fig (matplotlib.figure.Figure): The matplotlib figure object
    """
    # TODO: how to scale and normalize when reading out S?
    assert np.all(S >= 0), "S values must be non-negative"
    return plot_bars(A, S, l=l, r=r, verbose=verbose, dpi=dpi, axis=axis)

In [None]:
def bar_without_surround(N_y = 9, N_x = 9):
    C = np.zeros((N_y, N_x))
    C[int((N_y - 1) / 2), int((N_x - 1) / 2)] = 3.5
    A = np.zeros((N_y, N_x))
    return A, C


def iso_orientation(N_y = 9, N_x = 9):
    C = np.full((N_y, N_x), 3.5)
    A = np.zeros((N_y, N_x))
    return A, C


def random_background(N_y = 9, N_x = 9, seed=None):
    C = np.full((N_y, N_x), 3.5)
    rng = np.random.default_rng(seed)
    A = rng.uniform(0, np.pi, (N_y, N_x))
    A[int((N_y - 1) / 2), int((N_x - 1) / 2)]  = 0.0
    return A, C


def cross_orientation(N_y = 9, N_x = 9):
    C = np.full((N_y, N_x), 3.5)
    A = np.full((N_y, N_x), np.pi / 2)
    A[int((N_y - 1) / 2), int((N_x - 1) / 2)] = 0
    return A, C


def bar_without_surround_low_contrast(N_y = 9, N_x = 9):
    C = np.zeros((N_y, N_x))
    C[int((N_y - 1) / 2), int((N_x - 1) / 2)] = 1.05
    A = np.zeros((N_y, N_x))
    return A, C


def with_one_flanker(N_y = 9, N_x = 9):
    C = np.zeros((N_y, N_x))
    C[int((N_y - 1) / 2), int((N_x - 1) / 2)] = 1.05
    C[int((N_y - 1) / 2) + 1, int((N_x - 1) / 2)] = 3.5
    A = np.zeros((N_y, N_x))
    return A, C


def with_two_flankers(N_y = 9, N_x = 9):
    C = np.zeros((N_y, N_x))
    y_mid, x_mid = int((N_y - 1) / 2), int((N_x - 1) / 2)
    C[y_mid, x_mid] = 1.5
    C[y_mid - 1, x_mid] = 3.5
    C[y_mid + 1, x_mid] = 3.5
    A = np.zeros((N_y, N_x))
    return A, C


def with_flanking_line_and_noise(N_y = 9, N_x = 9, seed=None):
    rng = np.random.default_rng(seed)
    A = rng.uniform(0, np.pi, (N_y, N_x))
    y_mid, x_mid = int((N_y - 1) / 2), int((N_x - 1) / 2)
    A[:, x_mid] = 0.0
    C = np.full((N_y, N_x), 3.5)
    C[y_mid, x_mid] = 1.5
    return A, C


def neighboring_textures(n_rows = 11, n_cols = 27):
    A = np.zeros((n_rows, n_cols))
    A[:, :n_cols // 2] = np.pi / 2
    C = np.full((n_rows, n_cols), 2.0)
    return A, C

In [None]:
dpi = 400
N_y, N_x = 9, 9
l = 9
r = 1.3
d = l * r  # grid spacing
img_height = int(N_y * d)
img_width = int(N_x * d)

plt.rcParams.update({"font.size": 6})
fig, axes = plt.subplots(
    2,
    4,
    figsize=(img_width / 100 * 4, img_height / 100 * 2),
    dpi=dpi,
    constrained_layout=True,
)

axes = axes.flatten()

test_cases = [
    ("A: Bar without\nsurround", bar_without_surround),
    ("B: Iso-\norientation", iso_orientation),
    ("C: Random\nbackground", random_background),
    ("D: Cross-\norientation", cross_orientation),
    ("E: Bar without\nsurround", bar_without_surround_low_contrast),
    ("F: With one\nflanker", with_one_flanker),
    ("G: With two\nflankers", with_two_flankers),
    ("E: With flanking\nline and noise", with_flanking_line_and_noise),
]

for ax, (title, func) in zip(axes, test_cases):
    A, C = func()
    visualize_input(A, C, verbose=False, axis=ax)
    ax.set_title(title)

plt.show()

## 2. Implement Tuning Curve

In [None]:
def tuning_curve(angle: np.ndarray) -> np.ndarray:
    """Tuning curve function

    Parameters:
        angle (np.ndarray): angle difference (radians), shape (N_y, N_x, ...), values in [-pi/2, +pi/2]

    Returns:
        (np.ndarray): tuning curve values, shape (N_y, N_x, ...), values in [0, 1]
    """
    absolute_angle = np.abs(angle)
    absolute_angle = np.minimum(
        absolute_angle, np.pi - absolute_angle
    )  # wrap to [0, pi/2]
    phi = np.exp(-absolute_angle / (np.pi / 8))
    phi[absolute_angle >= np.pi / 6] = 0
    return phi

In [None]:
k_post = 10000
angles_1 = np.linspace(-np.pi / 2 - np.pi / k_post, 0, k_post)
angles_2 = -angles_1[::-1][1:]
angles = np.concatenate([angles_1, angles_2])
tc_values = tuning_curve(angles)
print(
    np.allclose(tc_values, tc_values[::-1], atol=1e-6)
)  # Should be True for perfect symmetry

plt.rcParams.update({"font.size": 17})
plt.figure(figsize=(6, 4), dpi=1000, constrained_layout=True)
plt.plot(angles * 180 / np.pi, tc_values, linewidth=4)
plt.xticks(np.arange(-90, 91, 30))
plt.xlabel(r"Angle $x$ [°]")
plt.ylabel(r"$\phi(x)$")
plt.title(r"Tuning Curve $\phi(x)$")
plt.grid(True)
plt.show()

In [None]:
from typing import Optional


def get_model_input(
    A: np.ndarray, C: np.ndarray, M: Optional[np.ndarray] = None, K=12
) -> np.ndarray:
    """Computes model input from visual input

    TODO: extend to multiple input bars per locationn (i.e. A and C of shape (N_y, N_x, L) where L is number of input bars per location)

    Parameters:
        A (np.ndarray): 2D array of angles (radians) of input bars, shape (N_y, N_x), values in [0, pi]
        C (np.ndarray): 2D array of contrasts of input bars, same shape as A, values in [1, 4] or 0 (no bar)
        M (np.ndarray): prefered orientations of model neurons, shape (N_y, N_x, K), values in [0, pi], where K is number of orientation channels

    Returns:
        I (np.ndarray): 3D array of model input, shape (K, N_y, N_x)

    """
    if M is None:
        angles = np.linspace(0, np.pi, K, endpoint=False)
        M = angles[np.newaxis, np.newaxis, :]
        N_y, N_x = A.shape
        M = np.broadcast_to(M, (N_y, N_x, K))

    M = M % np.pi  # ensure M in [0, pi]
    A = A % np.pi  # ensure A in [0, pi]

    A = A[:, :, np.newaxis]  # shape (N_y, N_x, 1)
    C = C[:, :, np.newaxis]  # shape (N_y, N_x, 1)
    return C * tuning_curve(A - M), M

In [None]:
test = False

# Generate random valid inputs
N_y, N_x = 3, 9
rng = np.random.default_rng(42)
A_test = rng.uniform(0, np.pi, (N_y, N_x))
C_test = np.full((N_y, N_x), 2.5)  # rng.uniform(1, 4, (N_y, N_x))

# Call get_model_input
I, M = get_model_input(A_test, C_test)

print("Input shape:", I.shape)
print("Input min/max:", I.min(), I.max())

for k_post in range(I.shape[2]):
    print(M[0, 0, k_post] / np.pi * 180)
    visualize_output(M[:, :, k_post], I[:, :, k_post], verbose=test)

## 3. Implement Naive Model

In [None]:
class NaiveModel:
    def __init__(self, K=12, alpha=1.0):
        self.alpha = alpha
        self.K = K

        # Precompute preferred orientations per neuron
        angles = np.linspace(0, np.pi, self.K, endpoint=False)
        self.M = angles[np.newaxis, np.newaxis, :]  # shape (1, 1, K)

    def get_input(
        self, A: np.ndarray, C: np.ndarray, verbose: bool = False
    ) -> np.ndarray:
        """Computes model input from visual input

        TODO: extend to multiple input bars per locationn (i.e. A and C of shape (N_y, N_x, L) where L is number of input bars per location)

        Parameters:
            A (np.ndarray): 2D array of angles (radians) of input bars, shape (N_y, N_x), values in [0, pi]
            C (np.ndarray): 2D array of contrasts of input bars, same shape as A, values in [1, 4] or 0 (no bar)
            verbose (bool): If True, visualize input and output

        Returns:
            I (np.ndarray): 3D array of model input, shape (N_y, N_x, K)

        """
        M = np.broadcast_to(self.M, (N_y, N_x, self.K))
        M = M % np.pi  # ensure M in [0, pi]

        A = A % np.pi  # ensure A in [0, pi]
        A = A[:, :, np.newaxis]  # shape (N_y, N_x, 1)
        C = C[:, :, np.newaxis]  # shape (N_y, N_x, 1)

        I = C * tuning_curve(A - M)

        if verbose:
            visualize_input(A, C, verbose=True)
            for k in range(I.shape[2]):
                print("==================================")
                print(f"Neurons {k}, tuned to {M[0, 0, k] / np.pi * 180}°")
                visualize_output(M[:, :, k], I[:, :, k], verbose=True)

        return I

    def derivative(self, X: np.ndarray, I: np.ndarray) -> np.ndarray:
        """Computes the derivative dX/dt

        Parameters:
            X (np.ndarray): Current state, shape (N_y, N_x, K)
            I (np.ndarray): Input, shape (N_y, N_x, K)

        Returns:
            (np.ndarray): Derivative dX/dt, shape (N_y, N_x, K)
        """

        return -self.alpha * X + I

    def euler_method(self, I: np.ndarray, dt: float, T: float) -> np.ndarray:
        """Simulates the model over time given input I

        Parameters:
            I (np.ndarray): Input, shape (N_y, N_x, K)
            dt (float): Time step
            T (float): Total simulation time

        Returns:
            X (np.ndarray): Final state after simulation, shape (T, N_y, N_x, K)
        """

        steps = int(T // dt)
        X = np.zeros((steps, *I.shape))

        X[0] = np.zeros_like(I)  # initial state = input
        for t in range(1, steps):
            dXdt = self.derivative(X[t - 1], I)
            X[t] = X[t - 1] + dt * dXdt

        return X

    def simulate(
        self,
        A: np.ndarray,
        C: np.ndarray,
        dt: float = 0.001,
        T: float = 12.0,
        verbose: bool = False,
    ) -> np.ndarray:
        """Runs the full simulation given angles A and contrasts C

        Parameters:
            A (np.ndarray): 2D array of angles (radians) of input bars, shape (N_y, N_x), values in [0, pi]
            C (np.ndarray): 2D array of contrasts of input bars, same shape as A, values in [1, 4] or 0 (no bar)
            dt (float): Time step
            T (float): Total simulation time

        Returns:
            X (np.ndarray): Final state after simulation, shape (T, N_y, N_x, K)
        """

        I = self.get_input(A, C, verbose=verbose)
        X = self.euler_method(I, dt, T)
        return X, I

In [None]:
# Instantiate and test the NaiveModel
model = NaiveModel(K=10, alpha=1.0)

A = np.array([[90 / 180 * np.pi]])
C = np.array([[2.0]])

T = 8
dt = 0.001
X_gen, I = model.simulate(A, C, dt=dt, T=T, verbose=False)


def analytical_solution(t: np.ndarray, I: np.ndarray, alpha: float) -> np.ndarray:
    """Computes the analytical solution of the ODE at time t given input I and parameter alpha.
        Assumes constant input, and initial condition X(t = 0) = 0.

    Parameters:
        t (np.ndarray): Time points, shape (N_t,)
        I (np.ndarray): Input, shape (N_y, N_x, K)
        alpha (float): Model parameter

    Returns:
        (np.ndarray): Analytical solution at time t, shape (N_y, N_x, K)
    """
    t_ = t[:, np.newaxis, np.newaxis, np.newaxis]  # shape (N_t, 1, 1, 1)
    I_ = I[np.newaxis, :, :, :]  # shape (1, N_y, N_x, K)
    return (1 - np.exp(-alpha * t_)) / alpha * I_


t = np.arange(X_gen.shape[0]) * dt
X_gt = analytical_solution(t, I, model.alpha)

plt.rcParams.update({"font.size": 12})
plt.figure(figsize=(6, 4), dpi=300, constrained_layout=True)
labels = [f"{round(180 / model.K * k)}°" for k in range(model.K)]
colors = plt.cm.hsv(np.linspace(0, 1, model.K, endpoint=False))  # Use twilight colormap
linestyles = ["-", "--", "-.", ":"] * ((model.K // 4) + 1)  # Cycle through styles
markers = ["o", "s", "D", "^", "v", "<", ">", "p", "*", "h"] * (
    (model.K // 10) + 1
)  # Cycle through markers

for k_post in range(model.K):
    plt.plot(
        t,
        X_gen[:, 0, 0, k_post],
        linestyle=linestyles[k_post],
        label=f"{labels[k_post]}",
        color=colors[k_post],
        alpha=0.8,
        linewidth=2,
        marker=markers[k_post],
        markevery=(
            k_post * round(len(t) / 6 / model.K),
            round(len(t) / 6),
        ),  # Offset markers for each line
        markeredgecolor="black",  # Add black outline
        markeredgewidth=0.6,
    )
    plt.plot(
        t,
        X_gt[:, 0, 0, k_post],
        linestyle="-",
        color="k",
        alpha=0.1,
        linewidth=4,
        label="Analytical\nSolutions" if k_post == model.K - 1 else None,
    )
plt.legend(
    loc="center right",
    title="Preferred\norientation",
    framealpha=1.0,
    fontsize=8,
    title_fontsize=9,
)
plt.xlabel("Time [per time constant]")
plt.ylabel("Model response")
plt.title(
    "Response of model neurons in one hypercolumn\n"
    + r"to a bar of orientation $\theta = 90^\circ$ and contrast $\hat{I} = 2$"
)
plt.show()

## 4. Implement noise model

In [None]:
class NaiveNoisyModel:
    def __init__(
        self,
        K=12,
        alpha=1.0,
        average_noise_height=0.1,
        average_noise_temporal_width=0.1,
        seed=None,
    ):
        self.alpha = alpha
        self.K = K

        # Precompute preferred orientations per neuron
        angles = np.linspace(0, np.pi, self.K, endpoint=False)
        self.M = angles[np.newaxis, np.newaxis, :]  # shape (1, 1, K)

        # Noise parameters
        assert average_noise_height >= 0, "average_noise_height must be non-negative"
        assert average_noise_temporal_width > 0, (
            "average_noise_temporal_width must be positive"
        )
        self.noise_std = average_noise_height
        self.noise_tau = average_noise_temporal_width
        self.rng = np.random.default_rng(seed)

    def get_input(
        self, A: np.ndarray, C: np.ndarray, verbose: bool = False
    ) -> np.ndarray:
        """Computes model input from visual input

        TODO: extend to multiple input bars per locationn (i.e. A and C of shape (N_y, N_x, L) where L is number of input bars per location)

        Parameters:
            A (np.ndarray): 2D array of angles (radians) of input bars, shape (N_y, N_x), values in [0, pi]
            C (np.ndarray): 2D array of contrasts of input bars, same shape as A, values in [1, 4] or 0 (no bar)
            verbose (bool): If True, visualize input and output

        Returns:
            I (np.ndarray): 3D array of model input, shape (N_y, N_x, K)

        """
        assert A.ndim == 2, "A must be a 2D array"
        assert C.shape == A.shape, "C must have the same shape as A"
        N_y, N_x = A.shape
        M = np.broadcast_to(self.M, (N_y, N_x, self.K))
        M = M % np.pi  # ensure M in [0, pi]

        A = A % np.pi  # ensure A in [0, pi]
        A = A[:, :, np.newaxis]  # shape (N_y, N_x, 1)
        C = C[:, :, np.newaxis]  # shape (N_y, N_x, 1)

        I = C * tuning_curve(A - M)

        if verbose:
            visualize_input(A, C, verbose=True)
            for k in range(I.shape[2]):
                print("==================================")
                print(f"Neurons {k}, tuned to {M[0, 0, k] / np.pi * 180}°")
                visualize_output(M[:, :, k], I[:, :, k], verbose=True)

        return I

    def derivative(self, X: np.ndarray, I: np.ndarray) -> np.ndarray:
        """Computes the derivative dX/dt

        Parameters:
            X (np.ndarray): Current state, shape (N_y, N_x, K)
            I (np.ndarray): Input, shape (N_y, N_x, K)

        Returns:
            (np.ndarray): Derivative dX/dt, shape (N_y, N_x, K)
        """

        return -self.alpha * X + I

    def euler_method(self, I: np.ndarray, dt: float, T: float) -> np.ndarray:
        """Simulates the model over time given input I

        Parameters:
            I (np.ndarray): Input, shape (N_y, N_x, K)
            dt (float): Time step
            T (float): Total simulation time

        Returns:
            X (np.ndarray): Final state after simulation, shape (T, N_y, N_x, K)
        """

        steps = int(T // dt)
        X = np.zeros((steps, *I.shape))
        X[0] = np.zeros_like(I)  # initial state = input

        I_noise = np.zeros_like(I)
        noise_duration = np.zeros_like(I)

        for t in range(1, steps):
            # dXdt = self.derivative(X[t-1], I)
            # X[t] = X[t-1] + dt * dXdt

            # add noise: temporal width follows exponential distribution, amplitude follows normal distribution
            noise_duration -= dt
            I_noise[noise_duration <= 0] = self.rng.normal(
                0, self.noise_std, size=I.shape
            )[noise_duration <= 0]
            noise_duration[noise_duration <= 0] = self.rng.exponential(
                self.noise_tau, size=I.shape
            )[noise_duration <= 0]
            X[t] += I_noise * dt

        return X

    def simulate(
        self,
        A: np.ndarray,
        C: np.ndarray,
        dt: float = 0.001,
        T: float = 12.0,
        verbose: bool = False,
    ) -> np.ndarray:
        """Runs the full simulation given angles A and contrasts C

        Parameters:
            A (np.ndarray): 2D array of angles (radians) of input bars, shape (N_y, N_x), values in [0, pi]
            C (np.ndarray): 2D array of contrasts of input bars, same shape as A, values in [1, 4] or 0 (no bar)
            dt (float): Time step
            T (float): Total simulation time

        Returns:
            X (np.ndarray): Final state after simulation, shape (T, N_y, N_x, K)
        """

        I = self.get_input(A, C, verbose=verbose)
        X = self.euler_method(I, dt, T)
        return X, I

In [None]:
# Instantiate and test the NaiveModel
model = NaiveNoisyModel(
    K=12, alpha=1.0, average_noise_height=0.1, average_noise_temporal_width=0.1
)

A = np.full((9, 9), 90 / 180 * np.pi)
C = np.zeros((9, 9))

T = 10
dt = 0.001
X_gen, I = model.simulate(A, C, dt=dt, T=T, verbose=False)

discard_initial_steps = int(model.noise_tau / dt * 10)
X_gen = X_gen[discard_initial_steps:]

In [None]:
np.allclose(model.noise_std, np.sqrt(np.mean(X_gen**2, axis=0)).mean() / dt, atol=1e-3)

In [None]:
from scipy.signal import correlate

# Compute autocorrelation for lags up to 100
n_lags = 1000
X_gen_centered = X_gen - np.mean(X_gen, axis=0)
autocorrs = np.zeros((n_lags + 1, X_gen.shape[1], X_gen.shape[2], X_gen.shape[3]))
for i in range(X_gen.shape[2]):
    for j in range(X_gen.shape[1]):
        for k_post in range(X_gen.shape[3]):
            autocorr = correlate(
                X_gen_centered[:, j, i, k_post],
                X_gen_centered[:, j, i, k_post],
                method="auto",
            )
            autocorr = autocorr[autocorr.size // 2 : autocorr.size // 2 + n_lags + 1]
            # autocorr /= (X_gen.shape[0] - np.arange(n_lags + 1))
            autocorr /= autocorr[0]  # Normalize
            autocorrs[:, j, i, k_post] = autocorr

In [None]:
lags = np.arange(n_lags + 1) * dt


def analytical_autocorrelation(lags: np.ndarray, tau: float) -> np.ndarray:
    """Computes the analytical autocorrelation function for the noise process.

    Parameters:
        lags (np.ndarray): Lag times, shape (N_lags,)
        tau (float): Temporal width of the noise process

    Returns:
        (np.ndarray): Analytical autocorrelation values, shape (N_lags,)
    """
    return np.exp(-lags / tau)  # * (X_gen.shape[0] - lags) / (X_gen.shape[0])


plt.figure(figsize=(8, 4))
mean = autocorrs.mean(axis=(1, 2, 3))
plt.plot(lags, mean, color="blue", linewidth=2, label="Simulation")
sem = autocorrs.std(axis=(1, 2, 3)) / np.sqrt(
    autocorrs.shape[1] + autocorrs.shape[2] + autocorrs.shape[3]
)
plt.fill_between(lags, mean - sem, mean + sem, color="blue", alpha=0.2)
plt.plot(
    lags,
    analytical_autocorrelation(lags, model.noise_tau),
    color="red",
    linewidth=2,
    label="Analytical",
    linestyle="--",
)
plt.title("Autocorrelation of $X_{gen}$")
plt.xlabel("Lag")
plt.ylabel("Autocorrelation")
plt.legend(loc="upper right", framealpha=1.0)
plt.grid(True)
plt.xlim(0, 0.1)
plt.show()

TODO: debug this

In [None]:
t = np.arange(X_gen.shape[0]) * dt
X_gt = analytical_solution(t, I, model.alpha)

plt.rcParams.update({"font.size": 12})
plt.figure(figsize=(6, 4), dpi=300, constrained_layout=True)
labels = [f"{round(180 / model.K * k)}°" for k in range(model.K)]
colors = plt.cm.hsv(np.linspace(0, 1, model.K, endpoint=False))  # Use twilight colormap
linestyles = ["-", "--", "-.", ":"] * ((model.K // 4) + 1)  # Cycle through styles
markers = ["o", "s", "D", "^", "v", "<", ">", "p", "*", "h"] * (
    (model.K // 10) + 1
)  # Cycle through markers

for k_post in range(model.K):
    plt.plot(
        t,
        X_gen[:, 0, 0, k_post],
        linestyle="-",
        label=f"{labels[k_post]}",
        color=colors[k_post],
        alpha=0.8,
        linewidth=2,
        marker=markers[k_post],
        markevery=(
            k_post * round(len(t) / 6 / model.K),
            round(len(t) / 6),
        ),  # Offset markers for each line
        markeredgecolor="black",  # Add black outline
        markeredgewidth=0.6,
    )
    plt.plot(
        t,
        X_gt[:, 0, 0, k_post],
        linestyle="-",
        color="k",
        alpha=0.1,
        linewidth=4,
        label="Analytical\nSolutions" if k_post == model.K - 1 else None,
    )
plt.legend(
    loc="center right",
    title="Preferred\norientation",
    framealpha=1.0,
    fontsize=8,
    title_fontsize=9,
)
plt.xlabel("Time [per time constant]")
plt.ylabel("Model response")
plt.title(
    "Response of model neurons in one hypercolumn\n"
    + r"to a bar of orientation $\theta = 90^\circ$ and contrast $\hat{I} = 2$"
)
plt.show()

## 5. Implement utility functions

### 5.a) Distance metric

In [None]:
def compute_psi(theta, K, atol=1e-6):
    """Computes distance metric for orientation angles used in V1 model

    Parameters:
        theta (np.ndarray): Angle differences (radians), shape (...), values in radians

    Returns:
        (np.ndarray): Distance metric, shape (...), values in [0, 1]
    """

    theta = np.abs(theta)
    theta = theta % np.pi  # wrap to [0, pi]
    theta = np.minimum(theta, np.pi - theta)  # wrap to [0, pi/2]

    psi = np.zeros_like(theta)

    where_theta_is_zero = np.isclose(theta, 0.0, atol=atol)
    psi[where_theta_is_zero] = 1

    where_theta_is_pi_over_K = np.isclose(theta, np.pi / K, atol=atol)
    psi[where_theta_is_pi_over_K] = 0.8

    where_theta_is_pi_over_2K = np.isclose(theta, 2 * np.pi / K, atol=atol)
    psi[where_theta_is_pi_over_2K] = 0.7

    return psi

In [None]:
K = 12
N = 1
thetas = np.linspace(-N * K, N * K, N * K + 1) * np.pi / 2 / K
psi_values = compute_psi(thetas, K, atol=1e-6)

plt.rcParams.update({"font.size": 14})
plt.figure(figsize=(8, 3), dpi=200)
plt.plot(thetas / np.pi * 180, psi_values, linewidth=1, marker="o", markersize=8)
plt.xlabel(r"angle $\theta$ [°]")
plt.ylabel(r"$\psi(\theta)$")
plt.title(r"Intracortical connection strength $\psi(\theta)$ for $K = 12$")
plt.xticks(np.arange(-90, 91, 15))
plt.xlim(-95, 95)
plt.ylim(-0.05, 1.05)
plt.grid(True)
plt.show()

### 5.b) Activation functions

In [None]:
def g_x(x, T_x=1.0):
    """Activation function for pyramidal cells

    Parameters:
        x (np.ndarray): deviation from rersting state (input current), shape (...), values in R
        T_x (float): threshold parameter, default 1.0

    Returns:
        (np.ndarray): activation values, shape (...), values in [0, inf)
    """
    g_of_x = np.zeros_like(x)
    g_of_x[T_x <= x] = (x - T_x)[T_x <= x]
    g_of_x[x > T_x + 1] = 1.0
    return g_of_x

In [None]:
x_vals = np.linspace(0, 3, 300)
y_vals = g_x(x_vals)

plt.figure(figsize=(6, 4))
plt.plot(x_vals, y_vals, label=r"$g_x(x)$", linewidth=3)
plt.xlabel("x")
plt.ylabel(r"$g_x(x)$")
plt.title("Activation function $g_x(x)$")
plt.grid(True)
plt.show()

In [None]:
def g_y(y, L_y=1.2, g_1=0.21, g_2=2.5):
    """Activation function for interneurons

    Parameters:
        y (np.ndarray): deviation from resting state (input current), shape (...), values in R
        L_y (float): threshold parameter, default 1.2
        g_1 (float): slope parameter, default 0.21
        g_2 (float): slope parameter, default 2.5

    Returns:
        (np.ndarray): activation values, shape (...), values in [0, inf)
    """
    g_of_y = np.zeros_like(y)
    g_of_y[0 <= y] = g_1 * y[0 <= y]
    g_of_y[y >= L_y] = g_1 * L_y + g_2 * (y[y >= L_y] - L_y)
    return g_of_y

In [None]:
y_range = np.linspace(-1, 3, 300)
g_y_vals = g_y(y_range)

plt.figure(figsize=(6, 4))
plt.plot(y_range, g_y_vals, label=r"$g_y(y)$", linewidth=3)
plt.xlabel("y")
plt.ylabel(r"$g_y(y)$")
plt.title("Activation function $g_y(y)$")
plt.grid(True)
plt.show()

### 5.c) normalization term

In [None]:
def I_c(I_top_down=0.0):
    """Computes normalization term of interneurons

    Parameters:
        I_top_down (float): top-down input, default 0.0

    Returns:
        (float): normalization term, value in R
    """
    return 1.0 + I_top_down


def I_o(X: np.ndarray):
    """Computes normalization term of pyramidal cells

    Parameters:
        X (np.ndarray): Current state of pyramidal cells, shape (N_y, N_x, K)

    Returns:
        (np.ndarray): normalization term, shape (N_y, N_x, 1), values in [0, inf]
    """

    g_X = g_x(X)  # shape (N_y, N_x, K)
    g_X_summed_over_K = g_X.sum(axis=-1, keepdims=True)  # shape (N_y, N_x, 1)

    # neighbors on Manhatten Grid with distance maximal 2
    neighbors = np.ones((5, 5), dtype=X.dtype)
    # neighbors[2, 2] = 0  # TODO: center included or not?
    g_X_normalized = convolve(
        g_X_summed_over_K, neighbors[:, :, np.newaxis], mode="wrap"
    )  # shape (N_y, N_x, 1)

    return 0.85 - 2.0 * (g_X_normalized / neighbors.sum()) ** 2

In [None]:
# test normalization

# 1)
X = np.zeros((9, 9, 12))
X[0, 0, 0] = 2.0

I_o_result = I_o(X)
print(
    "Analytical value:", 0.85 - 2 * (1 / 25) ** 2, "<-> Code:", I_o_result[0, 0, 0]
)  # expected value at (0, 0)

fig = plt.figure(figsize=(6, 5), constrained_layout=True)
plt.matshow(I_o_result[..., 0], cmap="viridis", vmin=0.82, vmax=0.85)
plt.grid(True, which="both", color="black", linestyle="-", linewidth=1)
plt.colorbar(label=r"$I_o(X)$", fraction=0.046, pad=0.04)
plt.plot(0, 0, marker="o", markersize=15, color="white", mew=2)
plt.title(r"X(x=0, y=0, k=0) = 2 else 0")
plt.xlabel("x")
plt.ylabel("y")
plt.show()
plt.close()

# 2)
X = np.zeros((9, 9, 12))
X[0, 0, 0] = 2
X[4, 4, 5] = 2
X[4, 4, 6] = 2

I_o_result = I_o(X)
print(
    "Analytical value:", 0.85 - 2 * (3 / 25) ** 2, "<-> Code:", I_o_result[2, 2, 0]
)  # expected value at (0, 0)

fig = plt.figure(figsize=(6, 5), constrained_layout=True)
plt.matshow(I_o_result[..., 0], cmap="viridis", vmin=0.82, vmax=0.85)
plt.grid(True, which="both", color="black", linestyle="-", linewidth=1)
plt.colorbar(label=r"$I_o(X)$", fraction=0.046, pad=0.04)
plt.plot(0, 0, marker="o", markersize=15, color="white", mew=2)
plt.plot(4, 4, marker="o", markersize=15, color="white", mew=2)
plt.title(r"also X(x=4, y=4, k={5, 6}) = 2")
plt.xlabel("x")
plt.ylabel("y")
plt.show()
plt.close()

## 6. Implement Connections

In [None]:
def plot_bars_2(A, W, l=9, r=1.3, verbose=True, dpi=500, axis=None, color="k"):
    """
    Plots a grid of bars with given angles, contrasts, and saliency (linewidths).

    Parameters:
        A (np.ndarray): 3D array of angles (radians), shape (N_y, N_x, K)
        W (np.ndarray or None): 3D array of linewidths, same shape as A
        l (float): Bar length
        r (float): Grid spacing factor
        verbose (bool): If True, show the plot
        dpi (int): Dots per inch for rendering
        axis (matplotlib axis or None): If provided, plot on this axis instead of creating a new figure.

    Returns:
        fig (matplotlib.figure.Figure): The matplotlib figure object
    """
    assert W.shape == A.shape, "C must have the same shape as A"
    N_y, N_x, K = A.shape

    # Calculate image size in pixels
    d = l * r  # grid spacing
    img_height = int(N_y * d)
    img_width = int(N_x * d)

    # Create figure
    if axis is not None:
        ax = axis
        fig = None
    else:
        fig, ax = plt.subplots(figsize=(img_width / 100, img_height / 100), dpi=dpi)
    ax.set_xlim(0, img_width)
    ax.set_ylim(0, img_height)
    ax.set_aspect("equal")  # keep x and y scales the same, avoding distortion
    ax.axis("off")

    # Draw bars
    for i in range(N_y):
        for j in range(N_x):
            for k in range(K):
                # compute center of the bar
                cx = (j + 0.5) * d  # center x-coordinate
                cy = (i + 0.5) * d  # invert y-axis for plotting
                # compute bar directions
                angle = A[i, j, k]
                dx = l * np.sin(angle) / 2
                dy = l * np.cos(angle) / 2
                # compute endpoints of the bar
                x0, y0 = cx - dx, cy - dy
                x1, y1 = cx + dx, cy + dy
                # draw the bar
                ax.plot(
                    [x0, x1],
                    [y0, y1],
                    color=color,
                    linewidth=W[i, j, k],
                    solid_capstyle="butt",
                )

    if verbose:
        plt.show()

    return fig

In [None]:
def conjuction_features():
    A = np.zeros((11, 27, 2))
    A[:, :14, 1] = np.pi / 2 * 0.5
    C = np.zeros((11, 27, 2))
    C[:, :14, :] = 1.0
    C[:, 14:, 0] = 1.0
    return A, C


A, C = conjuction_features()
plot_bars_2(A, C, l=9, r=1.3, verbose=True, dpi=500)
plt.show()

In [None]:
import matplotlib.lines as mlines


def visualize_weights(W, J, Psi, k_pres=[0, 6], K=12, dpi=200):
    N_y, N_x = W.shape[0], W.shape[1]
    A = np.linspace(0, np.pi, K, endpoint=False)
    A = A[np.newaxis, np.newaxis, :]  # shape (1, 1, K)
    A = np.broadcast_to(A, (N_y, N_x, K))  # shape (N_y, N_x, K)

    plt.rcParams.update({"font.size": 8})

    for k_pre in k_pres:  # preferred orientation of presynaptic neuron
        fig, axis = plt.subplots(figsize=(12, 5), constrained_layout=True, dpi=dpi)

        blue_line = mlines.Line2D([], [], color="tab:blue", label=r"$J$")
        plot_bars_2(
            A,
            J[:, :, :, k_pre] * 7.5,
            verbose=False,
            dpi=dpi,
            axis=axis,
            color="tab:blue",
        )

        red_line = mlines.Line2D([], [], color="tab:red", label=r"$W$")
        plot_bars_2(
            A,
            W[:, :, :, k_pre] * 7.5,
            verbose=False,
            dpi=dpi,
            color="tab:red",
            axis=axis,
        )

        green_line = mlines.Line2D([], [], color="tab:green", label=r"$\psi$")
        Psi_broadcasted = np.zeros_like(W)
        Psi_broadcasted[W.shape[0] // 2, W.shape[1] // 2, :, :] = Psi[0, 0, :, :]
        plot_bars_2(
            A, Psi_broadcasted[:, :, : , k_pre], verbose=False, dpi=dpi, color="tab:green", axis=axis
        )

        center_bar = np.zeros((N_y, N_x, K))
        center_bar[10, 10, k_pre] = 1
        black_line = mlines.Line2D([], [], color="k", label="presynaptic neuron")
        plot_bars_2(A, center_bar, verbose=False, dpi=dpi, color="k", axis=axis)

        plt.legend(
            handles=[blue_line, red_line, green_line, black_line],
            loc="upper right",
            framealpha=1.0,
        )

        plt.show()

In [None]:
def compute_W_values(d, beta, delta_theta, theta_1, theta_2):
    """Computes the entries of W according to pp. 314, "Understanding Vision" (Li Zhaoping, 2014)

    Parameters:
        d (np.ndarray): Distance between pre- and post-synaptic neurons, values in [0, inf)
        beta (np.ndarray): Angle between preferred orientation of pre-synaptic neuron and line connecting two neurons (radians), values in [0, pi]
        delta_theta (np.ndarray): Angle between preferred orientations of pre- and post-synaptic neurons (radians), values in [0, pi/2]
        theta_1 (np.ndarray): Smallest angle between preferred orientation of neuron and line connecting pre- and post-synaptic neurons (radians), values in [0, pi/2]
        theta_2 (np.ndarray): Largest angle between preferred orientation of neuron and line connecting pre- and post-synaptic neurons (radians), values in [0, pi/2]

    Returns:
        (np.ndarray): entry of W, values in R
    """
    if (
        (d > 0)
        and (d / np.cos(beta / 4)) < 10
        and (beta >= np.pi / 1.1)
        and (np.abs(theta_1) > np.pi / 11.999)
        and (delta_theta < np.pi / 3)
    ):
        # d > 0: only connected to neurons in other hypercolumns
        # d / np.cos(beta/4) < 10: elliptical shaped interaction circumfrence (elongated along line connecting two neurons)
        # delta_theta < np.pi / 3: only connections for iso-oriented neurons
        # beta >= np.pi / 1.1 and |theta_1| > np.pi / 11.999: only connections for bars not colinear, i.e. parallel but orthogonal to line connecting two neurons
        return (
            0.141
            * (1 - np.exp(-0.4 * (beta / d)**1.5))
            * np.exp(-(delta_theta / (np.pi / 4))**1.5)
        )
    else:
        return 0.0


def compute_J_values(d, beta, delta_theta, theta_1, theta_2):
    """Computes the entries of J according to pp. 314, "Understanding Vision" (Li Zhaoping, 2014)

    Parameters:
        d (np.ndarray): Distance between pre- and post-synaptic neurons, values in [0, inf)
        beta (np.ndarray): Angle between preferred orientation of pre-synaptic neuron and line connecting two neurons (radians), values in [0, pi]
        delta_theta (np.ndarray): Angle between preferred orientations of pre- and post-synaptic neurons (radians), values in [0, pi/2]
        theta_1 (np.ndarray): Smallest angle between preferred orientation of neuron and line connecting pre- and post-synaptic neurons (radians), values in [0, pi/2]
        theta_2 (np.ndarray): Largest angle between preferred orientation of neuron and line connecting pre- and post-synaptic neurons (radians), values in [0, pi/2]

    Returns:
        (np.ndarray): entry of J, values in R
    """
    if (
        (d > 0)
        and (d <= 10)
        and (
            (beta < np.pi / 2.69)
            or ((beta < np.pi / 1.1) and (np.abs(theta_2) < np.pi / 5.9))
        )
    ):
        # d > 0: only connected to neurons in other hypercolumns
        # d <= 10: circular shaped interaction circumfrence
        # |theta_2| < np.pi / 5.9: only connections for prefered orientations close to line (since |theta_1| <= |theta_2|)
        # beta < np.pi / 2.69: introduces slight asymmetry such that more bars are connected if they are colinear (i.e. along a smooth contour)
        # beta < np.pi / 1.1: unnecessary since always true if |theta_1| <= |theta_2| < np.pi / 5.9
        return 0.126 * np.exp(-(beta / d)**2 - 2*(beta / d)**7 - d**2 / 90)
    else:
        return 0.0


def compute_angle_between_bar_and_line(bar_angle, line_angle):
    """Computes the angle between a bar and a line, both defined by their angles (radians)

    Parameters:
        bar_angle (np.ndarray): Angle of the bar (radians), shape (...), values in [0, pi]
        line_angle (np.ndarray): Angle of the line (radians), shape (...), values in [0, pi]

    Returns:
        (np.ndarray): Angle between bar and line (radians), shape (...), values in [0, pi/2]
    """
    angle_diff = line_angle - bar_angle

    if angle_diff >= np.pi / 2:
        angle_diff -= np.pi

    elif angle_diff < -np.pi / 2:
        angle_diff += np.pi

    return angle_diff


def compute_connection_kernel(K=12, verbose=False) -> np.ndarray:
    """Computes intracortical connection kernels J, W and Psi, according to pp. 314, "Understanding Vision" (Li Zhaoping, 2014).
        Note: 3. dimension is post-synaptic, 4. dimension is pre-synaptic orientation channel

    Parameters:
        K (int): Number of orientation channels, default 12

    Returns:
        J (np.ndarray): 3D array of inter-hypercolumn excitatory connection kernel, shape (N_y, N_x, K, K),
        W (np.ndarray): 3D array of inter-hypercolumn inhibitory connection kernel, same shape as J
        Psi (np.ndarray): 2D array of intra-hypercolum connection kernel, shape (1, 1, K, K)

    """

    kernel_size = 10
    N_x, N_y = 2 * kernel_size + 1, 2 * kernel_size + 1
    J = np.zeros((N_y, N_x, K, K))
    W = np.zeros((N_y, N_x, K, K))
    Psi = np.zeros((1, 1, K, K))

    A = np.linspace(0, np.pi, K, endpoint=False, dtype=np.float64)  # shape (K,)
    x, y = (
        np.arange(-kernel_size, kernel_size + 1, 1, dtype=np.float64),
        np.arange(-kernel_size, kernel_size + 1, 1, dtype=np.float64),
    )
    Y, X = np.meshgrid(x, y, indexing="ij")  # shape (21, 21)

    # angle between y-axis and vector to neuron
    alphas = np.arctan2(X, Y) % np.pi  # shape (21, 21), values in [0, pi]

    if verbose:
        # for testing
        plt.imshow(np.sqrt(X**2 + Y**2), cmap="viridis", origin="lower")
        plt.colorbar(label="D", fraction=0.046, pad=0.04)
        plt.title("Distance D from center neuron")
        plt.show()

        plt.imshow(X, cmap="viridis", origin="lower")
        plt.colorbar(label="X", fraction=0.046, pad=0.04)
        plt.title("X coordinate")
        plt.show()

        plt.imshow(Y, cmap="viridis", origin="lower")
        plt.colorbar(label="Y", fraction=0.046, pad=0.04)
        plt.title("Y coordinate")
        plt.show()

        plt.imshow(
            alphas / np.pi * 180,
            cmap="viridis",
            origin="lower",
            extent=(-1, 1, -1, 1),
            vmin=0,
            vmax=180,
        )
        plt.colorbar(
            label="Angle [°]", fraction=0.046, pad=0.04, ticks=[0, 45, 90, 135, 180]
        )
        plt.title("Angle between y-axis and vector to neuron")
        plt.show()

    for k_pre in range(K):
        for k_post in range(K):
            a = np.abs(A[k_post] - A[k_pre]) % np.pi
            delta_theta = np.minimum(a, np.pi - a)

            # compute non-zero entries of Psi
            Psi[0, 0, k_post, k_pre] = compute_psi(delta_theta, K)
            if k_post == k_pre:
                Psi[0, 0, k_post, k_pre] = 0.0  # no self-connection

            for i in range(0, N_y):
                for j in range(0, N_x):
                    d = np.sqrt(X[i, j] ** 2 + Y[i, j] ** 2)
                    if d > 0:
                        # angle between y-axis and vector to neuron
                        alpha = alphas[i, j]

                        # angle between preferred orientation
                        # of centered neuron and vector to neuron
                        theta_1_dash = compute_angle_between_bar_and_line(
                            A[k_pre], alpha
                        )
                        theta_2_dash = compute_angle_between_bar_and_line(
                            A[k_post], alpha
                        )

                        # name theta_1 and theta_2 correctly
                        if np.abs(theta_1_dash) < np.abs(theta_2_dash):
                            theta_1 = theta_1_dash
                            theta_2 = theta_2_dash
                        else:
                            theta_1 = theta_2_dash
                            theta_2 = theta_1_dash

                        beta = 2 * np.abs(theta_1) + 2 * np.sin(
                            np.abs(theta_1 + theta_2)
                        )

                        # compute non-zero entries of J
                        J[i, j, k_post, k_pre] = compute_J_values(
                            d, beta, delta_theta, theta_1, theta_2
                        )

                        # compute non-zero entries of W
                        W[i, j, k_post, k_pre] = compute_W_values(
                            d, beta, delta_theta, theta_1, theta_2
                        )

    return J, W, Psi


# test connection kernels
K = 12
J, W, Psi = compute_connection_kernel(K=K, verbose=False)
visualize_weights(W, J, Psi, k_pres=[0, 3, 6], K=K, dpi=150)

In [None]:
theta_1 = np.arange(-np.pi / 2, np.pi / 2, 0.01)
theta_2 = np.arange(-np.pi / 2, np.pi / 2, 0.01)
Theta_1, Theta_2 = np.meshgrid(theta_1, theta_2, indexing="xy")
beta = 2 * np.abs(Theta_1) + 2 * np.sin(np.abs(Theta_1 + Theta_2))

im = np.abs(Theta_1) <= np.abs(Theta_2)
plt.imshow(im, cmap="hot", origin="lower", extent=(-90, 90, -90, 90), vmin=0, vmax=1)
plt.colorbar(label="Threshold", fraction=0.046, pad=0.04, ticks=[0, 1])
plt.xlabel(r"$\theta_1$ [°]")
plt.ylabel(r"$\theta_2$ [°]")
plt.show()

im = (np.abs(Theta_2) <= np.pi / 5.9) & (np.abs(Theta_1) <= np.abs(Theta_2))
plt.imshow(im, cmap="hot", origin="lower", extent=(-90, 90, -90, 90), vmin=0, vmax=1)
plt.colorbar(label="Threshold", fraction=0.046, pad=0.04, ticks=[0, 1])
plt.xlabel(r"$\theta_1$ [°]")
plt.ylabel(r"$\theta_2$ [°]")
plt.show()

im = ((np.abs(Theta_2) <= np.pi / 5.9) | (beta <= np.pi / 2.69)) & (
    np.abs(Theta_1) <= np.abs(Theta_2)
)
plt.imshow(im, cmap="hot", origin="lower", extent=(-90, 90, -90, 90), vmin=0, vmax=1)
plt.colorbar(label="Threshold", fraction=0.046, pad=0.04, ticks=[0, 1])
plt.xlabel(r"$\theta_1$ [°]")
plt.ylabel(r"$\theta_2$ [°]")
plt.show()

im = (
    ((beta <= np.pi / 1.1) & (np.abs(Theta_2) <= np.pi / 5.9)) | (beta <= np.pi / 2.69)
) & (np.abs(Theta_1) <= np.abs(Theta_2))
plt.imshow(im, cmap="hot", origin="lower", extent=(-90, 90, -90, 90), vmin=0, vmax=1)
plt.colorbar(label="Threshold", fraction=0.046, pad=0.04, ticks=[0, 1])
plt.xlabel(r"$\theta_1$ [°]")
plt.ylabel(r"$\theta_2$ [°]")
plt.show()

## 7. Full Model

In [None]:
from numba import njit, prange

@njit(parallel=True)
def summation_numba(S_padded, kernel, N_y, N_x, K, kernel_size):
    result = np.zeros((N_y, N_x, K))
    for k_post in prange(K):
        for i in range(N_y):
            for j in range(N_x):
                acc = 0.0
                for ki in range(kernel_size):
                    for kj in range(kernel_size):
                        for k_pre in range(K):
                            acc += S_padded[i+ki, j+kj, k_pre] * kernel[ki, kj, k_post, k_pre]
                result[i, j, k_post] = acc
    return result

class FullModel:
    def __init__(
        self,
        K=12,
        alpha_x=1.0,
        alpha_y=1.0,
        average_noise_height=0.1,
        average_noise_temporal_width=0.1,
        seed=None,
    ):
        """Initializes the full V1 model with pyramidal cells and interneurons

        Parameters:
            K (int): Number of orientation channels, default 12
            alpha_x (float): Time constant of pyramidal cells, default 1.0
            alpha_y (float): Time constant of interneurons, default 1.0
            average_noise_height (float): Standard deviation of noise amplitude, default 0.1
            average_noise_temporal_width (float): Average temporal width of noise, default 0.1
            seed (int or None): Random seed for noise generation, default None
        """

        self.alpha_x = alpha_x
        self.alpha_y = alpha_y

        # Precompute preferred orientations per neuron
        self.K = K
        angles = np.linspace(0, np.pi, self.K, endpoint=False)
        self.M = angles[np.newaxis, np.newaxis, :]  # shape (1, 1, K)

        self.J, self.W, self.Psi = compute_connection_kernel(K=K, verbose=False)
        self.J_o = 0.8
        self.I_o = I_o
        self.I_c = I_c
        self.g_x = g_x
        self.g_y = g_y

        # Noise parameters
        assert average_noise_height >= 0, "average_noise_height must be non-negative"
        assert average_noise_temporal_width > 0, (
            "average_noise_temporal_width must be positive"
        )
        self.noise_std = average_noise_height
        self.noise_tau = average_noise_temporal_width
        self.rng = np.random.default_rng(seed)

    def get_input(
        self, A: np.ndarray, C: np.ndarray, verbose: bool = False
    ) -> np.ndarray:
        """Computes model input from visual input

        TODO: extend to multiple input bars per location (i.e. A and C of shape (N_y, N_x, L) where L is number of input bars per location)

        Parameters:
            A (np.ndarray): 2D array of angles (radians) of input bars, shape (N_y, N_x), values in [0, pi]
            C (np.ndarray): 2D array of contrasts of input bars, same shape as A, values in [1, 4] or 0 (no bar)
            verbose (bool): If True, visualize input and output

        Returns:
            I (np.ndarray): 3D array of model input, shape (N_y, N_x, K)

        """
        assert A.ndim == 2, "A must be a 2D array"
        assert C.shape == A.shape, "C must have the same shape as A"
        N_y, N_x = A.shape
        M = np.broadcast_to(self.M, (N_y, N_x, self.K))
        M = M % np.pi  # ensure M in [0, pi]

        A = A % np.pi  # ensure A in [0, pi]
        A = A[:, :, np.newaxis]  # shape (N_y, N_x, 1)
        C = C[:, :, np.newaxis]  # shape (N_y, N_x, 1)

        I = C * tuning_curve(A - M)

        if verbose:
            visualize_input(A, C, verbose=True)
            for k in range(I.shape[2]):
                print("==================================")
                print(f"Neurons {k}, tuned to {M[0, 0, k] / np.pi * 180}°")
                visualize_output(M[:, :, k], I[:, :, k], verbose=True)

        return I

    def update_noise(
        self, I_noise: np.ndarray, noise_duration: np.ndarray
    ) -> tuple[np.ndarray, np.ndarray]:
        """Defines noise distributions and updates the noise input I_noise based on the remaining noise duration

        Parameters:
            I_noise (np.ndarray): Current noise input, shape (N_y, N_x, K)
            noise_duration (np.ndarray): Remaining duration of current noise input, same shape as I_noise

        Returns:
            I_noise (np.ndarray): Updated noise input, same shape as input
            noise_duration (np.ndarray): Updated remaining duration of current noise input, same shape as I _noise
        """

        # amplitude follows normal distribution
        I_noise[noise_duration <= 0] = self.rng.normal(
            0, self.noise_std, size=I_noise.shape
        )[noise_duration <= 0]

        # temporal width follows exponential distribution,
        noise_duration[noise_duration <= 0] = self.rng.exponential(
            self.noise_tau, size=noise_duration.shape
        )[noise_duration <= 0]

        return I_noise, noise_duration

    def summation(self, kernel: np.ndarray, S: np.ndarray, mode: str = "symmetric") -> np.ndarray:
        """Computes the sum over the connection kernel for each post-synaptic neuron

        Parameters:
            S (np.ndarray): 3D array of shape (N_y, N_x, K)
            kernel (np.ndarray): 4D array of shape (N_y_k, N_x_k, K, K)
            mode (str): Padding mode, e.g. 'symmetric' or 'wrap', default 'symmetric'

        Returns:
            result (np.ndarray): 3D array of shape (N_y, N_x, K) """

        # wrap pad input S
        kernel_size = kernel.shape[0] # assert square kernel
        delta_pad = (kernel_size - 1) // 2 # assert odd kernel size
        S_padded = np.pad(S, ((delta_pad, delta_pad), (delta_pad, delta_pad), (0, 0)), mode=mode)

        N_y, N_x = S.shape[0], S.shape[1]
        # start_time_convolution = time.time()
        
        # channels = [] # channel = application for all post-synaptic neurons with same preferred orientation
        # for k_post in range(self.K):
        #     output = np.zeros((N_y, N_x)) # shape (N_y, N_x)
            
        #     # translation-invariant application across space
        #     for i in range(N_y):
        #         for j in range(N_x):
        #             output[i, j] = np.sum(
        #                 S_padded[
        #                     i : i + kernel_size, j : j + kernel_size, :
        #                 ]  # shape (N_y_k, N_x_k, K)
        #                 * kernel[:, :, k_post, :],  # shape (N_y_k, N_x_k, K)
        #             )
                    
        #     channels.append(output)
        # # combine all channels
        # result = np.stack(channels, axis=-1) # shape (N_y, N_x, K)
        
        result = summation_numba(S_padded, kernel, N_y, N_x, self.K, kernel_size)

        # end_time_convolution = time.time()
        # print(f"Convolution time: {end_time_convolution - start_time_convolution:.4f} seconds")

        return result

    def derivative(
        self, X: np.ndarray, Y: np.ndarray, I: np.ndarray, I_top_down: float = 0.0, mode: str = "symmetric"
    ) -> np.ndarray:
        """Computes the derivative dX/dt and dY/dt of the model at current state X and Y given input I

        Parameters:
            X (np.ndarray): Current state of pyramidal cells, shape (N_y, N_x, K)
            Y (np.ndarray): Current state of interneurons, shape (N_y, N_x, K)
            I (np.ndarray): Input, shape (N_y, N_x, K)

        Returns:
            dXdt (np.ndarray): Derivative of pyramidal cells state, same shape as X
            dYdt (np.ndarray): Derivative of interneurons state, same shape as Y

        """
        g_X = self.g_x(X)
        g_Y = self.g_y(Y)
        
        # start_time_derivative = time.time()
        dXdt = (
            - self.alpha_x * X
            - g_Y
            - self.summation(self.Psi, g_Y, mode=mode)
            + self.J_o * g_X
            + self.summation(self.J, g_X, mode=mode)
            + I
            + self.I_o(X) # activated within, TODO: modular implementation
        )

        dYdt = (
            - self.alpha_y * Y
            + g_X
            + self.summation(self.W, g_X, mode=mode)
            + self.I_c(I_top_down)
        )
        # end_time_derivative = time.time()
        # print(f"Derivative computation time: {end_time_derivative - start_time_derivative:.4f} seconds")

        return dXdt, dYdt

    def euler_method(self, I: np.ndarray, dt: float, T: float, noisy: bool = True, mode: str = "symmetric") -> np.ndarray:
        """Simulates the model over time given input I

        Parameters:
            I (np.ndarray): Input, shape (N_y, N_x, K)
            dt (float): Time step
            T (float): Total simulation time

        Returns:
            X (np.ndarray): Final pyramidal state after simulation, shape (T, N_y, N_x, K)
        """
        N_y, N_x, K = I.shape
        steps = int(T / dt) # + 1 for initial condition

        # Pyramidal cells state over time
        X = np.zeros((steps, N_y, N_x, K))

        # Interneuon state over time
        Y = np.zeros((steps, N_y, N_x, K))

        # Noise initialization
        if noisy:
            I_noise = np.zeros((N_y, N_x, K, 2))  # last dim: 0: noise for X, 1: noise for Y
            noise_duration = np.zeros((N_y, N_x, K, 2))
            
            # # initial state = random
            # I_noise, noise_duration = self.update_noise(I_noise, noise_duration)
            # X[0] += I_noise[..., 0] * dt
            # Y[0] += I_noise[..., 1] * dt

        # Time integration using Euler method
        update_steps = int(0.05 / dt)  
        with tqdm(total=steps, desc="Simulating", unit="step") as pbar:
            for t in range(1, steps):
                # start_time_step = time.time()
                dXdt, dYdt = self.derivative(X[t - 1], Y[t - 1], I, mode=mode)
                X[t] = X[t - 1] + dt * dXdt
                Y[t] = Y[t - 1] + dt * dYdt

                if noisy:
                    # add noise
                    noise_duration -= dt
                    I_noise, noise_duration = self.update_noise(I_noise, noise_duration)
                    X[t] += I_noise[..., 0] * dt
                    Y[t] += I_noise[..., 1] * dt
                
                # Update progress bar every 0.05 seconds of simulated time
                if (t >= 1) and ((t-1) % update_steps == 0):
                    pbar.update(update_steps)
                
                # end_time_step = time.time()
                # print(f"Time step {t}/{steps} computation time: {end_time_step - start_time_step:.4f} seconds")

        return X, Y

    def simulate(
        self,
        A: np.ndarray,
        C: np.ndarray,
        dt: float = 0.001,
        T: float = 12.0,
        verbose: bool = False,
        noisy: bool = True,
        mode: str = "symmetric",
    ) -> np.ndarray:
        """Runs the full simulation given angles A and contrasts C

        Parameters:
            A (np.ndarray): 2D array of angles (radians) of input bars, shape (N_y, N_x), values in [0, pi]
            C (np.ndarray): 2D array of contrasts of input bars, same shape as A, values in [1, 4] or 0 (no bar)
            dt (float): Time step
            T (float): Total simulation time
            verbose (bool): If True, visualize input; default False
            noisy (bool): If True, add noise to the simulation; default True
            mode (str): boundary condition of simulation (see np.pad); default "symmetric"

        Returns:
            X (np.ndarray): Final state after simulation, shape (T, N_y, N_x, K)
        """
        I = self.get_input(A, C, verbose=verbose)
        X, Y = self.euler_method(I, dt, T, noisy=noisy, mode=mode)
        return X, Y, I
    
# Instantiate and test the FullModel
seed=42
model = FullModel(seed=seed, alpha_x = 2., alpha_y = 2.)
A, C = neighboring_textures(22, 60)

T = 2.0
dt = 0.001
X_gen, Y_gen, I = model.simulate(A, C, dt=dt, T=T, verbose=False, noisy=True, mode="wrap")

In [None]:
# Time points in seconds to plot
time_points = np.array([0.7, 0.9, 1.2, 1.4, 1.8])
steps = [int(t / dt) for t in time_points]

# average of input across columns for the orientation of the target bar
X_per_column = np.concatenate(
    [g_x(X_gen[:, :, :30, 6]).mean(axis=1), g_x(X_gen[:, :, 30:, 0]).mean(axis=1)], axis=1
)
# Plot input
visualize_input(A, C, verbose=False)
visualize_output(A, I.max(axis=-1), verbose=True)
# n_cols = A.shape[1]
# plt.xticks(
#     np.arange(0, n_cols, 1),
#     labels=[f"{i}" for i in range(1, n_cols + 1)]
# )
# # plt.tick_params(axis='x', length=0)  # hide tick marks
# plt.show()

# Plot model dynamics, see fig. 5.21 in "Understanding Vision" (Li Zhaoping, 2014)
plt.figure(figsize=(10, 6))
x_axis = np.arange(X_per_column.shape[1])  # column indices

for t_idx, step in zip(time_points, steps):
    # Neural response: sum over orientation channels (axis=-1)
    response = X_per_column[step]
    plt.plot(x_axis, response, label=f"t={t_idx:.1f}s")

avg_response = X_per_column.mean(axis=0)
plt.plot(
    x_axis, avg_response, label="Temporal avg", linewidth=3, linestyle="--", color="k"
)

plt.xlabel("Texture column number")
plt.ylabel("Neural response")
plt.legend()
plt.grid(True)
plt.show()

- too fast?
- not as smooth?

In [None]:
# replicate fig. 5.18 in "Understanding Vision" (Li Zhaoping, 2014)

seed = 0
model = FullModel(seed=seed)
T = 12.0
dt = 0.001
N_y_test, N_x_test = 9 + 2 * 10, 9 + 2 * 10

test_cases = {
    "A: Bar without\nsurround": bar_without_surround,
    "B: Iso-\norientation": iso_orientation,
    "C: Random\nbackground": random_background,
    "D: Cross-\norientation": cross_orientation,
    "E: Bar without\nsurround": bar_without_surround_low_contrast,
    "F: With one\nflanker": with_one_flanker,
    "G: With two\nflankers": with_two_flankers,
    "E: With flanking\nline and noise": with_flanking_line_and_noise,
}

# # generate model response for all test cases
input_and_outputs = {}
for title, func in test_cases.items():
    # create input images
    A_in, C_in = func(N_y = N_y_test, N_x = N_x_test) # A, C shape (N_y, N_x)
    
    # simulate model
    X, _, _ = model.simulate(A_in, C_in, dt=dt, T=T, verbose=False, noisy=True, mode="wrap")
    model_output = g_x(X).mean(axis=0)  # N_y x N_x x K
    C_out = model_output.max(axis=-1) # N_y x N_x
    argmax_angle_indices = model_output.argmax(axis=-1) # N_y x N_x
    A_out = np.pi / model.K * argmax_angle_indices # N_y x N_x
    
    input_and_outputs[title] = (A_in, C_in, A_out, C_out)

# import concurrent.futures

# def run_test_case(args):
#     # create input images
#     title, N_y_test, N_x_test, dt, T = args
#     func = test_cases[title]
#     A_in, C_in = func(N_y=N_y_test, N_x=N_x_test)
    
#     # simulate model
#     model = FullModel()
#     X, _, _ = model.simulate(A_in, C_in, dt=dt, T=T, verbose=False, noisy=True, mode="wrap")
#     model_output = g_x(X).mean(axis=0)
#     C_out = model_output.max(axis=-1)
#     argmax_angle_indices = model_output.argmax(axis=-1)
#     A_out = np.pi / model.K * argmax_angle_indices
    
#     return title, (A_in, C_in, A_out, C_out)

# # Prepare arguments for each test case
# args_list = [
#     (title, N_y_test, N_x_test, dt, T)
#     for title in test_cases.keys()
# ]

# input_and_outputs = {}
# with concurrent.futures.ThreadPoolExecutor() as executor:
#     # Submit all tasks
#     futures = [executor.submit(run_test_case, args) for args in args_list]
#     # Collect results as they complete
#     for future in concurrent.futures.as_completed(futures):
#         key, result = future.result()
#         input_and_outputs[key] = result

In [None]:
output_to_input_bar_width_ratio = 1.3

# compute image size
dpi = 500
N_y, N_x = 9, 9
l = 9
r = 1.3
d = l * r  # grid spacing
img_height = int(N_y * d)
img_width = int(N_x * d)

# set up figure
plt.rcParams.update({"font.size": 6})
fig, axes = plt.subplots(
    4,
    4,
    figsize=(img_width / 100 * 4, img_height / 100 * 4),
    dpi=dpi,
    constrained_layout=True,
)

for index, (title, (A_in, C_in, A_out, C_out)) in enumerate(input_and_outputs.items()):
    x_index = index % 4
    y_setoff = index // 4
    
    # crop images to 9 x 9 central region
    N_y_test, N_x_test = A_in.shape
    y_mid, x_mid = int((N_y_test - 1) / 2), int((N_x_test - 1) / 2) # assert odd dimensions
    delta = int((N_y - 1) / 2) # assert odd dimensions
    A_in = A_in[y_mid - delta : y_mid + delta + 1, x_mid - delta : x_mid + delta + 1]
    C_in = C_in[y_mid - delta : y_mid + delta + 1, x_mid - delta : x_mid + delta + 1]
    A_out = A_out[y_mid - delta : y_mid + delta + 1, x_mid - delta : x_mid + delta + 1]
    C_out = C_out[y_mid - delta : y_mid + delta + 1, x_mid - delta : x_mid + delta + 1]
    
    # plot input
    top_axis = axes[y_setoff * 2, x_index]
    top_axis.set_title(title)
    visualize_input(A_in, C_in, verbose=False, axis=top_axis, dpi=dpi)
    if x_index == 0:
        # show axis again
        top_axis.axis("on")
        # hide all elements of axis except y label
        for spine in top_axis.spines.values():
            spine.set_visible(False)
        top_axis.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
        top_axis.set_xlabel('')
        # set y label
        top_axis.set_ylabel(r"Input $\hat{I}_{i\theta}$")
    
    # plot output
    bottom_axis = axes[2 * y_setoff + 1, x_index]
    visualize_output(A_out, output_to_input_bar_width_ratio * C_out, verbose=False, axis=bottom_axis, dpi=dpi)
    print(C_out[4, 4])
    if x_index == 0:        
        # show axis again
        bottom_axis.axis("on")
        # hide all elements of axis except y label
        for spine in bottom_axis.spines.values():
            spine.set_visible(False)
        bottom_axis.tick_params(axis='both', which='both', length=0, labelbottom=False, labelleft=False)
        bottom_axis.set_xlabel('')
        # set y label
        bottom_axis.set_ylabel("Model\noutput " + r"$g_x(x_{i\theta})$")

plt.show()

- colinear facilitation increases for smaller bar contrast our less, and for higher bar contrast more; maybe even total increase higher?
- random background supresses bar less, mabe because of specific random background used her (slighly colinear activated possibly)?