# CoDA (uncertainty‑aware) in Sun et al.’s near/far task — Fig. 4c/4i/4j + graph progression

This notebook uses an **uncertainty‑aware split rule** for CoDA (posterior/Wilson gating on prospective contingency) and reproduces OSM‑style **Fig. 4c** (transition graph), **Fig. 4i** (final block quantification), **Fig. 4j** (decorrelation order), plus a **progression** of transition graphs across sessions showing how splits propagate.

In [None]:

import math, random, collections
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch, Circle

# Optional posterior functions (fallback to Wilson if unavailable)
try:
    from mpmath import betainc, erfcinv
    _HAS_MPMATH = True
except Exception:
    _HAS_MPMATH = False

random.seed(0); np.random.seed(0)


In [None]:

def posterior_prob_p_greater_than(theta: float, success: float, failure: float, alpha0: float=0.5, beta0: float=0.5) -> float:
    """Posterior probability P(p > theta | Beta(alpha0+success, beta0+failure)). Uses mpmath if available, else returns 0.0 to trigger Wilson fallback."""
    if not _HAS_MPMATH:
        return 0.0
    a = alpha0 + max(0.0, float(success))
    b = beta0 + max(0.0, float(failure))
    cdf = betainc(a, b, 0, theta, regularized=True)
    return float(1.0 - cdf)

def wilson_lower_bound(phat: float, n: float, confidence: float=0.95) -> float:
    """Wilson score interval (lower bound) for a proportion at confidence level."""
    if n <= 0:
        return 0.0
    if _HAS_MPMATH:
        z = float((2.0**0.5) * erfcinv(2*(1.0-confidence)))
    else:
        z = 1.6448536269514722  # ~95%
    denom = 1.0 + (z*z)/n
    center = phat + (z*z)/(2.0*n)
    adj = z * ((phat*(1.0-phat) + (z*z)/(4.0*n))/n)**0.5
    return (center - adj)/denom


## Sun et al. near/far sequences and Fig. 4 block indices

In [None]:

near = [1,1,1,1,1,1, 2,2,2,2, 1,1,1, 4, 6, 1,1,1, 5,5, 1,1, 7, 0,0,0]
far  = [1,1,1,1,1,1, 3,3,3,3, 1,1,1, 4,4, 1,1,1, 5, 6, 1,1, 7, 0,0,0]
assert len(near)==len(far)==26

preR1_idx = list(range(10,13))   # between indicator and R1 visual
preR2_idx = list(range(15,18))   # between (R1/water) and R2 visual

def block_indices(rows, cols): return [(r,c) for r in rows for c in cols]
offdiag_pairs     = block_indices(preR1_idx, preR2_idx) + block_indices(preR2_idx, preR1_idx)
same_preR1_pairs  = block_indices(preR1_idx, preR1_idx)
same_preR2_pairs  = block_indices(preR2_idx, preR2_idx)

symbols = sorted(set(near)|set(far))
symbols


## CoDA (uncertainty‑aware split rule)

In [None]:

@dataclass
class UncCfg:
    gamma: float = 0.9
    lam: float = 0.8
    theta_split: float = 0.9
    confidence: float = 0.95
    alpha0: float = 0.5
    beta0: float = 0.5
    n_threshold: int = 5
    min_presence_episodes: int = 3
    min_effective_exposure: float = 10.0
    reset_symbols: tuple = (0,)

class CoDAUncAgent:
    """Uncertainty-aware CoDA with contextual eligibility traces and uncertainty-gated splitting."""
    def __init__(self, obs_symbols, cfg: UncCfg=UncCfg()):
        self.cfg = cfg
        self.reset_symbols = set(cfg.reset_symbols)
        # Latent states
        self.states: Dict[int, Dict] = {}  # id -> {'obs':int, 'path':None/'R1'/'R2', 'parent':Optional[int]}
        self.obs_to_state_ids: Dict[int, List[int]] = {o: [] for o in obs_symbols}
        sid = 0
        for o in obs_symbols:
            self.states[sid] = {'obs': o, 'path': None, 'parent': None}
            self.obs_to_state_ids[o].append(sid)
            sid += 1
        self._next_sid = sid
        
        # contingency / exposure
        self.us_classes = [4,5]  # visual R1, visual R2
        self.co_occ = {s: {u: 0.0 for u in self.us_classes} for s in self.states}  # per-state US counts
        self.exposure = {s: 0.0 for s in self.states}                               # total eligibility mass at US times
        self.presence_episodes = {s: 0 for s in self.states}                        # episodes containing the state
        self.salient = {}  # s -> 'R1'/'R2'
    
    def _ensure(self, sid):
        if sid not in self.co_occ:
            self.co_occ[sid] = {u: 0.0 for u in self.us_classes}
            self.exposure[sid] = 0.0
            self.presence_episodes[sid] = 0
    
    def _clone_state(self, orig_state_id: int, path: str) -> int:
        orig = self.states[orig_state_id]
        cid = self._next_sid
        self.states[cid] = {'obs': orig['obs'], 'path': path, 'parent': orig_state_id}
        self.obs_to_state_ids[orig['obs']].append(cid)
        self._ensure(cid)
        self._next_sid += 1
        return cid
    
    def _select_state_for_obs(self, obs: int, context: Optional[str]) -> int:
        cands = self.obs_to_state_ids[obs]
        if context is not None:
            for sid in cands:
                if self.states[sid]['path'] == context:
                    return sid
        for sid in cands:
            if self.states[sid]['path'] is None:
                return sid
        return cands[0]
    
    def run_episode(self, obs_seq: List[int], learn=True):
        context = None
        latent_seq = []
        visited = set()
        for obs in obs_seq:
            sid = self._select_state_for_obs(obs, context)
            latent_seq.append(sid)
            visited.add(sid)
            if sid in self.salient:
                context = self.salient[sid]
            if obs in self.reset_symbols:
                context = None
        
        if not learn:
            return latent_seq
        
        # presence gate
        for sid in visited:
            self.presence_episodes[sid] = self.presence_episodes.get(sid, 0) + 1
        
        # US events: visual R1 (4), visual R2 (5)
        us_positions = {4: [i for i,o in enumerate(obs_seq) if o==4],
                        5: [i for i,o in enumerate(obs_seq) if o==5]}
        
        # contextual eligibility accumulation
        for u, pos_list in us_positions.items():
            for t_us in pos_list:
                e = np.zeros(self._next_sid, dtype=float)
                for t in range(t_us+1):
                    sid = latent_seq[t]
                    e *= (self.cfg.gamma * self.cfg.lam)
                    e[sid] += 1.0
                for s_id, val in enumerate(e):
                    if val>0 and s_id in self.states:
                        self._ensure(s_id)
                        self.co_occ[s_id][u] += val
                        self.exposure[s_id] += val
        
        # P(u|s)
        P = {}
        for s in self.states:
            tot = sum(self.co_occ[s][u] for u in self.us_classes)
            P[s] = {u: (self.co_occ[s][u]/tot if tot>0 else 0.0) for u in self.us_classes}
        
        # uncertainty‑aware splitting
        newly_salient = []
        for s in list(self.states.keys()):
            if s in self.salient: 
                continue
            tot = sum(self.co_occ[s][u] for u in self.us_classes)
            if tot < self.cfg.n_threshold: 
                continue
            if self.presence_episodes.get(s,0) < self.cfg.min_presence_episodes:
                continue
            if self.exposure.get(s,0.0) < self.cfg.min_effective_exposure:
                continue
            u_star = max(self.us_classes, key=lambda u: P[s][u])
            success = self.co_occ[s][u_star]
            failure = tot - success
            phat = success / (tot if tot>0 else 1.0)
            pass_test = False
            post = posterior_prob_p_greater_than(self.cfg.theta_split, success, failure,
                                                 self.cfg.alpha0, self.cfg.beta0)
            if post == 0.0:
                lb = wilson_lower_bound(phat, tot, confidence=self.cfg.confidence)
                pass_test = (lb > self.cfg.theta_split)
            else:
                pass_test = (post >= self.cfg.confidence)
            if pass_test:
                path = 'R1' if u_star==4 else 'R2'
                self.salient[s] = path
                newly_salient.append((s, path))
        
        # propagate: clone immediate successors following each cue occurrence
        for s, path in newly_salient:
            context=None
            for t, obs in enumerate(obs_seq[:-1]):
                sid = self._select_state_for_obs(obs, context)
                if sid == s:
                    nxt_obs = obs_seq[t+1]
                    cands = self.obs_to_state_ids[nxt_obs]
                    if not any(self.states[c]['path']==path for c in cands):
                        self._clone_state(cands[0], path)
                if sid in self.salient:
                    context = self.salient[sid]
                if obs_seq[t+1] in self.reset_symbols:
                    context = None
        
        return latent_seq
    
    # analysis helpers
    def encode_sequence(self, obs_seq: List[int]) -> np.ndarray:
        lat = self.run_episode(obs_seq, learn=False)
        S = self._next_sid
        X = np.zeros((len(lat), S), dtype=float)
        for t, sid in enumerate(lat):
            X[t, sid] = 1.0
        return X[:, :self._next_sid]
    
    def near_far_corr(self, near_seq, far_seq) -> np.ndarray:
        A = self.encode_sequence(near_seq); B = self.encode_sequence(far_seq)
        C = np.zeros((A.shape[0], B.shape[0]))
        for i in range(A.shape[0]):
            for j in range(B.shape[0]):
                a = A[i]; b = B[j]
                if np.allclose(a,0) or np.allclose(b,0):
                    C[i,j] = 0.0
                else:
                    a0 = a - a.mean(); b0 = b - b.mean()
                    denom = (np.linalg.norm(a0)*np.linalg.norm(b0))
                    C[i,j] = (a0@b0)/denom if denom>0 else 0.0
        return C


## Train across sessions; cache matrices (for Fig. 4i/4j) and checkpoints (for progression)

In [None]:

N_RUNS   = 8
SESSIONS = 9
TRIALS_PER_SESSION = 80
THRESH   = 0.3

def block_mean(C, pairs): 
    return float(np.mean([C[i,j] for (i,j) in pairs])) if pairs else np.nan

rng = np.random.default_rng(123)

all_final_blocks = []
all_time_to_thr  = []
mat_by_session   = {s: [] for s in [1,3,4,9]}
checkpoint_sessions = [1,3,4,9]
demo_checkpoints = {}

for run in range(N_RUNS):
    agent = CoDAUncAgent(obs_symbols=sorted(set(near)|set(far)), cfg=UncCfg())
    tt = {'offdiag': None, 'preR2': None, 'preR1': None}
    for session in range(1, SESSIONS+1):
        episodes = [near]*(TRIALS_PER_SESSION//2) + [far]*(TRIALS_PER_SESSION//2)
        rng.shuffle(episodes)
        for ep in episodes:
            agent.run_episode(ep, learn=True)
        C = agent.near_far_corr(near, far)
        if session in mat_by_session:
            mat_by_session[session].append(C)
        if run==0 and session in checkpoint_sessions:
            # stash a lightweight snapshot (shallow copies)
            snap = CoDAUncAgent(obs_symbols=sorted(set(near)|set(far)), cfg=agent.cfg)
            snap.states = {k: v.copy() for k,v in agent.states.items()}
            snap._next_sid = agent._next_sid
            snap.obs_to_state_ids = {k: list(v) for k,v in agent.obs_to_state_ids.items()}
            snap.co_occ = {k: dict(v) for k,v in agent.co_occ.items()}
            snap.exposure = dict(agent.exposure)
            snap.presence_episodes = dict(agent.presence_episodes)
            snap.salient = dict(agent.salient)
            demo_checkpoints[session] = snap
        
        b_off = block_mean(C, offdiag_pairs)
        b_r2  = block_mean(C, same_preR2_pairs)
        b_r1  = block_mean(C, same_preR1_pairs)
        if tt['offdiag'] is None and b_off < THRESH: tt['offdiag'] = session
        if tt['preR2']  is None and b_r2  < THRESH: tt['preR2']  = session
        if tt['preR1']  is None and b_r1  < THRESH: tt['preR1']  = session
    
    # final bars
    C_final = agent.near_far_corr(near, far)
    all_final_blocks.append((run, block_mean(C_final, offdiag_pairs),
                                  block_mean(C_final, same_preR2_pairs),
                                  block_mean(C_final, same_preR1_pairs)))
    def norm(x): return x/SESSIONS if x is not None else np.nan
    all_time_to_thr.append((run, norm(tt['offdiag']), norm(tt['preR2']), norm(tt['preR1'])))

blocks_df = pd.DataFrame(all_final_blocks, columns=['run','offdiag','preR2','preR1']).set_index('run')
times_df  = pd.DataFrame(all_time_to_thr,  columns=['run','offdiag_t','preR2_t','preR1_t']).set_index('run')

display(blocks_df.describe())
display(times_df.describe())


## Fig. 4c — ring layout (clean) and **progression** of graphs

In [None]:

def canonical_latents(agent, seq):
    lat = agent.run_episode(seq, learn=False)
    out = []
    for t,(obs,sid) in enumerate(zip(seq, lat)):
        if obs==0: break
        out.append((t, obs, sid, agent.states[sid]['path']))
    return out

def ring_positions(T: int, r_base=1.0, r_near=1.12, r_far=0.88, start_angle=np.pi/2):
    pos_by_t = {}
    for t in range(T):
        theta = start_angle - 2*np.pi*(t / max(T,1))
        pos_by_t[t] = {
            'base': (r_base*np.cos(theta), r_base*np.sin(theta)),
            'R1':   (r_near*np.cos(theta), r_near*np.sin(theta)),
            'R2':   (r_far*np.cos(theta),  r_far*np.sin(theta)),
        }
    return pos_by_t

def draw_ring_graph(ax, agent, title=None, alpha_grey=0.35, key_only=True):
    near_lat = canonical_latents(agent, near)
    far_lat  = canonical_latents(agent, far)
    T = max(len(near_lat), len(far_lat))
    pos = ring_positions(T)

    # Node palette
    col_node = {None:'#cfcfcf','R1':'#ff9bb0','R2':'#8fbff5'}
    col_edge_near = '#d64b5a'
    col_edge_far  = '#2f6db3'
    
    # Draw nodes on three rings; grey corridor nodes at low alpha
    def draw_nodes(lat_seq):
        for (t, obs, sid, path) in lat_seq:
            ring = 'base' if path is None else path
            x, y = pos[t][ring]
            alpha = alpha_grey if obs==1 else 0.95
            circ = Circle((x,y), radius=0.045, facecolor=col_node[path], edgecolor='white', lw=0.7, alpha=alpha)
            ax.add_patch(circ)
            if key_only and obs in (2,3,4,5,6,7):
                lbl = f'{obs if obs not in (2,3) else ("2N" if obs==2 else "3F")}'
                ax.text(x, y+0.08, lbl, ha='center', va='bottom', fontsize=7, color='k', alpha=0.9)

    draw_nodes(near_lat)
    draw_nodes(far_lat)

    # Draw edges as short chords between consecutive time indices on each ring
    def draw_edges(lat_seq, edge_color):
        for i in range(len(lat_seq)-1):
            t, obs, sid, path = lat_seq[i]
            t2, obs2, sid2, path2 = lat_seq[i+1]
            ring = 'base' if path is None else path
            ring2= 'base' if path2 is None else path2
            x1,y1 = pos[t][ring];   x2,y2 = pos[t2][ring2]
            arr = FancyArrowPatch((x1,y1),(x2,y2), arrowstyle='-|>', mutation_scale=6,
                                  lw=1.4, color=edge_color, alpha=0.9, shrinkA=2, shrinkB=2)
            ax.add_patch(arr)

    draw_edges(near_lat, '#d64b5a')
    draw_edges(far_lat,  '#2f6db3')

    ax.set_aspect('equal')
    ax.axis('off')
    if title: ax.set_title(title)

# Single snapshot (final)
final_session = max([s for s in [1,3,4,9] if s in demo_checkpoints])
fig, ax = plt.subplots(figsize=(8,3.2), constrained_layout=True)
draw_ring_graph(ax, demo_checkpoints[final_session], title="CoDA (uncertainty‑aware) — final transition graph (Fig. 4c style)")
plt.show()

# Progression panels
fig, axes = plt.subplots(1, len(demo_checkpoints), figsize=(4.5*len(demo_checkpoints),3.2), constrained_layout=True)
for ax, s in zip(axes, sorted(demo_checkpoints.keys())):
    draw_ring_graph(ax, demo_checkpoints[s], title=f"Session {s}")
plt.show()


### Near–far correlation matrices (mean across runs)

In [None]:

check_sessions = [1,3,4,9]
mean_mats = {s: np.mean(mat_by_session[s], axis=0) for s in check_sessions}

fig, axes = plt.subplots(1, len(check_sessions), figsize=(4.5*len(check_sessions),4), constrained_layout=True)
vmin, vmax = -0.1, 1.0
for ax, s in zip(axes, check_sessions):
    im = ax.imshow(mean_mats[s], vmin=vmin, vmax=vmax, origin='lower', aspect='auto')
    ax.set_title(f"Session {s}")
    ax.set_xlabel("Far position index")
    ax.set_ylabel("Near position index")
fig.colorbar(im, ax=axes, fraction=0.046, pad=0.04)
plt.show()


## Fig. 4i and Fig. 4j — OSM‑style bars

In [None]:

# Fig. 4i
means = blocks_df.mean(); ses = blocks_df.sem()
labels = ['offdiag','preR2','preR1']
x = np.arange(len(labels)); y = [means[l] for l in labels]; yerr = [ses[l] for l in labels]
fig, ax = plt.subplots(figsize=(5,4))
ax.bar(x, y, yerr=yerr, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(labels)
ax.set_ylim(0,1.0)
ax.set_ylabel("Mean correlation (final)")
ax.set_title("CoDA (uncertainty‑aware) — Fig. 4i analogue")
for i,val in enumerate(y): ax.text(i, val+0.03, f"{val:.2f}", ha='center', fontsize=9)
plt.show()

# Fig. 4j
means_t = times_df.mean(skipna=True); ses_t = times_df.sem(skipna=True)
labels_t = ['offdiag_t','preR2_t','preR1_t']; disp = ['offdiag','preR2','preR1']
x = np.arange(len(labels_t)); y = [means_t[l] for l in labels_t]; yerr = [ses_t[l] for l in labels_t]
fig, ax = plt.subplots(figsize=(5,4))
ax.bar(x, y, yerr=yerr, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(disp)
ax.set_ylim(0,1.05)
ax.set_ylabel("Fraction of training (first corr < 0.3)")
ax.set_title("CoDA (uncertainty‑aware) — Fig. 4j analogue")
for i,val in enumerate(y): ax.text(i, min(1.02, val+0.04), f"{val:.2f}", ha='center', fontsize=9)
plt.show()
