# CoDA in Sun et al. (2025) 2ACDC near/far task — reproducing Fig. 4c (bonus), Fig. 4i and Fig. 4j

**What this notebook does**
- Builds the **2ACDC near/far** stimulus sequences exactly as used for the CSCG simulations in Sun et al. (2025).
- Runs **CoDA** (Contingency‑Dependent State Augmentation) on those sequences, using contextual eligibility traces and split/merge rules.
- Computes **near–far cross‑correlation matrices** from the model’s internal latent states over learning, then
  - **Fig. 4i analogue:** summarizes the final matrix by the same **key blocks** (off‑diagonal, pre‑R2, pre‑R1),
  - **Fig. 4j analogue:** measures the **time to reach a correlation threshold (0.3)** in each block to recover the decorrelation order (off‑diag → pre‑R2 → pre‑R1).
- **Bonus (Fig. 4c analogue):** plots CoDA’s final latent-state transition graph for the two branches.

> Citations for task definition and evaluation scheme are included inline below.

## References used to match Sun et al.'s environment

- **Task & CSCG simulation details** (symbol sequences, key blocks, correlation threshold): see _Learning produces an orthogonalized state machine in the hippocampus_, Sun et al., **Nature** (2025), Fig. 4 and Methods. We use their near/far sequences and the same 0.3 correlation threshold and three “key regions/blocks.”
- **CoDA algorithm** (contextual eligibility traces; split/merge via prospective/retrospective contingency): see Yoo et al., _Contingency‑dependent state augmentation as a normative learning rule for non‑Markovian tasks_ (paper provided).

> If you have the CSCG repository and source data available locally, there are optional hooks below to overlay animal/CSCG bars, but the core of this notebook produces **CoDA‑only** versions of Fig. 4i and 4j.

In [None]:

# ==== Imports & basic setup ====
import math, random, itertools, collections
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx

np.set_printoptions(suppress=True, linewidth=120)
random.seed(0); np.random.seed(0)


In [None]:

# (Optional) Use your own CoDA implementation if present
# If you have the CoDA zip (e.g., coda_minigrid_project_coda-fixed.zip) next to this notebook, uncomment and run:
# import os, zipfile, sys
# zip_path = "coda_minigrid_project_coda-fixed.zip"  # adjust if needed
# if os.path.exists(zip_path):
#     with zipfile.ZipFile(zip_path, 'r') as zf:
#         zf.extractall("coda_repo")
#     sys.path.insert(0, os.path.abspath("coda_repo"))
#     try:
#         from coda import CoDAAgent as RepoCoDAAgent
#         print("Using CoDAAgent from your repository (RepoCoDAAgent).")
#     except Exception as e:
#         print("Could not import RepoCoDAAgent from your repo:", e)
# else:
#     print("No local CoDA zip found; using the minimal CoDA defined in this notebook.")


## Sun et al. (2025) 2ACDC near/far sequences and block definitions

Following Sun et al. Methods (CSCG section), we use the exact discrete sequences of **sensory symbols** (each element is a 10 cm segment):

- **Symbol legend**: `1` = grey corridor, `2` = near indicator (Ind_near), `3` = far indicator (Ind_far), `4` = visual cue at near reward zone (R1_vis), `5` = visual cue at far reward zone (R2_vis), `6` = water reward (shared), `7` = brick wall (end), `0` = teleportation.

- **Near trial**: `[1×6, 2×4, 1×3, 4, 6, 1×3, 5×2, 1×2, 7, 0×3]`  
  Expanded: `[1,1,1,1,1,1, 2,2,2,2, 1,1,1, 4, 6, 1,1,1, 5,5, 1,1, 7, 0,0,0]`

- **Far trial**: `[1×6, 3×4, 1×3, 4×2, 1×3, 5, 6, 1×2, 7, 0×3]`  
  Expanded: `[1,1,1,1,1,1, 3,3,3,3, 1,1,1, 4,4, 1,1,1, 5, 6, 1,1, 7, 0,0,0]`

**Key blocks (Fig. 4i/j):**
- `preR1` = the grey region between indicator and **visual R1** (`4`): indices `10..12` (inclusive) in both sequences.
- `preR2` = the grey region between (R1_vis/water) and **visual R2** (`5`): indices `15..17` in both sequences.
- `offdiag` = the two off‑diagonal cross‑blocks between `preR1` (near) × `preR2` (far) **and** `preR2` (near) × `preR1` (far).

In [None]:

# ==== Task sequences & index helpers ====
near = [1,1,1,1,1,1, 2,2,2,2, 1,1,1, 4, 6, 1,1,1, 5,5, 1,1, 7, 0,0,0]
far  = [1,1,1,1,1,1, 3,3,3,3, 1,1,1, 4,4, 1,1,1, 5, 6, 1,1, 7, 0,0,0]
assert len(near)==len(far)==26

# Indices for key blocks (inclusive ranges)
preR1_idx = list(range(10,13))   # 10,11,12
preR2_idx = list(range(15,18))   # 15,16,17

def block_indices(rows, cols):
    return [(r,c) for r in rows for c in cols]

offdiag_pairs     = block_indices(preR1_idx, preR2_idx) + block_indices(preR2_idx, preR1_idx)
same_preR1_pairs  = block_indices(preR1_idx, preR1_idx)
same_preR2_pairs  = block_indices(preR2_idx, preR2_idx)

symbols = sorted(set(near) | set(far))
symbols


## Minimal CoDA implementation (tabular, with contextual eligibility traces)

We implement CoDA following the paper’s pseudocode:

- Contextual eligibility traces (reward‑conditioned) to compute a **prospective contingency** for each state, `P(US | CS)`.
- **Split rule:** when `max_u P(u|s) > theta_split`, mark `s` as a cue for outcome `u` and **clone** all successors so that downstream states can encode the cue‑dependent history.
- **Merge rule:** included for completeness (via `P(CS | US)` and the utility product), but it is not expected to trigger in this fixed environment.
- **Multi‑US generalization:** for the 2ACDC task, we treat the two **visual reward zone cues** as the salient outcomes (`US ∈ {4 (R1_vis), 5 (R2_vis)}`), exactly because the **indicator** (`2` or `3`) perfectly predicts which **visual** reward cue will occur later (this avoids the trivial case where water `6` follows every state).

Default parameters mirror the paper: `lambda=0.8`, `gamma=0.9`, `theta_split=0.9`, `theta_merge=0.5` (see Table 2 in the CoDA manuscript).

In [None]:

# ==== CoDA agent ====
from dataclasses import dataclass

@dataclass
class LatentState:
    id: int
    obs: int                     # emitted observation symbol (aliased)
    path: Optional[str] = None   # None, 'R1', or 'R2' history tag
    parent: Optional[int] = None # id of original parent (for merges/visualization)

class CoDAAgent:
    def __init__(self, obs_symbols, gamma=0.9, lam=0.8, theta_split=0.9, theta_merge=0.5):
        self.gamma = gamma
        self.lam = lam
        self.theta_split = theta_split
        self.theta_merge = theta_merge
        self.reset_symbols = {0}         # reset context at teleportation
        
        # Base latent states: one per observation
        self.states: Dict[int, LatentState] = {}
        self.obs_to_state_ids: Dict[int, List[int]] = {o: [] for o in obs_symbols}
        sid = 0
        for o in obs_symbols:
            st = LatentState(id=sid, obs=o, path=None, parent=None)
            self.states[sid] = st
            self.obs_to_state_ids[o].append(sid)
            sid += 1
        self._next_sid = sid
        
        # Edges: count transitions actually experienced (for viz)
        self.edge_counts = collections.Counter()  # (sid_from, sid_to) -> count
        
        # Contingency tracking (multi-US): co-occurrence eligibility sums per state × outcome
        self.us_classes = [4,5] # R1_vis, R2_vis
        self.co_occ = {s: {u: 0.0 for u in self.us_classes} for s in self.states}
        
        # For retrospective (optional): counts of state presence on episodes with/without a given US
        self.state_exposure = collections.Counter()  # s -> total exposure amount (eligibility mass)
        self.salient: Dict[int, str] = {}  # state_id -> 'R1' or 'R2'
        
    def _clone_state(self, orig_state_id: int, path: str) -> int:
        orig = self.states[orig_state_id]
        clone = LatentState(id=self._next_sid, obs=orig.obs, path=path, parent=orig_state_id)
        self.states[self._next_sid] = clone
        self.obs_to_state_ids[orig.obs].append(self._next_sid)
        self._next_sid += 1
        return clone.id
    
    def _select_state_for_obs(self, obs: int, current_context: Optional[str]) -> int:
        # Choose existing latent state for a given observation and current context history.
        candidates = self.obs_to_state_ids[obs]
        # Prefer a candidate with matching path (context), else fall back to base (path=None)
        if current_context is not None:
            for sid in candidates:
                if self.states[sid].path == current_context:
                    return sid
        # fallback to base (path=None); if several, pick the very first base (id order)
        for sid in candidates:
            if self.states[sid].path is None:
                return sid
        # else return first candidate
        return candidates[0]
    
    def run_episode(self, obs_seq: List[int], learn=True):
        # Run one episode. If learn=True, update contingencies and consider splits.
        # Returns the sequence of active latent state ids (one per time step).
        # Track context based on salient cues encountered
        context = None
        latent_seq = []
        
        # First, pass through the sequence to get latent states actually traversed.
        # We do not split during traversal; splits happen after contingency updates.
        for t, obs in enumerate(obs_seq):
            sid = self._select_state_for_obs(obs, context)
            latent_seq.append(sid)
            # Update context if current sid is a salient cue
            if sid in self.salient:
                context = self.salient[sid]
            if obs in self.reset_symbols:
                context = None
        
        if not learn:
            return latent_seq
        
        # === CONTINGENCY UPDATES via contextual eligibility ===
        # Identify positions of US events: visual R1 (4) and visual R2 (5)
        us_positions = {4: [i for i,o in enumerate(obs_seq) if o==4],
                        5: [i for i,o in enumerate(obs_seq) if o==5]}
        
        # For each US event, accumulate a decayed eligibility trace over prior states
        for u, positions in us_positions.items():
            for t_us in positions:
                e = np.zeros(len(self.states), dtype=float)
                # propagate a backward trace from start..t_us
                for t in range(t_us+1):
                    sid = latent_seq[t]
                    # decay
                    e *= (self.gamma * self.lam)
                    e[sid] += 1.0
                    # store state exposure (for optional retrospective)
                    self.state_exposure[sid] += e[sid]
                # Add the snapshot at US time to co-occurrence
                for s_id, val in enumerate(e):
                    if val>0:
                        self.co_occ[s_id][u] += val
        
        # === SPLITTING (forward / prospective) ===
        # Compute P(u|s) as normalized co-occurrence over US classes
        P = {}
        for s in self.states:
            tot = sum(self.co_occ[s][u] for u in self.us_classes)
            if tot<=0:
                P[s] = {u: 0.0 for u in self.us_classes}
            else:
                P[s] = {u: self.co_occ[s][u]/tot for u in self.us_classes}
        
        # Mark new salient cues
        newly_salient = []
        for s in self.states:
            if s in self.salient:
                continue
            if sum(self.co_occ[s].values())<=0:
                continue
            # strongest US
            u_star = max(self.us_classes, key=lambda u: P[s][u])
            if P[s][u_star] > self.theta_split:
                path = 'R1' if u_star==4 else 'R2'
                self.salient[s] = path
                newly_salient.append((s, path))
        
        # Create clones of successors along the same path (propagate downstream splitting)
        # We look at actual transitions experienced in this episode and clone the "next" state under context
        for s, path in newly_salient:
            # find all indices where latent_seq[t]==s and t+1 exists
            idxs = [t for t, sid in enumerate(latent_seq[:-1]) if sid==s]
            for t in idxs:
                next_obs = obs_seq[t+1]
                # ensure a clone exists for next_obs with matching path
                next_candidates = self.obs_to_state_ids[next_obs]
                has_clone = any(self.states[cid].path==path for cid in next_candidates)
                if not has_clone:
                    self._clone_state(orig_state_id=next_candidates[0], path=path)
        
        # === Record edges (for viz) ===
        context = None
        for t in range(len(obs_seq)-1):
            sid = self._select_state_for_obs(obs_seq[t], context)
            nid = self._select_state_for_obs(obs_seq[t+1], context if sid not in self.salient else self.salient[sid])
            self.edge_counts[(sid, nid)] += 1
            if sid in self.salient:
                context = self.salient[sid]
            if obs_seq[t+1] in self.reset_symbols:
                context = None
        
        return latent_seq
    
    # ---------- Evaluation helpers ----------
    def encode_sequence(self, obs_seq: List[int]) -> np.ndarray:
        # Return T x S one-hot occupancy of latent states for this sequence with current splitting (no learning).
        lat = self.run_episode(obs_seq, learn=False)
        S = max(self.states)+1
        X = np.zeros((len(lat), S), dtype=float)
        for t, sid in enumerate(lat):
            X[t, sid] = 1.0
        return X[:, :len(self.states)]  # in case ids are compact
    
    def near_far_corr(self, near_seq, far_seq) -> np.ndarray:
        A = self.encode_sequence(near_seq)  # 26 x S
        B = self.encode_sequence(far_seq)   # 26 x S
        # Pearson correlation between rows of A and rows of B
        C = np.zeros((A.shape[0], B.shape[0]))
        for i in range(A.shape[0]):
            for j in range(B.shape[0]):
                a = A[i]; b = B[j]
                if np.allclose(a,0) or np.allclose(b,0):
                    C[i,j] = 0.0
                else:
                    # center
                    a0 = a - a.mean(); b0 = b - b.mean()
                    denom = (np.linalg.norm(a0)*np.linalg.norm(b0))
                    C[i,j] = (a0@b0)/denom if denom>0 else 0.0
        return C
    
    def graph_for_plot(self):
        # Return a NetworkX DiGraph of current latent states and the observed episode-derived edges.
        G = nx.DiGraph()
        for sid, st in self.states.items():
            label = f"{sid}\nobs={st.obs}\n{st.path or 'base'}"
            color = {None:'#cccccc','R1':'#ff88aa','R2':'#88ccff'}[st.path]
            G.add_node(sid, label=label, color=color, obs=st.obs, path=st.path)
        # only show edges that were experienced
        for (u,v), c in self.edge_counts.items():
            if u in self.states and v in self.states:
                G.add_edge(u,v, weight=c)
        return G
    
# --- Hot‑patch for CoDAAgent: initialize contingency bookkeeping for clones ---
def _ensure_state_keys(self, sid):
    # Make sure internal counters exist for any state id
    if sid not in self.co_occ:
        self.co_occ[sid] = {u: 0.0 for u in self.us_classes}
    if sid not in self.state_exposure:
        self.state_exposure[sid] = 0.0

def _clone_state_fixed(self, orig_state_id: int, path: str) -> int:
    orig = self.states[orig_state_id]
    clone_id = self._next_sid
    clone = LatentState(id=clone_id, obs=orig.obs, path=path, parent=orig_state_id)
    self.states[clone_id] = clone
    self.obs_to_state_ids[orig.obs].append(clone_id)
    # NEW: initialize contingency bookkeeping for the clone
    _ensure_state_keys(self, clone_id)
    self._next_sid += 1
    return clone_id

# Bind the helpers to the class (monkey-patch)
CoDAAgent._ensure_state_keys = _ensure_state_keys
CoDAAgent._clone_state = _clone_state_fixed

# (Optional, defensive) If you want to “repair” an agent that already exists:
def _repair_bookkeeping(self):
    for s in list(self.states.keys()):
        _ensure_state_keys(self, s)

CoDAAgent._repair_bookkeeping = _repair_bookkeeping    


## Train CoDA on randomized near/far episodes and evaluate

We simulate multiple runs of learning. Each run has several sessions; within a session, we sample a randomized mix of near/far episodes, mirroring the experiment (both trial types, randomized order). After selected sessions we compute:

- The near–far cross‑correlation matrix.
- The block means for `offdiag`, `preR2`, `preR1`.
- The earliest session where each block’s mean drops below the 0.3 threshold (Fig. 4j).

You can adjust `N_RUNS`, `SESSIONS`, and `TRIALS_PER_SESSION` to match your compute budget.

In [None]:

# ==== Training params ====
N_RUNS = 8
SESSIONS = 9
TRIALS_PER_SESSION = 80
THRESH = 0.3

def block_mean(C, pairs):
    if len(pairs)==0: return np.nan
    vals = [C[i,j] for (i,j) in pairs]
    return float(np.mean(vals))

rng = np.random.default_rng(123)

all_final_blocks = []   # (run, offdiag, preR2, preR1)
time_to_thresh = []     # (run, offdiag_t, preR2_t, preR1_t) in [0,1]

checkpoint_sessions = [1,3,4,9]  # to plot example matrices like Fig. 4d (optional)
example_mats = {s: [] for s in checkpoint_sessions}

for run in range(N_RUNS):
    agent = CoDAAgent(obs_symbols=symbols, gamma=0.9, lam=0.8, theta_split=0.9, theta_merge=0.5)
    tt = {'offdiag': None, 'preR2': None, 'preR1': None}
    
    for session in range(1, SESSIONS+1):
        episodes = [near]*(TRIALS_PER_SESSION//2) + [far]*(TRIALS_PER_SESSION//2)
        rng.shuffle(episodes)
        for ep in episodes:
            agent.run_episode(ep, learn=True)
        
        C = agent.near_far_corr(near, far)
        if session in example_mats: example_mats[session].append(C)
        
        b_off = block_mean(C, offdiag_pairs)
        b_r2  = block_mean(C, same_preR2_pairs)
        b_r1  = block_mean(C, same_preR1_pairs)
        
        if tt['offdiag'] is None and b_off < THRESH: tt['offdiag'] = session
        if tt['preR2']  is None and b_r2  < THRESH: tt['preR2']  = session
        if tt['preR1']  is None and b_r1  < THRESH: tt['preR1']  = session
    
    C_final = agent.near_far_corr(near, far)
    all_final_blocks.append((run, block_mean(C_final, offdiag_pairs),
                                  block_mean(C_final, same_preR2_pairs),
                                  block_mean(C_final, same_preR1_pairs)))
    
    def norm_t(x): 
        return (x/SESSIONS) if x is not None else np.nan
    time_to_thresh.append((run, norm_t(tt['offdiag']), norm_t(tt['preR2']), norm_t(tt['preR1'])))

blocks_df = pd.DataFrame(all_final_blocks, columns=['run','offdiag','preR2','preR1']).set_index('run')
timet_df = pd.DataFrame(time_to_thresh, columns=['run','offdiag_t','preR2_t','preR1_t']).set_index('run')

display(blocks_df.describe())
display(timet_df.describe())


### Example near–far correlation matrices over learning (like Fig. 4d)

We show the mean across runs at sessions 1, 3, 4, 9 for illustration.

In [None]:

mean_mats = {s: np.mean(example_mats[s], axis=0) for s in checkpoint_sessions}

fig, axes = plt.subplots(1, len(checkpoint_sessions), figsize=(4.5*len(checkpoint_sessions),4), constrained_layout=True)
vmin, vmax = -0.1, 1.0
for ax, s in zip(axes, checkpoint_sessions):
    im = ax.imshow(mean_mats[s], vmin=vmin, vmax=vmax, origin='lower', aspect='auto')
    ax.set_title(f"Session {s}")
    ax.set_xlabel("Far position index")
    ax.set_ylabel("Near position index")
fig.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
plt.show()


## Fig. 4i analogue — final matrix block quantification (CoDA)

Bars summarize the final near–far correlation matrix by **off‑diagonal**, **pre‑R2**, and **pre‑R1** blocks.

In [None]:

means = blocks_df.mean()
ses   = blocks_df.sem()

fig, ax = plt.subplots(figsize=(5,4))
labels = ['offdiag','preR2','preR1']
x = np.arange(len(labels))
y = [means[l] for l in labels]
yerr = [ses[l] for l in labels]
ax.bar(x, y, yerr=yerr, capsize=4)
ax.set_xticks(x); ax.set_xticklabels(labels)
ax.set_ylabel("Mean correlation (final)")
ax.set_title("CoDA — Fig. 4i analogue")
plt.show()


## Fig. 4j analogue — time to reach threshold (0.3)

Bars show the fraction of training (sessions normalized to 1.0) until each block’s mean correlation first drops below **0.3**.

In [None]:

means_t = timet_df.mean(skipna=True)
ses_t   = timet_df.sem(skipna=True)

fig, ax = plt.subplots(figsize=(5,4))
labels = ['offdiag_t','preR2_t','preR1_t']
disp_labels = ['offdiag','preR2','preR1']
x = np.arange(len(labels))
y = [means_t[l] for l in labels]
yerr = [ses_t[l] for l in labels]
ax.bar(x, y, yerr=yerr, capsize=4)
ax.set_xticks(x); ax.set_xticklabels(disp_labels)
ax.set_ylim(0, 1.05)
ax.set_ylabel("Fraction of training (to corr < 0.3)")
ax.set_title("CoDA — Fig. 4j analogue")
plt.show()


## Bonus — Fig. 4c analogue: CoDA’s final latent-state transition graph

We plot the model’s final graph from one run (run 0). Nodes are colored by history path (base/near/far) and labeled by `(state id, obs, path)`. Edges are those experienced during training (thicker = more traversals).

In [None]:

# Re-train one agent (run 0) and then plot its graph
agent = CoDAAgent(obs_symbols=symbols, gamma=0.9, lam=0.8, theta_split=0.9, theta_merge=0.5)
rng = np.random.default_rng(0)
for session in range(SESSIONS):
    episodes = [near]*(TRIALS_PER_SESSION//2) + [far]*(TRIALS_PER_SESSION//2)
    rng.shuffle(episodes)
    for ep in episodes:
        agent.run_episode(ep, learn=True)

G = agent.graph_for_plot()

# Lay out nodes roughly along the sequence order: sort by (path, obs) for a simple left-to-right layout
pos = {}
order_map = {None:0, 'R1':1, 'R2':2}
sorted_nodes = sorted(G.nodes(data=True), key=lambda kv: (order_map[kv[1]['path']], kv[1]['obs'], kv[0]))
for i,(nid,nd) in enumerate(sorted_nodes):
    pos[nid] = (i, order_map[nd['path']])

fig, ax = plt.subplots(figsize=(12,3), constrained_layout=True)
colors = [G.nodes[n]['color'] for n in G.nodes]
nx.draw(G, pos, with_labels=False, node_color=colors, node_size=400, arrows=True, ax=ax, width=[max(0.5, G.edges[e].get('weight',1)/50.0) for e in G.edges])

# add small text labels
for n in G.nodes:
    lab = G.nodes[n]['label']
    ax.text(pos[n][0], pos[n][1]+0.1, lab, fontsize=7, ha='center')

ax.set_title("CoDA final latent-state graph (one run)")
ax.axis('off')
plt.show()


## (Optional) Overlay animal / CSCG bars

If you have Sun et al.’s **Source Data** CSV for Fig. 4 or the CSCG notebook outputs, you can place them in a local folder and point `SOURCE_DATA_CSV` to it to overlay comparison bars. By default this cell does nothing.

In [None]:

# Example scaffold for overlay (disabled by default).
SOURCE_DATA_CSV = None  # e.g., "source_data/figure4_source_data.csv"
if SOURCE_DATA_CSV:
    try:
        src = pd.read_csv(SOURCE_DATA_CSV)
        print("Loaded source data with", len(src), "rows")
        # Implement overlay if desired
    except Exception as e:
        print("Could not load source data:", e)
else:
    print("No source data provided; showing CoDA-only figures.")


---

### Notes & parameters
- **CoDA params:** `gamma=0.9`, `lambda=0.8`, `theta_split=0.9`, `theta_merge=0.5` (Table 2 in the CoDA manuscript).
- **Training harness:** `SESSIONS=9`, `TRIALS_PER_SESSION=80`, `N_RUNS=8` by default; increase for more stable error bars.
- **Threshold for Fig. 4j:** `0.3`, as in Sun et al.’s Methods.
- **Generalization of US:** we use `US∈{4,5}` (visual reward zone cues) so that the indicator is the earliest perfectly predictive cue; this is faithful to the CSCG simulation setup and avoids trivialities due to the shared water (`6`).