# CoDA in the 2ACDC near/far task — Reproductions of OSM Fig. 4c, 4i, 4j

This notebook mirrors the code structure of the official OSM figure notebooks to:

- Simulate **CoDA** on the 2ACDC near/far symbol sequences used in the CSCG simulation.
- Reproduce the **Fig. 4c** transition graph (clear layout and colored branches),
- Reproduce **Fig. 4i** final correlation block quantification, and
- Reproduce **Fig. 4j** decorrelation order (time-to-threshold).

References to the original notebooks:
- Transition graph: `fig_4/fig_4_CSCG/fig_4_c_Transition_graph.ipynb`.
- Final correlation quantification: `fig_4/fig_4i_Final_correlation_quantification.ipynb`.
- Decorrelation order: `fig_4/fig_4j_decorr_order.ipynb`.


In [None]:

# Core imports
import math, random, itertools, collections
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Graphing
import networkx as nx

# Reproducibility
random.seed(0); np.random.seed(0)


## Task definition (2ACDC near/far symbol sequences)

We use the exact discrete sequences used by the CSCG simulation in the OSM repo. Symbols: `1`=grey corridor, `2`=near-indicator, `3`=far-indicator, `4`=R1 visual, `5`=R2 visual, `6`=water, `7`=wall, `0`=teleport.

In [None]:

near = [1,1,1,1,1,1, 2,2,2,2, 1,1,1, 4, 6, 1,1,1, 5,5, 1,1, 7, 0,0,0]
far  = [1,1,1,1,1,1, 3,3,3,3, 1,1,1, 4,4, 1,1,1, 5, 6, 1,1, 7, 0,0,0]
assert len(near)==len(far)==26

# Indices for Fig. 4 blocks
preR1_idx = list(range(10,13))   # between indicator and R1 visual
preR2_idx = list(range(15,18))   # between (R1/water) and R2 visual

def block_indices(rows, cols):
    return [(r,c) for r in rows for c in cols]

offdiag_pairs     = block_indices(preR1_idx, preR2_idx) + block_indices(preR2_idx, preR1_idx)
same_preR1_pairs  = block_indices(preR1_idx, preR1_idx)
same_preR2_pairs  = block_indices(preR2_idx, preR2_idx)

symbols = sorted(set(near) | set(far))
symbols


## CoDA (minimal, tabular) with contextual eligibility traces

This implementation follows the split/merge logic in the CoDA manuscript and adds a bookkeeping fix so clones get their own contingency counters.

In [None]:

@dataclass
class LatentState:
    id: int
    obs: int
    path: Optional[str] = None   # None, 'R1', or 'R2'
    parent: Optional[int] = None

class CoDAAgent:
    def __init__(self, obs_symbols, gamma=0.9, lam=0.8, theta_split=0.9, theta_merge=0.5):
        self.gamma = gamma
        self.lam = lam
        self.theta_split = theta_split
        self.theta_merge = theta_merge
        self.reset_symbols = {0}
        
        # Base latent states
        self.states: Dict[int, LatentState] = {}
        self.obs_to_state_ids: Dict[int, List[int]] = {o: [] for o in obs_symbols}
        sid = 0
        for o in obs_symbols:
            st = LatentState(id=sid, obs=o, path=None, parent=None)
            self.states[sid] = st
            self.obs_to_state_ids[o].append(sid)
            sid += 1
        self._next_sid = sid
        
        # Transition counts (for graph)
        self.edge_counts = collections.Counter()
        
        # Contingency storage
        self.us_classes = [4,5]  # R1_vis, R2_vis
        self.co_occ = {s: {u: 0.0 for u in self.us_classes} for s in self.states}
        self.state_exposure = collections.Counter()
        self.salient: Dict[int, str] = {}
    
    # --- helpers ---
    def _ensure_state_keys(self, sid):
        if sid not in self.co_occ:
            self.co_occ[sid] = {u: 0.0 for u in self.us_classes}
        if sid not in self.state_exposure:
            self.state_exposure[sid] = 0.0
    
    def _clone_state(self, orig_state_id: int, path: str) -> int:
        orig = self.states[orig_state_id]
        clone_id = self._next_sid
        clone = LatentState(id=clone_id, obs=orig.obs, path=path, parent=orig_state_id)
        self.states[clone_id] = clone
        self.obs_to_state_ids[orig.obs].append(clone_id)
        self._ensure_state_keys(clone_id)  # FIX
        self._next_sid += 1
        return clone_id
    
    def _select_state_for_obs(self, obs: int, current_context: Optional[str]) -> int:
        # Choose state that matches context if available, else base
        candidates = self.obs_to_state_ids[obs]
        if current_context is not None:
            for sid in candidates:
                if self.states[sid].path == current_context:
                    return sid
        for sid in candidates:
            if self.states[sid].path is None:
                return sid
        return candidates[0]
    
    # --- learning ---
    def run_episode(self, obs_seq: List[int], learn=True):
        context = None
        latent_seq = []
        for t, obs in enumerate(obs_seq):
            sid = self._select_state_for_obs(obs, context)
            latent_seq.append(sid)
            if sid in self.salient:
                context = self.salient[sid]
            if obs in self.reset_symbols:
                context = None
        
        if not learn:
            return latent_seq
        
        # Identify US events
        us_positions = {4: [i for i,o in enumerate(obs_seq) if o==4],
                        5: [i for i,o in enumerate(obs_seq) if o==5]}
        # Update contextual eligibilities
        for u, pos_list in us_positions.items():
            for t_us in pos_list:
                e = np.zeros(max(self.states)+1, dtype=float)
                for t in range(t_us+1):
                    sid = latent_seq[t]
                    e *= (self.gamma * self.lam)
                    e[sid] += 1.0
                    self.state_exposure[sid] += e[sid]
                # Add snapshot
                for s_id, val in enumerate(e):
                    if val>0 and s_id in self.states:
                        self._ensure_state_keys(s_id)
                        self.co_occ[s_id][u] += val
        
        # Prospective contingency P(u|s)
        P = {}
        for s in list(self.states.keys()):
            self._ensure_state_keys(s)
            tot = sum(self.co_occ[s][u] for u in self.us_classes)
            if tot<=0: P[s] = {u:0.0 for u in self.us_classes}
            else:      P[s] = {u:self.co_occ[s][u]/tot for u in self.us_classes}
        
        newly_salient = []
        for s in list(self.states.keys()):
            if s in self.salient: 
                continue
            if sum(self.co_occ[s].values())<=0: 
                continue
            u_star = max(self.us_classes, key=lambda u: P[s][u])
            if P[s][u_star] > self.theta_split:
                path = 'R1' if u_star==4 else 'R2'
                self.salient[s] = path
                newly_salient.append((s, path))
        
        # Propagate splitting: clone immediate successors along same path
        for s, path in newly_salient:
            idxs = [t for t, sid in enumerate(latent_seq[:-1]) if sid==s]
            for t in idxs:
                next_obs = obs_seq[t+1]
                cands = self.obs_to_state_ids[next_obs]
                has_clone = any(self.states[c].path==path for c in cands)
                if not has_clone:
                    self._clone_state(cands[0], path)
        
        # Record edges from a single pass to keep graph simple
        context = None
        for t in range(len(obs_seq)-1):
            sid = self._select_state_for_obs(obs_seq[t], context)
            if sid in self.salient:
                context = self.salient[sid]
            nid = self._select_state_for_obs(obs_seq[t+1], context)
            self.edge_counts[(sid, nid)] += 1
            if obs_seq[t+1] in self.reset_symbols:
                context = None
        
        return latent_seq
    
    # --- analysis ---
    def encode_sequence(self, obs_seq: List[int]) -> np.ndarray:
        lat = self.run_episode(obs_seq, learn=False)
        S = max(self.states)+1
        X = np.zeros((len(lat), S), dtype=float)
        for t, sid in enumerate(lat):
            X[t, sid] = 1.0
        return X[:, :len(self.states)]
    
    def near_far_corr(self, near_seq, far_seq) -> np.ndarray:
        A = self.encode_sequence(near_seq)
        B = self.encode_sequence(far_seq)
        C = np.zeros((A.shape[0], B.shape[0]))
        for i in range(A.shape[0]):
            for j in range(B.shape[0]):
                a = A[i]; b = B[j]
                if np.allclose(a,0) or np.allclose(b,0):
                    C[i,j] = 0.0
                else:
                    a0 = a - a.mean(); b0 = b - b.mean()
                    denom = (np.linalg.norm(a0)*np.linalg.norm(b0))
                    C[i,j] = (a0@b0)/denom if denom>0 else 0.0
        return C


## Train CoDA and gather matrices across sessions (as in Fig. 4i & 4j)

We randomize near/far episodes within session. After each session we compute the near–far correlation matrix, then extract the three key block averages used by the OSM notebooks (`offdiag`, `preR2`, `preR1`).

In [None]:

N_RUNS = 8
SESSIONS = 9
TRIALS_PER_SESSION = 80
THRESH = 0.3

def block_mean(C, pairs):
    if not pairs: return np.nan
    return float(np.mean([C[i,j] for (i,j) in pairs]))

rng = np.random.default_rng(123)

# Storage
all_final_blocks = []     # (run, offdiag, preR2, preR1)
all_time_to_thr = []      # (run, t_off, t_preR2, t_preR1) normalized to [0,1]
mat_by_session = {s: [] for s in [1,3,4,9]}

for run in range(N_RUNS):
    agent = CoDAAgent(obs_symbols=sorted(set(near)|set(far)))
    tt = {'offdiag': None, 'preR2': None, 'preR1': None}
    for session in range(1, SESSIONS+1):
        episodes = [near]*(TRIALS_PER_SESSION//2) + [far]*(TRIALS_PER_SESSION//2)
        rng.shuffle(episodes)
        for ep in episodes:
            agent.run_episode(ep, learn=True)
        # Evaluate
        C = agent.near_far_corr(near, far)
        if session in mat_by_session:
            mat_by_session[session].append(C)
        b_off = block_mean(C, offdiag_pairs)
        b_r2  = block_mean(C, same_preR2_pairs)
        b_r1  = block_mean(C, same_preR1_pairs)
        if tt['offdiag'] is None and b_off < THRESH: tt['offdiag'] = session
        if tt['preR2']  is None and b_r2  < THRESH: tt['preR2']  = session
        if tt['preR1']  is None and b_r1  < THRESH: tt['preR1']  = session
    # Final
    C_final = agent.near_far_corr(near, far)
    all_final_blocks.append((run, block_mean(C_final, offdiag_pairs),
                                  block_mean(C_final, same_preR2_pairs),
                                  block_mean(C_final, same_preR1_pairs)))
    def norm(x): return x/SESSIONS if x is not None else np.nan
    all_time_to_thr.append((run, norm(tt['offdiag']), norm(tt['preR2']), norm(tt['preR1'])))

blocks_df = pd.DataFrame(all_final_blocks, columns=['run','offdiag','preR2','preR1']).set_index('run')
times_df  = pd.DataFrame(all_time_to_thr, columns=['run','offdiag_t','preR2_t','preR1_t']).set_index('run')

display(blocks_df.describe())
display(times_df.describe())


## Fig. 4c (transition graph) — OSM‑style clean layout

We render **two colored trajectories** (near=R1 branch, far=R2 branch) on a layered layout:
- y=0: base (pre‑split) latent states, y=+1: R1 path (near), y=−1: R2 path (far).
- x‑position is the **mean step index** where a latent state appears across the two trajectories (gives left‑to‑right flow).
- Only edges along **canonical near/far trajectories** are drawn (no cross‑episode shortcuts), with arc offsets to avoid overlap.

In [None]:

def latent_path(agent, seq):
    # Return list of (t, sid) for a single pass without learning
    lat = agent.run_episode(seq, learn=False)
    return list(enumerate(lat))

# Retrain a fresh agent and fit it (to get a representative final graph)
agent = CoDAAgent(obs_symbols=sorted(set(near)|set(far)))
rng = np.random.default_rng(0)
for session in range(SESSIONS):
    episodes = [near]*(TRIALS_PER_SESSION//2) + [far]*(TRIALS_PER_SESSION//2)
    rng.shuffle(episodes)
    for ep in episodes:
        agent.run_episode(ep, learn=True)

# Collect canonical near and far latent sequences
near_path = latent_path(agent, near)
far_path  = latent_path(agent, far)

# Compute node positions: x = mean time of occurrence across both paths; y by path tag
from collections import defaultdict
time_acc = defaultdict(list)
for t,sid in near_path: time_acc[sid].append(t)
for t,sid in far_path:  time_acc[sid].append(t)

pos = {}
def y_level(st):
    return {None:0.0, 'R1':1.0, 'R2':-1.0}[st.path]

for sid, st in agent.states.items():
    if sid in time_acc:
        x = np.mean(time_acc[sid])
        pos[sid] = (x, y_level(st))
    else:
        # states never used in the canonical near/far path; push them far left and lightly alpha
        pos[sid] = (-1.0, y_level(st))

# Build edges only from canonical consecutive pairs within each path
near_edges = [(near_path[i][1], near_path[i+1][1]) for i in range(len(near_path)-1)]
far_edges  = [(far_path[i][1],  far_path[i+1][1])  for i in range(len(far_path)-1)]

# Draw
fig, ax = plt.subplots(figsize=(12,3), constrained_layout=True)

# Nodes (color by path)
node_colors = []
for sid in agent.states:
    path = agent.states[sid].path
    node_colors.append({None:'#cfcfcf','R1':'#ff9bb0','R2':'#8fbff5'}[path])

nx.draw_networkx_nodes(nx.DiGraph(), pos, nodelist=list(agent.states.keys()),
                       node_color=node_colors, node_size=420, ax=ax)

# Draw edges with slight opposite curvature for near vs far
def draw_edges(edge_list, color, rad):
    for (u,v) in edge_list:
        if u not in pos or v not in pos: 
            continue
        con = 'arc3,rad={}'.format(rad)
        ax.annotate('', xy=pos[v], xytext=pos[u],
                    arrowprops=dict(arrowstyle='-|>', lw=1.6, color=color, alpha=0.9,
                                    connectionstyle=con))

draw_edges(near_edges, color='#d64b5a', rad=0.18)   # near (R1)
draw_edges(far_edges,  color='#2f6db3', rad=-0.18)  # far  (R2)

# Labels (only for informative symbols)
for sid, st in agent.states.items():
    lab_obs = st.obs
    if lab_obs in (2,3,4,5,6,7):  # annotate key points
        ax.text(pos[sid][0], pos[sid][1]+0.12, f'obs={lab_obs}\n{st.path or "base"}',
                fontsize=8, ha='center', va='bottom')

ax.set_title('CoDA transition graph (OSM-style clean layout)')
ax.set_ylim(-1.5, 1.5)
ax.set_xlim(-1, 26)
ax.axis('off')
plt.show()


### Example near–far correlation matrices across sessions (sanity check)
We show the mean across runs at sessions 1, 3, 4, 9 (style similar to OSM matrices).

In [None]:

check_sessions = [1,3,4,9]
mean_mats = {s: np.mean(mat_by_session[s], axis=0) for s in check_sessions}

fig, axes = plt.subplots(1, len(check_sessions), figsize=(4.5*len(check_sessions),4), constrained_layout=True)
vmin, vmax = -0.1, 1.0
for ax, s in zip(axes, check_sessions):
    im = ax.imshow(mean_mats[s], vmin=vmin, vmax=vmax, origin='lower', aspect='auto')
    ax.set_title(f"Session {s}")
    ax.set_xlabel("Far position index")
    ax.set_ylabel("Near position index")
fig.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
plt.show()


## Fig. 4i — Final correlation quantification (bars with mean ± s.e.m.)
This follows the OSM pipeline: compute block means from the final near–far matrix per run, then aggregate across runs.

In [None]:

means = blocks_df.mean()
ses   = blocks_df.sem()

fig, ax = plt.subplots(figsize=(5,4))
labels = ['offdiag','preR2','preR1']
x = np.arange(len(labels))
y = [means[l] for l in labels]
yerr = [ses[l] for l in labels]

ax.bar(x, y, yerr=yerr, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(labels)
ax.set_ylabel("Mean correlation (final)")
ax.set_title("CoDA — Final correlation quantification (Fig. 4i analogue)")
ax.set_ylim(0,1.0)
for i,val in enumerate(y):
    ax.text(i, val+0.03, f"{val:.2f}", ha='center', fontsize=9)
plt.show()


## Fig. 4j — Decorrelation order (time to reach correlation < 0.3)
For each block we find the first session at which the mean correlation drops below 0.3, then plot the fraction of training.

In [None]:

means_t = times_df.mean(skipna=True)
ses_t   = times_df.sem(skipna=True)

fig, ax = plt.subplots(figsize=(5,4))
labels = ['offdiag_t','preR2_t','preR1_t']
disp_labels = ['offdiag','preR2','preR1']
x = np.arange(len(labels))
y = [means_t[l] for l in labels]
yerr = [ses_t[l] for l in labels]

ax.bar(x, y, yerr=yerr, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(disp_labels)
ax.set_ylim(0, 1.05)
ax.set_ylabel("Fraction of training (to corr < 0.3)")
ax.set_title("CoDA — Decorrelation order (Fig. 4j analogue)")
for i,val in enumerate(y):
    ax.text(i, min(1.02, val+0.04), f"{val:.2f}", ha='center', fontsize=9)
plt.show()
