# CoDA (uncertainty‑aware) — OSM Fig. 4c/4i/4j with **CSCG‑style color‑coded states**

**New:** CSCG‑style node colors (uniform gray arrows), three coloring modes (`blocks`, `state_id`, `obs_step`).

In [None]:

import os
import math, random
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch, Circle


In [None]:

def posterior_prob_p_greater_than(theta: float, success: float, failure: float, alpha0: float=0.5, beta0: float=0.5) -> float:
    if not _HAS_MPMATH:
        return 0.0
    a = alpha0 + max(0.0, float(success))
    b = beta0 + max(0.0, float(failure))
    cdf = betainc(a, b, 0, theta, regularized=True)
    return float(1.0 - cdf)

def wilson_lower_bound(phat: float, n: float, confidence: float=0.95) -> float:
    if n <= 0:
        return 0.0
    if _HAS_MPMATH:
        z = float((2.0**0.5) * erfcinv(2*(1.0-confidence)))
    else:
        z = 1.6448536269514722
    denom = 1.0 + (z*z)/n
    center = phat + (z*z)/(2.0*n)
    adj = z * ((phat*(1.0-phat) + (z*z)/(4.0*n))/n)**0.5
    return (center - adj)/denom


In [None]:

near = [1,1,1,1,1,1, 2,2,2,2, 1,1,1, 4,6, 1,1,1, 5,5, 1,1, 7, 0,0,0]
far  = [1,1,1,1,1,1, 3,3,3,3, 1,1,1, 4,4, 1,1,1, 5,6, 1,1, 7, 0,0,0]

preR1_idx = list(range(10,13))
preR2_idx = list(range(15,18))

def block_indices(rows, cols): return [(r,c) for r in rows for c in cols]
offdiag_pairs     = block_indices(preR1_idx, preR2_idx) + block_indices(preR2_idx, preR1_idx)
same_preR1_pairs  = block_indices(preR1_idx, preR1_idx)
same_preR2_pairs  = block_indices(preR2_idx, preR2_idx)


In [None]:
# --- CoDA agent (trial-by-trial; paper-faithful) ---
# This notebook-local agent is designed to follow the *key concepts* in the CoDA paper:
#   1) Outcome-conditioned eligibility traces (Alg. 1) accumulate evidence for cues.
#   2) Prospective contingency (PC) triggers *state-space augmentation* via successor cloning:
#        when a cue becomes salient, we clone its outgoing successor states and rewire edges.
#   3) Utility = (prospective × retrospective) can trigger merging (Alg. 2).
# And it also includes the *practical stabilizers* used in coda_trial_by_trial_util.py:
#   - confidence-gated splitting (Wilson lower bound)
#   - evidence / exposure gates
#   - optional decay (counts / traces / retrospective)
#   - edge-mass-based merging safeguard
#
# NOTE: This agent is action-free because the notebook is a 1D symbol stream.
# We therefore learn a deterministic "cognitive graph" over latent states by keying edges
# on the next *observation symbol* (parent_sid, next_obs) -> child_sid, with counts for evidence.

from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
import math
import numpy as np

# -----------------------------
# Config (mirrors key knobs in paper + python implementation)
# -----------------------------

@dataclass
class UncCfg:
    # Eligibility traces (Alg. 1)
    gamma: float = 0.9
    lam: float = 0.8

    # Split / merge thresholds (Alg. 2)
    theta_split: float = 0.7
    theta_merge: float = 0.5

    # Evidence gates
    n_threshold: float = 2.0              # minimum effective "n" for contingency estimate
    min_presence_episodes: int = 2        # must appear in >= this many episodes
    min_effective_exposure: float = 2.0   # must accumulate >= this much eligibility exposure

    # Confidence gate (as in python file): require Wilson LB > theta_split
    split_confidence: float = 0.75

    # Optional forgetting / decay (stabilizers, as in python file)
    count_decay: float = 1.0              # transition-count decay per episode (1.0 = no decay)
    trace_decay: float = 1.0              # eligibility accumulator decay per episode
    retro_decay: float = 1.0              # retrospective EMA decay per episode

    # Merge safeguard: if cue's outgoing mass to cue-created clones is tiny, merge
    edge_eps: float = 1e-6

    # Semantics for this notebook
    reset_symbols: Tuple[int, ...] = (0,)  # symbols that reset "episode context"
    us_classes: Tuple[int, int] = (4, 5)   # two US event symbols used in the notebook


# -----------------------------
# Helpers: confidence gating (Wilson LB)
# -----------------------------

def _z_one_sided(confidence: float) -> float:
    # common one-sided z's (enough for notebook usage)
    table = {0.90: 1.2815515655446004,
             0.95: 1.6448536269514722,
             0.975: 1.959963984540054,
             0.99: 2.3263478740408408}
    key = min(table.keys(), key=lambda k: abs(k - confidence))
    return table[key]

def wilson_lower_bound(phat: float, n: float, confidence: float = 0.95) -> float:
    """One-sided Wilson lower bound for a Bernoulli proportion."""
    if n <= 0:
        return 0.0
    z = _z_one_sided(confidence)
    denom = 1.0 + (z * z) / n
    center = phat + (z * z) / (2.0 * n)
    adj = z * math.sqrt((phat * (1.0 - phat) + (z * z) / (4.0 * n)) / n)
    return (center - adj) / denom


# -----------------------------
# CoDA agent
# -----------------------------

class CoDAUncAgent:
    """
    Paper-faithful CoDA for a symbol stream.

    Latent states live in a "cognitive graph" G:
      edges[parent_sid][next_obs] = child_sid

    Each latent node stores:
      - obs: the aliased observation symbol
      - path: None for original; "R1"/"R2" for clones (label derived from which US it predicts)
      - original: original state id for this observation (if clone)
      - root_cue: which cue caused this clone lineage (for merging descendants)
      - alive: whether this node is active (merged nodes are kept but deactivated)

    Learning:
      - Maintain outcome-conditioned eligibility accumulators:
          E[sid][u]  (eligibility mass credited at times US=u occurs)
          exposure[sid]  (total eligibility exposure mass)
          C[sid]         (effective count; here = sum_u E[sid][u])
      - Retrospective:
          us_episode_ema[u] and cs_us_presence_ema[sid][u]
      - Split:
          if max_u (E[sid][u]/C[sid]) has Wilson-LB > theta_split
          and evidence gates pass -> mark sid salient with label ("R1"/"R2"),
          then clone *all outgoing successors* of sid and rewire edges sid->clone(succ).
      - Merge:
          For a salient cue sid with winning u*:
             PC = E[sid][u*]/C[sid]
             RC = cs_us_presence_ema[sid][u*] / us_episode_ema[u*]
             utility = PC * RC
          If utility < theta_merge OR edge_mass_to_clones < edge_eps -> merge descendants.

    This reproduces the *core components* of the paper while staying compatible with the notebook
    (encode_sequence + near_far_corr + states[sid]['path'] for plotting).
    """

    def __init__(self, obs_symbols: List[int], cfg: UncCfg = UncCfg()):
        self.cfg = cfg
        self.us_classes = list(cfg.us_classes)

        # --- latent node store ---
        self.states: Dict[int, Dict] = {}
        self.obs_to_original: Dict[int, int] = {}
        self.obs_to_state_ids: Dict[int, List[int]] = {o: [] for o in obs_symbols}

        # create originals
        sid = 0
        for o in obs_symbols:
            self.states[sid] = dict(obs=o, path=None, original=sid, root_cue=None, alive=True)
            self.obs_to_original[o] = sid
            self.obs_to_state_ids[o].append(sid)
            sid += 1
        self._next_sid = sid

        # --- deterministic graph edges + counts ---
        self.edges: Dict[int, Dict[int, int]] = {s: {} for s in self.states}         # parent -> next_obs -> child
        self.edge_counts: Dict[int, Dict[int, float]] = {s: {} for s in self.states} # parent -> next_obs -> count

        # --- outcome-conditioned eligibility (Alg 1) ---
        self.E: Dict[int, Dict[int, float]] = {s: {u: 0.0 for u in self.us_classes} for s in self.states}
        # backward-compat: notebook checkpoint code expects `co_occ`
        self.co_occ = self.E
        self.exposure: Dict[int, float] = {s: 0.0 for s in self.states}  # total eligibility exposure
        self.C: Dict[int, float] = {s: 0.0 for s in self.states}         # effective evidence count proxy

        # --- retrospective EMA (for utility / merge) ---
        self.us_episode_ema: Dict[int, float] = {u: 0.0 for u in self.us_classes}
        self.cs_us_presence_ema: Dict[int, Dict[int, float]] = {s: {u: 0.0 for u in self.us_classes} for s in self.states}

        # evidence gate: how many episodes did this state appear in?
        self.presence_episodes: Dict[int, int] = {s: 0 for s in self.states}

        # salient cues: sid -> ("R1"/"R2", winning_us)
        self.salient: Dict[int, Tuple[str, int]] = {}

        # bookkeeping: cue -> set of clone ids it created (direct)
        self.cue_to_clones: Dict[int, Set[int]] = {}

    # ---------------------
    # utilities
    # ---------------------

    def _ensure_state_structs(self, sid: int) -> None:
        if sid not in self.E:
            self.E[sid] = {u: 0.0 for u in self.us_classes}
            self.exposure[sid] = 0.0
            self.C[sid] = 0.0
            self.cs_us_presence_ema[sid] = {u: 0.0 for u in self.us_classes}
            self.presence_episodes[sid] = 0
            self.edges.setdefault(sid, {})
            self.edge_counts.setdefault(sid, {})

    def _alloc_sid(self) -> int:
        sid = self._next_sid
        self._next_sid += 1
        return sid

    def _clone_node(self, base_sid: int, *, root_cue: int, path_label: str) -> int:
        """Clone a successor state as in Alg. 2: same obs, copy outgoing edges."""
        base = self.states[base_sid]
        new_sid = self._alloc_sid()
        obs = base['obs']
        self.states[new_sid] = dict(
            obs=obs,
            path=path_label,
            original=base['original'],
            root_cue=root_cue,
            alive=True,
        )
        self._ensure_state_structs(new_sid)

        # register by observation
        self.obs_to_state_ids.setdefault(obs, []).append(new_sid)

        # copy outgoing edges + counts
        self.edges[new_sid] = dict(self.edges.get(base_sid, {}))
        self.edge_counts[new_sid] = dict(self.edge_counts.get(base_sid, {}))

        return new_sid

    def _is_descendant_of(self, sid: int, cue: int) -> bool:
        st = self.states.get(sid, None)
        return (st is not None) and (st.get('root_cue', None) == cue)

    # ---------------------
    # Prospective / retrospective
    # ---------------------

    def prospective_dist(self, sid: int) -> Dict[int, float]:
        """P(US=u | CS=sid) over u in us_classes (normalized E)."""
        self._ensure_state_structs(sid)
        tot = sum(self.E[sid][u] for u in self.us_classes)
        if tot <= 0:
            return {u: 0.0 for u in self.us_classes}
        return {u: self.E[sid][u] / tot for u in self.us_classes}

    def _winning_us_and_pc(self, sid: int) -> Tuple[int, float, float]:
        """
        Return (u_star, phat, n_eff) where phat = max_u E[sid][u]/sum_u E[sid][u]
        and n_eff = sum_u E[sid][u] (effective sample size).
        """
        self._ensure_state_structs(sid)
        n_eff = sum(self.E[sid][u] for u in self.us_classes)
        if n_eff <= 0:
            return self.us_classes[0], 0.0, 0.0
        u_star = max(self.us_classes, key=lambda u: self.E[sid][u])
        phat = self.E[sid][u_star] / n_eff
        return u_star, phat, n_eff

    def retrospective(self, sid: int, u: int) -> float:
        """P(CS=sid present | US=u present), EMA-based."""
        self._ensure_state_structs(sid)
        denom = self.us_episode_ema.get(u, 0.0)
        if denom <= 0:
            return 0.0
        return self.cs_us_presence_ema[sid].get(u, 0.0) / denom

    def utility(self, sid: int) -> float:
        """utility = PC * RC for the cue's winning US."""
        if sid not in self.salient:
            # use current winning US anyway
            u_star, phat, _ = self._winning_us_and_pc(sid)
        else:
            _, u_star = self.salient[sid]
            u_star, phat, _ = self._winning_us_and_pc(sid)  # recompute phat on current stats
        rc = self.retrospective(sid, u_star)
        return phat * rc

    # ---------------------
    # Graph rollout / episode update
    # ---------------------

    def rollout_latent(self, obs_seq: List[int], create_missing: bool = True) -> List[int]:
        """Map observation sequence to latent states by following edges (creating defaults if missing)."""
        if not obs_seq:
            return []
        # start at original for first obs
        cur = self.obs_to_original[obs_seq[0]]
        latent = [cur]

        for nxt_obs in obs_seq[1:]:
            out = self.edges.get(cur, {})
            if nxt_obs in out:
                nxt = out[nxt_obs]
            else:
                nxt = self.obs_to_original[nxt_obs]
                if create_missing:
                    out[nxt_obs] = nxt
                    self.edges[cur] = out
            cur = nxt
            latent.append(cur)
        return latent

    def update_with_episode(self, obs_seq: List[int], learn: bool = True) -> List[int]:
        """
        Runs an episode, optionally learning:
          - update transition counts
          - update eligibility E
          - update retrospective EMA
          - run split/merge
        Returns latent state sequence.
        """
        latent = self.rollout_latent(obs_seq, create_missing=learn)
        if not learn:
            return latent

        # 0) decay (stabilizers)
        if self.cfg.count_decay != 1.0:
            for p in list(self.edge_counts.keys()):
                for o in list(self.edge_counts[p].keys()):
                    self.edge_counts[p][o] *= self.cfg.count_decay
                    if self.edge_counts[p][o] < 1e-12:
                        self.edge_counts[p].pop(o, None)
        if self.cfg.trace_decay != 1.0:
            for s in list(self.E.keys()):
                for u in self.us_classes:
                    self.E[s][u] *= self.cfg.trace_decay
                self.exposure[s] *= self.cfg.trace_decay
                self.C[s] *= self.cfg.trace_decay
        if self.cfg.retro_decay != 1.0:
            for u in self.us_classes:
                self.us_episode_ema[u] *= self.cfg.retro_decay
            for s in list(self.cs_us_presence_ema.keys()):
                for u in self.us_classes:
                    self.cs_us_presence_ema[s][u] *= self.cfg.retro_decay

        # 1) update transition counts along the episode
        visited = set(latent)
        for sid in visited:
            self.presence_episodes[sid] = self.presence_episodes.get(sid, 0) + 1

        for t in range(len(latent) - 1):
            p = latent[t]
            nxt_obs = obs_seq[t + 1]
            self.edge_counts.setdefault(p, {})
            self.edge_counts[p][nxt_obs] = self.edge_counts[p].get(nxt_obs, 0.0) + 1.0

        # 2) retrospective EMA: which US occurred in this episode?
        us_present = {u: any(o == u for o in obs_seq) for u in self.us_classes}
        for u, pres in us_present.items():
            if pres:
                self.us_episode_ema[u] = self.us_episode_ema.get(u, 0.0) + 1.0
                for sid in visited:
                    self._ensure_state_structs(sid)
                    self.cs_us_presence_ema[sid][u] = self.cs_us_presence_ema[sid].get(u, 0.0) + 1.0

        # 3) contextual eligibility traces (Alg 1): snapshot trace at each time
        decay = self.cfg.gamma * self.cfg.lam
        trace: Dict[int, float] = {}
        snapshots: List[Dict[int, float]] = []
        for sid in latent:
            # decay trace
            for k in list(trace.keys()):
                trace[k] *= decay
                if trace[k] < 1e-12:
                    trace.pop(k, None)
            trace[sid] = trace.get(sid, 0.0) + 1.0
            snapshots.append(dict(trace))

        # 4) add snapshot mass at each US event time to E[*][u]
        for t, o in enumerate(obs_seq):
            if o not in self.us_classes:
                continue
            snap = snapshots[t]
            u = o
            for k, v in snap.items():
                self._ensure_state_structs(k)
                self.E[k][u] = self.E[k].get(u, 0.0) + v
                self.exposure[k] = self.exposure.get(k, 0.0) + v
                # evidence proxy tracks total mass credited to any US
                self.C[k] = self.C.get(k, 0.0) + v

        # 5) split / merge updates (Alg 2)
        self._maybe_split()
        self._maybe_merge()

        return latent

    # ---------------------
    # Split / merge (Alg 2)
    # ---------------------

    def _maybe_split(self) -> None:
        # consider all currently alive states
        for sid in list(self.states.keys()):
            if not self.states[sid].get('alive', True):
                continue
            if sid in self.salient:
                continue

            # evidence gates
            if self.presence_episodes.get(sid, 0) < self.cfg.min_presence_episodes:
                continue
            if self.exposure.get(sid, 0.0) < self.cfg.min_effective_exposure:
                continue
            if self.C.get(sid, 0.0) < self.cfg.n_threshold:
                continue

            u_star, phat, n_eff = self._winning_us_and_pc(sid)
            lb = wilson_lower_bound(phat, n_eff, confidence=self.cfg.split_confidence)
            if lb <= self.cfg.theta_split:
                continue

            # Become salient cue. Label path by winning US.
            label = "R1" if u_star == self.us_classes[0] else "R2"
            self.salient[sid] = (label, u_star)

            # Split: clone *all outgoing successors* and rewire edges sid->clone(succ)
            self._split_cue_successors(sid, label)

    def _split_cue_successors(self, cue_sid: int, label: str) -> None:
        out = self.edges.get(cue_sid, {})
        if cue_sid not in self.cue_to_clones:
            self.cue_to_clones[cue_sid] = set()

        # iterate over snapshot because we mutate
        for nxt_obs, child_sid in list(out.items()):
            # avoid repeated cloning if already a cue-descendant clone
            if self._is_descendant_of(child_sid, cue_sid):
                continue
            clone_sid = self._clone_node(child_sid, root_cue=cue_sid, path_label=label)
            out[nxt_obs] = clone_sid
            self.cue_to_clones[cue_sid].add(clone_sid)

        self.edges[cue_sid] = out

    def _edge_mass_to_descendants(self, cue_sid: int) -> float:
        """Fraction of cue's outgoing transition count mass that goes to descendants of cue."""
        counts = self.edge_counts.get(cue_sid, {})
        if not counts:
            return 0.0
        tot = sum(counts.values())
        if tot <= 0:
            return 0.0
        mass = 0.0
        out = self.edges.get(cue_sid, {})
        for nxt_obs, cnt in counts.items():
            child = out.get(nxt_obs, None)
            if child is not None and self._is_descendant_of(child, cue_sid):
                mass += cnt
        return mass / tot

    def _maybe_merge(self) -> None:
        if self.cfg.theta_merge <= 0:
            return

        for cue_sid in list(self.salient.keys()):
            if not self.states[cue_sid].get('alive', True):
                self.salient.pop(cue_sid, None)
                continue

            # Compute utility using cue's winning US
            label, u_star = self.salient[cue_sid]
            # recompute PC on current stats
            u_star_now, phat, _ = self._winning_us_and_pc(cue_sid)
            # keep u_star from salience unless it vanished
            if self.C.get(cue_sid, 0.0) > 0:
                u_use = u_star
            else:
                u_use = u_star_now

            rc = self.retrospective(cue_sid, u_use)
            util = phat * rc

            # extra merge safeguard from python file
            edge_mass = self._edge_mass_to_descendants(cue_sid)

            if (util < self.cfg.theta_merge) or (edge_mass < self.cfg.edge_eps):
                self._merge_descendants_of_cue(cue_sid)
                self.salient.pop(cue_sid, None)

    def _merge_descendants_of_cue(self, cue_sid: int) -> None:
        """
        Merge all descendants (clones with root_cue==cue_sid) back into originals:
          - redirect edges that point to descendants back to descendant.original
          - mark descendant nodes as inactive (alive=False)
        """
        descendants = {sid for sid, st in self.states.items()
                       if st.get('alive', True) and st.get('root_cue', None) == cue_sid}
        if not descendants:
            return

        # Redirect all incoming edges
        for p in list(self.edges.keys()):
            for nxt_obs, child in list(self.edges[p].items()):
                if child in descendants:
                    orig = self.states[child].get('original', child)
                    self.edges[p][nxt_obs] = orig

        # Deactivate descendants (keep ids stable for notebook)
        for sid in descendants:
            self.states[sid]['alive'] = False

    # ---------------------
    # Notebook interface
    # ---------------------

    def run_episode(self, obs_seq, learn: bool = True):
        return self.update_with_episode(list(obs_seq), learn=learn)

    def encode_sequence(self, obs_seq) -> np.ndarray:
        lat = self.run_episode(obs_seq, learn=False)
        S = self._next_sid
        X = np.zeros((len(lat), S), dtype=float)
        for t, sid in enumerate(lat):
            if self.states.get(sid, {}).get('alive', True):
                X[t, sid] = 1.0
            else:
                # If a merged-out state appears (should be rare), fall back to its original
                orig = self.states.get(sid, {}).get('original', sid)
                if orig < S:
                    X[t, orig] = 1.0
        return X[:, :self._next_sid]

    def near_far_corr(self, near_seq, far_seq) -> np.ndarray:
        A = self.encode_sequence(near_seq)
        B = self.encode_sequence(far_seq)
        C = np.zeros((A.shape[0], B.shape[0]))
        for i in range(A.shape[0]):
            for j in range(B.shape[0]):
                a = A[i]; b = B[j]
                if np.allclose(a, 0) or np.allclose(b, 0):
                    C[i, j] = 0.0
                    continue
                a0 = a - a.mean()
                b0 = b - b.mean()
                den = (np.linalg.norm(a0) * np.linalg.norm(b0))
                C[i, j] = (a0 @ b0) / den if den > 0 else 0.0
        return C


In [None]:

N_RUNS=6; SESSIONS=9; TRIALS_PER_SESSION=80; THRESH=0.3
def block_mean(C, pairs): 
    return float(np.mean([C[i,j] for (i,j) in pairs])) if pairs else np.nan
rng=np.random.default_rng(123)
all_final_blocks=[]; all_time_to_thr=[]; mat_by_session={s:[] for s in [1,3,4,9]}
checkpoint_sessions=[1,3,4,9]; demo_checkpoints={}

for run in range(N_RUNS):
    agent=CoDAUncAgent(obs_symbols=sorted(set(near)|set(far)), cfg=UncCfg())
    tt={'offdiag':None,'preR2':None,'preR1':None}
    for session in range(1, SESSIONS+1):
        episodes=[near]*(TRIALS_PER_SESSION//2)+[far]*(TRIALS_PER_SESSION//2)
        rng.shuffle(episodes)
        for ep in episodes: agent.run_episode(ep, learn=True)
        C=agent.near_far_corr(near, far)
        if session in mat_by_session: mat_by_session[session].append(C)
        if run==0 and session in checkpoint_sessions:
            snap=CoDAUncAgent(obs_symbols=sorted(set(near)|set(far)), cfg=agent.cfg)
            snap.states={k:v.copy() for k,v in agent.states.items()}
            snap._next_sid=agent._next_sid
            snap.obs_to_state_ids={k:list(v) for k,v in agent.obs_to_state_ids.items()}
            snap.co_occ={k:dict(v) for k,v in agent.co_occ.items()}
            snap.exposure=dict(agent.exposure); snap.presence_episodes=dict(agent.presence_episodes)
            snap.salient=dict(agent.salient); demo_checkpoints[session]=snap
        b_off=block_mean(C,offdiag_pairs); b_r2=block_mean(C,same_preR2_pairs); b_r1=block_mean(C,same_preR1_pairs)
        if tt['offdiag'] is None and b_off<THRESH: tt['offdiag']=session
        if tt['preR2']  is None and b_r2<THRESH: tt['preR2']=session
        if tt['preR1']  is None and b_r1<THRESH: tt['preR1']=session
    C_final=agent.near_far_corr(near, far)
    all_final_blocks.append((run, block_mean(C_final, offdiag_pairs), block_mean(C_final, same_preR2_pairs), block_mean(C_final, same_preR1_pairs)))
    def norm(x): return x/SESSIONS if x is not None else np.nan
    all_time_to_thr.append((run, norm(tt['offdiag']), norm(tt['preR2']), norm(tt['preR1'])))


    print("salient cues:", agent.salient)   # should become non-empty
    print("num clones:", sum(1 for s in agent.states.values() if s.get("is_clone", False)))
    for cue, clones in getattr(agent, "cue_to_clones", {}).items():
        print("cue", cue, "-> #clones", len(clones))


blocks_df=pd.DataFrame(all_final_blocks, columns=['run','offdiag','preR2','preR1']).set_index('run')
times_df =pd.DataFrame(all_time_to_thr, columns=['run','offdiag_t','preR2_t','preR1_t']).set_index('run')


In [None]:


def canonical_latents(agent, seq):
    lat = agent.run_episode(seq, learn=False)
    out = []
    for t, (obs, sid) in enumerate(zip(seq, lat)):
        if obs == 0:
            break
        out.append((t, obs, sid, agent.states[sid]['path']))
    return out


def ring_positions(T: int, r_base=1.0, r_near=1.12, r_far=0.88, start_angle=np.pi/2):
    pos_by_t = {}
    for t in range(T):
        theta = start_angle - 2 * np.pi * (t / max(T, 1))
        pos_by_t[t] = {
            'theta': theta,
            'base': (r_base * np.cos(theta), r_base * np.sin(theta)),
            'R1': (r_near * np.cos(theta), r_near * np.sin(theta)),
            'R2': (r_far * np.cos(theta), r_far * np.sin(theta)),
        }
    return pos_by_t


def build_color_assignments(scheme: str, near_lat, far_lat):
    import matplotlib as mpl

    colors = {'near': {}, 'far': {}}

    def set_color(target, lat_seq, fn):
        for (t, obs, sid, path) in lat_seq:
            ring = 'base' if path is None else path
            target[(t, ring)] = fn(t, obs, sid, ring)

    if scheme == 'state_id':
        tab20 = mpl.cm.get_cmap('tab20')
        set_color(colors['near'], near_lat, lambda t, o, s, ring: tab20(s % 20))
        set_color(colors['far'],  far_lat,  lambda t, o, s, ring: tab20(s % 20))

    elif scheme == 'obs_step':
        hsv = mpl.cm.get_cmap('hsv')
        set_color(colors['near'], near_lat, lambda t, o, s, ring: hsv((t % 26) / 26.0))
        set_color(colors['far'],  far_lat,  lambda t, o, s, ring: hsv((t % 26) / 26.0))

    else:
        # Default: color-code by the "third element" (latent state id) with overrides
        unique_sids = sorted({sid for (_, _, sid, _) in near_lat + far_lat})
        cmap = mpl.cm.get_cmap('tab20', max(1, len(unique_sids)))
        sid_to_color = {sid: cmap(i % cmap.N) for i, sid in enumerate(unique_sids)}
        overrides = {1: '#b3b3b3', 5: '#9467bd', 9: '#6a6a6a', 11: '#000000', 6: '#d62728', 8: '#d62728'}
        sid_to_color.update({sid: color for sid, color in overrides.items() if sid in unique_sids})

        set_color(colors['near'], near_lat, lambda t, o, s, ring: sid_to_color[s])
        set_color(colors['far'],  far_lat,  lambda t, o, s, ring: sid_to_color[s])

    return colors


def draw_cscg_style_ring(
    ax,
    agent,
    title=None,
    scheme='blocks',
    close_loop=True,
    edge_color='#5a5a5a',
    edge_lw=2.2,
 ):
    # --- build sequences and ring geometry ---
    near_lat = canonical_latents(agent, near)
    far_lat = canonical_latents(agent, far)
    T = max(len(near_lat), len(far_lat))
    pos = ring_positions(T)

    # --- node colors and occupancy ---
    node_colors = build_color_assignments(scheme, near_lat, far_lat)
    from collections import defaultdict
    occupancy = defaultdict(list)
    sid_shared = set()
    near_by_t = {t: sid for (t, _, sid, _) in near_lat}
    far_by_t = {t: sid for (t, _, sid, _) in far_lat}
    for t, sid in near_by_t.items():
        if far_by_t.get(t) == sid:
            sid_shared.add((t, sid))
    for label, lat_seq in (('near', near_lat), ('far', far_lat)):
        for (t, obs, sid, path) in lat_seq:
            ring = 'base' if path is None else path
            occupancy[(t, ring)].append((label, sid))

    node_positions = {}
    offset_mag = 0.07
    radial_offset = 0.09
    prev_pos = {'near': None, 'far': None}
    all_times = sorted({t for (t, _) in occupancy.keys()})
    rings_order = ['base', 'R1', 'R2']
    force_swap_times = {6, 19}
    arrow_mutation = 14.0

    for t in all_times:
        for ring in rings_order:
            key = (t, ring)
            if key not in occupancy:
                continue
            entries = occupancy[key]
            base_x, base_y = pos[t][ring]
            sids = {sid for (_, sid) in entries}
            shared_here = any((t, sid) in sid_shared for (_, sid) in entries)
            if len(entries) == 1 or (len(sids) == 1 and shared_here):
                for label, sid in entries:
                    node_positions[(label, t, ring)] = (base_x, base_y)
                    prev_pos[label] = (base_x, base_y)
                continue
            labels_here = {label for (label, _) in entries}
            vx, vy = base_x, base_y
            norm = (vx * vx + vy * vy) ** 0.5
            if norm == 0.0:
                ux, uy = 1.0, 0.0
            else:
                ux, uy = vx / norm, vy / norm
            if len(entries) == 2 and labels_here == {'near', 'far'}:
                def assign(sign_map):
                    out = {}
                    for label, sid in entries:
                        direction = radial_offset * sign_map[label]
                        out[label] = (base_x + direction * ux, base_y + direction * uy)
                    return out

                default_map = {'near': 1.0, 'far': -1.0}
                swapped_map = {'near': -1.0, 'far': 1.0}

                def mapping_cost(mapped):
                    cost = 0.0
                    for label, (x, y) in mapped.items():
                        prev = prev_pos.get(label)
                        if prev is None:
                            continue
                        px, py = prev
                        cost += (x - px) * (x - px) + (y - py) * (y - py)
                    return cost

                opt_default = assign(default_map)
                opt_swapped = assign(swapped_map)
                if t in force_swap_times:
                    chosen = opt_swapped
                elif mapping_cost(opt_swapped) < mapping_cost(opt_default):
                    chosen = opt_swapped
                else:
                    chosen = opt_default
                for label, coords in chosen.items():
                    node_positions[(label, t, ring)] = coords
                    prev_pos[label] = coords
                continue
            if len(entries) == 1:
                label, sid = entries[0]
                node_positions[(label, t, ring)] = (base_x, base_y)
                prev_pos[label] = (base_x, base_y)
                continue
            if norm == 0.0:
                px, py = 0.0, 1.0
            else:
                px, py = -vy / norm, vx / norm
            ordered = sorted(entries, key=lambda item: item[0])
            span = len(ordered) - 1
            for idx, (label, sid) in enumerate(ordered):
                factor = idx - span / 2.0
                coords = (base_x + factor * offset_mag * px, base_y + factor * offset_mag * py)
                node_positions[(label, t, ring)] = coords
                prev_pos[label] = coords

    def node_xy(label, t, ring):
        return node_positions[(label, t, ring)]

    def draw_edges(lat_seq, label, z=1):
        for i in range(len(lat_seq) - 1):
            t, _, _, p = lat_seq[i]
            t2, _, _, p2 = lat_seq[i + 1]
            r1 = 'base' if p is None else p
            r2 = 'base' if p2 is None else p2
            x1, y1 = node_xy(label, t, r1)
            x2, y2 = node_xy(label, t2, r2)
            ax.add_patch(
                FancyArrowPatch(
                    (x1, y1),
                    (x2, y2),
                    arrowstyle='-|>',
                    mutation_scale=arrow_mutation,
                    lw=edge_lw,
                    color=edge_color,
                    alpha=0.9,
                    shrinkA=2.5,
                    shrinkB=2.5,
                    zorder=z,
                )
            )
        if close_loop and len(lat_seq) >= 2:
            t0, _, _, p0 = lat_seq[0]
            tL, _, _, pL = lat_seq[-1]
            r0 = 'base' if p0 is None else p0
            rL = 'base' if pL is None else pL
            x1, y1 = node_xy(label, tL, rL)
            x2, y2 = node_xy(label, t0, r0)
            ax.add_patch(
                FancyArrowPatch(
                    (x1, y1),
                    (x2, y2),
                    arrowstyle='-|>',
                    mutation_scale=arrow_mutation,
                    lw=edge_lw,
                    color=edge_color,
                    alpha=0.9,
                    shrinkA=2.5,
                    shrinkB=2.5,
                    zorder=z,
                )
            )

    def draw_nodes(lat_seq, label, z=3):
        for (t, obs, sid, path) in lat_seq:
            ring = 'base' if path is None else path
            x, y = node_xy(label, t, ring)
            face = node_colors[label].get((t, ring), '#bfbfbf')
            ax.add_patch(
                Circle(
                    (x, y),
                    radius=0.065,
                    facecolor=face,
                    edgecolor='#3a3a3a',
                    lw=1.1,
                    alpha=0.98,
                    zorder=z,
                )
            )

    draw_edges(near_lat, 'near', z=1)
    draw_edges(far_lat, 'far', z=1)
    draw_nodes(far_lat, 'far', z=3)
    draw_nodes(near_lat, 'near', z=4)

    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_xlim(-1.55, 1.55)
    ax.set_ylim(-1.55, 1.55)
    if title:
        ax.set_title(title, fontsize=20)


In [None]:
near_lat = canonical_latents(demo_checkpoints[4], near)
near_lat

In [None]:
far_lat = canonical_latents(demo_checkpoints[4], far)
far_lat

In [None]:
import os
os.makedirs("figures/poster", exist_ok=True)

export_specs = [
    ("svg", {"format": "svg"}),
    ("pdf", {"format": "pdf"}),
    ("png", {"format": "png", "dpi": 400})
]

final_session = max([s for s in [1, 3, 4, 9] if s in demo_checkpoints])
fig, ax = plt.subplots(figsize=(7.4, 4.6), constrained_layout=True)
draw_cscg_style_ring(
    ax,
    demo_checkpoints[final_session],
    title="CoDA — transition graph (CSCG-style colors)",
    scheme='blocks'
)
for ext, kwargs in export_specs:
    fig.savefig(
        os.path.join("figures", "poster", f"coda_transition_final_session.{ext}"),
        bbox_inches='tight',
        **kwargs
    )
plt.show()

for s in sorted(demo_checkpoints.keys()):
    fig_s, ax_s = plt.subplots(figsize=(7.4, 4.6), constrained_layout=True)
    draw_cscg_style_ring(
        ax_s,
        demo_checkpoints[s],
        title=f"Session {s}",
        scheme='blocks'
    )
    for ext, kwargs in export_specs:
        fig_s.savefig(
            os.path.join("figures", "poster", f"coda_transition_session_{s}.{ext}"),
            bbox_inches='tight',
            **kwargs
        )
    plt.show()


In [None]:
# === Trajectory of decorrelation during learning (CSCG-style) ===
import matplotlib as mpl

check_sessions = [1, 3, 4, 9]
mean_mats = {s: np.mean(mat_by_session[s], axis=0) for s in check_sessions}

# CSCG-style colormap (dark background, warm high values)
# 'magma' or 'inferno' gives similar look; both start dark→warm→bright
cmap = mpl.cm.get_cmap('magma')

# Pre/indicator windows to highlight (indices are inclusive, 0-based)
highlight_windows = [(6, 10), (13, 15), (18, 20)]
dash_lines = [6, 10, 13, 15, 18, 20]

# Make independent figures for each checkpoint session
for s in check_sessions:
    fig, ax = plt.subplots(figsize=(4.8, 4.8), constrained_layout=True)

    last_im = ax.imshow(
        mean_mats[s],
        vmin=-0.1,
        vmax=1.0,
        cmap=cmap,
        origin='upper',
        aspect='equal',
        interpolation='nearest'
    )

    ny, nx = mean_mats[s].shape
    ax.set_xticks(np.arange(-0.5, nx, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, ny, 1), minor=True)
    ax.grid(which='minor', color='white', linestyle=':', linewidth=0.4)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.tick_params(colors='white', labelsize=0, length=0)
    for spine in ax.spines.values():
        spine.set_visible(False)

    for idx in dash_lines:
        if idx <= nx:
            ax.axvline(idx - 0.5, color='white', linewidth=1.2, linestyle=(0, (2, 4)))
        if idx <= ny:
            ax.axhline(idx - 0.5, color='white', linewidth=1.2, linestyle=(0, (2, 4)))

    for low, high in highlight_windows:
        if high <= min(nx, ny):
            x0, x1 = low - 0.5, high - 0.5
            y0, y1 = low - 0.5, high - 0.5
            ax.plot(
                [x0, x1, x1, x0, x0],
                [y0, y0, y1, y1, y0],
                color='white',
                linewidth=2.5
            )

    fig.patch.set_facecolor('black')
    ax.set_facecolor('black')

    cbar = fig.colorbar(last_im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("")
    cbar.set_ticks([])
    cbar.ax.tick_params(length=0)

    plt.show()


In [None]:

# Fig. 4i
means = blocks_df.mean(); ses = blocks_df.sem()
labels = ['offdiag','preR2','preR1']
x = np.arange(len(labels)); y = [means[l] for l in labels]; yerr = [ses[l] for l in labels]
fig, ax = plt.subplots(figsize=(6.5,4.6))
bars = ax.bar(x, y, yerr=yerr, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(labels)
ax.set_ylim(0, 1.05)
ax.set_ylabel("Mean correlation (final)")
ax.set_title("CoDA (uncertainty‑aware) — Fig. 4i analogue")
for i, (b, val) in enumerate(zip(bars, y)):
    ax.text(b.get_x()+b.get_width()/2, (val-0.05 if val>0.85 else min(1.02, val+0.05)), f"{val:.2f}", ha='center',
            va=('top' if val>0.85 else 'bottom'), color=('white' if val>0.85 else 'black'), fontsize=9)
fig.tight_layout(); plt.show()

# Fig. 4j
means_t = times_df.mean(skipna=True); ses_t = times_df.sem(skipna=True)
labels_t = ['offdiag_t','preR2_t','preR1_t']; disp = ['offdiag','preR2','preR1']
x = np.arange(len(labels_t)); y = [means_t[l] for l in labels_t]; yerr = [ses_t[l] for l in labels_t]
fig, ax = plt.subplots(figsize=(6.5,4.6))
bars = ax.bar(x, y, yerr=yerr, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(disp)
ax.set_ylim(0, 1.05)
ax.set_ylabel("Fraction of training (first corr < 0.3)")
ax.set_title("CoDA (uncertainty‑aware) — Fig. 4j analogue")
for i, (b, val) in enumerate(zip(bars, y)):
    ax.text(b.get_x()+b.get_width()/2, min(1.02, val+0.05), f"{val:.2f}", ha='center', va='bottom', fontsize=9)
fig.tight_layout(); plt.show()
