## 1. Notation and key variables

We simulate **two populations** (Left and Right), each with **K neurons** arranged on a ring.

- `s(t) ∈ R^{2K}`: network state (activity), concatenated as `[s_L(t), s_R(t)]`
- `s_L(t) ∈ R^{K}`: left ring activity
- `s_R(t) ∈ R^{K}`: right ring activity

At each step we compute:

- **Velocity gains** *(scalars)*:
  - `v_L(t) = 1 - β_vel * v(t)`
  - `v_R(t) = 1 + β_vel * v(t)`

- **Recurrent input** *(vectors in R^K)*:
  - `g_LL = W_LL s_L`, `g_LR = W_LR s_R`
  - `g_RR = W_RR s_R`, `g_RL = W_RL s_L`

- **Pre-activation current** *(vectors in R^K)*:
  - `G_L = v_L * (g_LL + g_LR + FF_global) + landmark_input`
  - `G_R = v_R * (g_RR + g_RL + FF_global) + landmark_input`
  - `G = [G_L, G_R] ∈ R^{2K}`

- **Activation / firing rate** *(ReLU)*:
  - `F = max(G, 0)`

- **Synaptic low-pass update**:
  - `s(t+dt) = s(t) + (F - s(t)) * dt / τ_s`

### Bump center (phase)
The CAN forms a localized “bump” on the ring. We track its center as a **scalar index**:
- `nn_state[t]`: estimated bump center neuron index (0…K-1 in the code below)

This tracked `nn_state` is the **1D path-integrated variable** used for:
- stopping the trial when the bump reaches an `end_phase`
- triggering internal landmark inputs when the bump crosses learned landmark phases


## 2. Algorithms

### 2.1 Initialization (`init_state`)  [BF09-like + small static bump]
Goal: let the ring settle into a stable bump attractor.

**Algorithm**
1. Create a small noisy velocity input (constant over the init block).
2. Add a weak static Gaussian “seed” input at a chosen ring location (just to pick a phase).
3. Integrate the CAN dynamics for `T_init` seconds.
4. Return the final state `s_init ∈ R^{2K}`.

### 2.2 Trial simulation (`run_trial`)  [BF09-like + optional NFJ24 landmark resets]
Goal: simulate one timing/path-integration trial.

**Algorithm**
1. Initialize `s[:,0] = s_init`.
2. Draw one noisy scalar velocity `v` for the whole trial (Weber noise).
3. For each timestep:
   - compute recurrent inputs and pre-activation currents
   - if internal landmarks are enabled:
     - check which landmark phase has been passed
     - if it is the first entry, store entry time
     - compute a time-envelope amplitude `amp(dt)` (Gaussian in time since entry)
     - inject a spatial Gaussian bump centered at that landmark (or slightly shifted)
   - apply ReLU and low-pass update
   - track the bump center via a local-maximum search around the previous center
4. Stop when the bump phase reaches the stopping threshold.

Outputs per trial:
- `nn_state(t)` trajectory
- trial duration / reaction time `RT = (#steps)*dt`
- optional state matrix `s` for debugging/visualization


In [None]:
# Imports
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Make sure we can import the local modules in this notebook environment
sys.path.insert(0, "/mnt/data")

from CAN_network import CANNetwork
from CAN_simulator import CANSimulator

np.random.seed(0)  # reproducibility for the notebook figures


## 3. Build the CAN network (parameters)

We mirror the reference MATLAB defaults:

- `dt = 1/2000 s`
- `tau_s = 40 ms`
- Mexican-hat kernel parameters as in the reference code
- Asymmetric shifts: `W_LL` shifted by +1, `W_RR` shifted by -1

> We keep `K=364` to match their ring size.  
> When we want to interpret phases in degrees (0–360), we’ll use helper conversions.


In [None]:
# Network + simulator
K = 364
net = CANNetwork(K=K)           # architecture + weights
sim = CANSimulator(net)         # dynamics + landmark logic

print("K =", net.K, "dt =", net.dt, "tau_s =", net.tau_s)


In [None]:
# Helper conversions between (0..K-1) indices and degrees (0..360)
def idx_to_deg(idx, K):
    return (np.asarray(idx) * 360.0) / K

def deg_to_idx(deg, K):
    return (np.asarray(deg) * K) / 360.0

# Sanity check
print("Index 0 -> deg:", idx_to_deg(0, K))
print("Deg 180 -> idx:", deg_to_idx(180, K))


## 4. Initialization sanity check: do we get a stable bump?

We run `init_state()` and plot the final activity in both populations.
You should see:
- a localized bump in each ring
- similar shapes, typically with a small constant shift between L and R (expected in this architecture)


In [None]:
init = sim.init_state(T=10.0)

plt.figure(figsize=(8,3))
plt.plot(np.arange(K), init[:K], label="L")
plt.plot(np.arange(K), init[K:], label="R")
plt.title("Initialization: stable bump attractor (two populations)")
plt.xlabel("Neuron index (0..K-1)")
plt.ylabel("Activity")
plt.legend()
plt.tight_layout()
plt.show()

print("Peak idx L:", int(np.argmax(init[:K])), "Peak idx R:", int(np.argmax(init[K:])))


## 5. Landmark input profile sanity check

`generate_landmark_input(centers, std, ampl_scaling)` returns a *vector of length K*,
with Gaussian bump(s) centered at the given ring indices.


In [None]:
# One landmark bump
lm = sim.generate_landmark_input(centers=[100], std=5.0, ampl_scaling=10.0)

plt.figure(figsize=(8,2.5))
plt.plot(np.arange(K), lm)
plt.title("Example landmark input (single Gaussian bump)")
plt.xlabel("Neuron index (0..K-1)")
plt.ylabel("Landmark drive")
plt.tight_layout()
plt.show()


## 6. Single-trial trajectories: with vs without internal landmarks (Fig. 4c-style)

We simulate two single trials:
- `landmarkpresent=False`: pure integrator (drift accumulates)
- `landmarkpresent=True`: internal landmark correction (resets/corrections at learned phases)

We define internal landmark phases as in the MATLAB demo:
- phases at 60°, 120°, 180°, 240°, 300° on a 360° cycle

We convert those degrees to ring indices (0..K-1).


In [None]:
# Landmark phases in degrees (paper / MATLAB style)
lm_deg = np.array([60, 120, 180, 240, 300], dtype=float)
lm_idx = deg_to_idx(lm_deg, K)     # indices in 0..K
print("Landmarks deg:", lm_deg)
print("Landmarks idx:", lm_idx)

# One common init state for fair comparison
init = sim.init_state(T=10.0)

trial_wo = sim.run_trial(
    init_condition=init,
    initial_phase=int(deg_to_idx(30, K)),
    end_phase=float(deg_to_idx(360, K)),
    landmarkpresent=False,
    landmark_input_loc=lm_idx,
    wolm_speed=0.35,
    wlm_speed=0.42,
    wm=0.05,
)

trial_w = sim.run_trial(
    init_condition=init,
    initial_phase=int(deg_to_idx(30, K)),
    end_phase=float(deg_to_idx(360, K)),
    landmarkpresent=True,
    landmark_input_loc=lm_idx,
    wolm_speed=0.35,
    wlm_speed=0.42,
    wm=0.05,
)

# Plot trajectories in degrees
t_wo = np.arange(len(trial_wo["nn_state"])) * net.dt
t_w  = np.arange(len(trial_w["nn_state"]))  * net.dt

plt.figure(figsize=(9,3))
plt.plot(t_wo, idx_to_deg(trial_wo["nn_state"], K), label="No internal landmarks")
plt.plot(t_w,  idx_to_deg(trial_w["nn_state"], K),  label="With internal landmarks")
plt.axhline(360, linestyle="--")
plt.title("Single-trial bump phase trajectory")
plt.xlabel("Time (s)")
plt.ylabel("Phase (deg)")
plt.legend()
plt.tight_layout()
plt.show()

RT_wo = len(trial_wo["nn_state"]) * net.dt
RT_w  = len(trial_w["nn_state"])  * net.dt
print("RT no landmarks:", RT_wo, "s")
print("RT with landmarks:", RT_w, "s")


## 7. Reproducing Fig. 4c: multiple example trajectories

Panel 4c shows a handful of example trajectories (with reset vs without reset).
Here we:
- simulate many trials for each condition
- plot a small subset (e.g., 4) of trajectories per condition


In [None]:
def run_many_trials(num_trials, landmarkpresent, seed=0, **kwargs):
    rng = np.random.default_rng(seed)
    # We'll vary np.random via global state for simplicity; reset seed for reproducible batches
    np.random.seed(seed)

    init = sim.init_state(T=10.0)

    trials = []
    RTs = []
    for i in range(num_trials):
        out = sim.run_trial(init_condition=init, landmarkpresent=landmarkpresent, **kwargs)
        trials.append(out)
        RTs.append(len(out["nn_state"]) * net.dt)
    return trials, np.array(RTs)

common_kwargs = dict(
    initial_phase=int(deg_to_idx(30, K)),
    end_phase=float(deg_to_idx(360, K)),
    landmark_input_loc=lm_idx,
    wolm_speed=0.35,
    wlm_speed=0.42,
    wm=0.05,
    T_max=60.0,
)

trials_w, RT_w = run_many_trials(50, True, seed=1, **common_kwargs)
trials_wo, RT_wo = run_many_trials(50, False, seed=2, **common_kwargs)

print("RT with landmarks: mean", RT_w.mean(), "std", RT_w.std())
print("RT no landmarks:  mean", RT_wo.mean(), "std", RT_wo.std())


In [None]:
# Plot a few example trajectories (Fig. 4c-like)
n_show = 4
plt.figure(figsize=(10,3))

for i in range(n_show):
    traj = idx_to_deg(trials_w[i]["nn_state"], K)
    t = np.arange(len(traj)) * net.dt
    plt.plot(t, traj, linewidth=1)

plt.title("Example trajectories WITH internal landmarks (reset-like)")
plt.xlabel("Time (s)")
plt.ylabel("Phase (deg)")
plt.tight_layout()
plt.show()

plt.figure(figsize=(10,3))
for i in range(n_show):
    traj = idx_to_deg(trials_wo[i]["nn_state"], K)
    t = np.arange(len(traj)) * net.dt
    plt.plot(t, traj, linewidth=1)

plt.title("Example trajectories WITHOUT internal landmarks")
plt.xlabel("Time (s)")
plt.ylabel("Phase (deg)")
plt.tight_layout()
plt.show()


## 8. Reproducing Fig. 4d: RT variance across target durations

The paper reports mean and std of reaction time (RT) for multiple base times (0.65s, 1.3s, ...).

In the CAN simulation, **RT** is the time it takes for the bump to reach the stopping phase.
To study variability vs duration, we need to define a family of tasks with different target durations.

A simple way to do this (and keep the dynamics identical) is:
- pick a baseline speed per condition
- **calibrate** a speed scaling so that the mean RT matches the desired base time
- then measure the RT variability (std) across trials at that base time

Below we implement a lightweight calibration:
- for each target base time, we search a scalar `v_base` such that the mean RT (no-landmark condition) is close to that target
- then we use the same `v_base` for the landmark condition to compare variance reduction

This is not the only choice, but it produces the *qualitative* Fig. 4d comparison: variance with internal landmarks < variance without.


In [None]:
def mean_rt_for_vbase(v_base, landmarkpresent, n_trials=30, seed=0):
    np.random.seed(seed)
    init = sim.init_state(T=10.0)
    RTs=[]
    for i in range(n_trials):
        out = sim.run_trial(
            init_condition=init,
            initial_phase=int(deg_to_idx(30, K)),
            end_phase=float(deg_to_idx(360, K)),
            landmarkpresent=landmarkpresent,
            landmark_input_loc=lm_idx,
            wolm_speed=v_base,   # we pass v_base in both slots; run_trial picks based on landmarkpresent
            wlm_speed=v_base,
            wm=0.05,
            T_max=60.0
        )
        RTs.append(len(out["nn_state"]) * net.dt)
    return float(np.mean(RTs)), float(np.std(RTs))

def calibrate_vbase(target_time, v_low=0.05, v_high=1.0, n_iter=12, seed=0):
    # We assume: larger v_base -> faster bump -> smaller RT (monotonic)
    lo, hi = v_low, v_high
    for _ in range(n_iter):
        mid = 0.5*(lo+hi)
        mean_rt, _ = mean_rt_for_vbase(mid, landmarkpresent=False, n_trials=25, seed=seed)
        if mean_rt > target_time:
            # too slow -> increase v
            lo = mid
        else:
            # too fast -> decrease v
            hi = mid
    return 0.5*(lo+hi)

base_times = np.array([0.65, 1.3, 1.95, 2.6, 3.25], dtype=float)

rows=[]
for bt in base_times:
    v_star = calibrate_vbase(bt, v_low=0.05, v_high=2.0, n_iter=10, seed=int(bt*100))
    m_wo, s_wo = mean_rt_for_vbase(v_star, landmarkpresent=False, n_trials=50, seed=int(bt*1000)+1)
    m_w,  s_w  = mean_rt_for_vbase(v_star, landmarkpresent=True,  n_trials=50, seed=int(bt*1000)+2)
    rows.append((bt, v_star, m_w, s_w, m_wo, s_wo))

df = pd.DataFrame(rows, columns=["base_time", "v_base_calibrated", "meanRT_withReset", "stdRT_withReset", "meanRT_woReset", "stdRT_woReset"])
df


In [None]:
# Fig. 4d-like plot: mean RT ± std vs base_time
plt.figure(figsize=(7,4))
plt.errorbar(df["base_time"], df["meanRT_withReset"], yerr=df["stdRT_withReset"], capsize=3, label="with internal landmarks")
plt.errorbar(df["base_time"], df["meanRT_woReset"],  yerr=df["stdRT_woReset"],  capsize=3, label="without internal landmarks")
plt.xlabel("Target base time (s)")
plt.ylabel("Produced time (s): mean ± std")
plt.title("Fig. 4d-like: RT variability vs target time")
plt.legend()
plt.tight_layout()
plt.show()


## 9. What you should save for the report / reproducibility

For Fig. 4c reproduction:
- save a few trajectories per condition: `time`, `nn_state` (converted to degrees)

For Fig. 4d reproduction:
- save `base_time`, `meanRT_withReset`, `stdRT_withReset`, `meanRT_woReset`, `stdRT_woReset`

The `df` table above already matches that structure.
