# CoDA (uncertainty‑aware) — OSM Fig. 4c/4i/4j with **CSCG‑style color‑coded states**

**New:** CSCG‑style node colors (uniform gray arrows), three coloring modes (`blocks`, `state_id`, `obs_step`).

In [None]:

import os
import math, random
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch, Circle


In [None]:

def posterior_prob_p_greater_than(theta: float, success: float, failure: float, alpha0: float=0.5, beta0: float=0.5) -> float:
    if not _HAS_MPMATH:
        return 0.0
    a = alpha0 + max(0.0, float(success))
    b = beta0 + max(0.0, float(failure))
    cdf = betainc(a, b, 0, theta, regularized=True)
    return float(1.0 - cdf)

def wilson_lower_bound(phat: float, n: float, confidence: float=0.95) -> float:
    if n <= 0:
        return 0.0
    if _HAS_MPMATH:
        z = float((2.0**0.5) * erfcinv(2*(1.0-confidence)))
    else:
        z = 1.6448536269514722
    denom = 1.0 + (z*z)/n
    center = phat + (z*z)/(2.0*n)
    adj = z * ((phat*(1.0-phat) + (z*z)/(4.0*n))/n)**0.5
    return (center - adj)/denom


In [None]:

near = [1,1,1,1,1,1, 2,2,2,2, 1,1,1, 4,6, 1,1,1, 5,5, 1,1, 7, 0,0,0]
far  = [1,1,1,1,1,1, 3,3,3,3, 1,1,1, 4,4, 1,1,1, 5,6, 1,1, 7, 0,0,0]

preR1_idx = list(range(10,13))
preR2_idx = list(range(15,18))

def block_indices(rows, cols): return [(r,c) for r in rows for c in cols]
offdiag_pairs     = block_indices(preR1_idx, preR2_idx) + block_indices(preR2_idx, preR1_idx)
same_preR1_pairs  = block_indices(preR1_idx, preR1_idx)
same_preR2_pairs  = block_indices(preR2_idx, preR2_idx)


In [None]:
# --- CoDA agent (from coda_trial_by_trial_util.py) ---
# Goal: use the core CoDA implementation *as-is* from coda_trial_by_trial_util.py,
# and wrap it with the minimal API the notebook expects (run_episode / encode_sequence / near_far_corr).

import sys, os, pathlib, importlib.util
import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Optional

# 1) Locate and load coda_trial_by_trial_util.py (prefer repo-local copy)
def _find_file_upwards(fname: str, start: Optional[pathlib.Path]=None, max_up: int=6) -> Optional[pathlib.Path]:
    start = start or pathlib.Path.cwd()
    cur = start.resolve()
    for _ in range(max_up+1):
        cand = cur / fname
        if cand.exists():
            return cand
        cur = cur.parent
    return None

_coda_path = _find_file_upwards("coda_trial_by_trial_util.py")
if _coda_path is None:
    # fall back to common locations (e.g., this sandbox)
    for p in [pathlib.Path("/mnt/data/coda_trial_by_trial_util.py"), pathlib.Path.home()/"coda_trial_by_trial_util.py"]:
        if p.exists():
            _coda_path = p
            break
if _coda_path is None:
    raise FileNotFoundError("Could not find coda_trial_by_trial_util.py. Put it in the notebook folder or a parent folder.")

# Ensure the module's directory is on sys.path (so its sibling util.py / spatial_environments.py can be found)
_coda_dir = str(_coda_path.parent)
if _coda_dir not in sys.path:
    sys.path.insert(0, _coda_dir)

# 2) Import the core CoDA implementation
spec = importlib.util.spec_from_file_location("coda_trial_by_trial_util", str(_coda_path))
coda_mod = importlib.util.module_from_spec(spec)
assert spec.loader is not None
try:
    spec.loader.exec_module(coda_mod)
except ModuleNotFoundError as e:
    # If you're running this notebook outside the full repo, you may be missing
    # util.py or spatial_environments.py. Provide minimal fallbacks so the core
    # CoDAAgent code can import and run.
    import types as _types
    if e.name == 'util':
        _m = _types.ModuleType('util')
        import numpy as _np
        def compute_eligibility_traces(states, n_states, gamma=0.9, lam=0.8):
            decay = float(gamma) * float(lam)
            e = _np.zeros(n_states, dtype=float)
            T = len(states)
            E = _np.zeros((T, n_states), dtype=float)
            counts = _np.zeros((T, n_states), dtype=float)
            seen = _np.zeros(n_states, dtype=float)
            for t, s in enumerate(states):
                e *= decay
                if 0 <= int(s) < n_states:
                    e[int(s)] += 1.0
                    seen[int(s)] += 1.0
                E[t] = e
                counts[t] = seen
            return E, counts
        def accumulate_conditioned_eligibility_traces(E_r, E_nr, C, states, sprime=None, sprime2=None, n_states=0, lam=0.8, gamma=0.9):
            if len(states) == 0: return E_r, E_nr, C
            term = states[-1]
            decay = float(gamma) * float(lam)
            e = _np.zeros(n_states, dtype=float)
            mass = _np.zeros(n_states, dtype=float)
            for s in states[:-1]:
                e *= decay
                if 0 <= int(s) < n_states:
                    e[int(s)] += 1.0
                mass += e
            # evidence: count visits (nonterminal)
            for s in states[:-1]:
                if 0 <= int(s) < n_states: C[0, int(s)] += 1.0
            if sprime is not None and term == sprime:
                E_r[0, :n_states] += mass
            elif sprime2 is not None and term == sprime2:
                E_nr[0, :n_states] += mass
            return E_r, E_nr, C
        _m.compute_eligibility_traces = compute_eligibility_traces
        _m.accumulate_conditioned_eligibility_traces = accumulate_conditioned_eligibility_traces
        sys.modules['util'] = _m
    if e.name == 'spatial_environments':
        _m = _types.ModuleType('spatial_environments')
        class GridEnvRightDownNoSelf: pass
        class GridEnvRightDownNoCue: pass
        _m.GridEnvRightDownNoSelf = GridEnvRightDownNoSelf
        _m.GridEnvRightDownNoCue = GridEnvRightDownNoCue
        sys.modules['spatial_environments'] = _m
    # retry import after installing fallbacks
    spec.loader.exec_module(coda_mod)

CoDAAgent  = coda_mod.CoDAAgent
CoDAConfig = coda_mod.CoDAConfig

# 3) Minimal environment that satisfies CoDAAgent's interface for *sequence* data used in this notebook
class SequenceEnv:
    """
    A minimal env that provides the attributes/methods CoDAAgent expects.
    We treat each observation symbol as its own base state-id (sid = obs),
    and add two synthetic terminal state-ids:
      - rewarded_terminal: reached for "near" episodes
      - unrewarded_terminal: reached for "far" episodes
    """
    def __init__(self, obs_symbols: List[int], rewarded_terminal: int, unrewarded_terminal: int):
        self.rewarded_terminals   = [int(rewarded_terminal)]
        self.unrewarded_terminals = [int(unrewarded_terminal)]
        self.clone_dict: Dict[int, int] = {}          # clone_id -> successor/original
        self.reverse_clone_dict: Dict[int, int] = {}  # successor/original -> clone_id

        max_sid = max(max(obs_symbols, default=0), rewarded_terminal, unrewarded_terminal)
        self.num_unique_states = int(max_sid + 1)

        # One dummy action (0) everywhere
        self.valid_actions = {s: [0] for s in range(self.num_unique_states)}

    def add_clone_dict(self, new_clone: int, successor: int):
        self.clone_dict[int(new_clone)] = int(successor)

    def add_reverse_clone_dict(self, new_clone: int, successor: int):
        self.reverse_clone_dict[int(successor)] = int(new_clone)


# 4) Notebook-facing wrapper with the same API as the old CoDAUncAgent
@dataclass
class UncCfg:
    # Mirror CoDAConfig defaults (keep notebook calls unchanged)
    gamma: float = 0.9
    lam: float = 0.8
    theta_split: float = 0.9
    theta_merge: float = 0.5
    n_threshold: int = 10
    eps: float = 1e-9
    min_presence_episodes: int = 5
    min_effective_exposure: float = 20.0
    confidence: float = 0.95
    alpha0: float = 0.5
    beta0: float = 0.5
    count_decay: float = 1.0
    trace_decay: float = 1.0
    retro_decay: float = 1.0

class CoDAUncAgent:
    """
    Wraps coda_trial_by_trial_util.CoDAAgent but exposes:
      - run_episode(seq, learn=True/False) -> latent ids per timestep
      - encode_sequence(seq) -> one-hot matrix
      - near_far_corr(near, far) -> correlation matrix as in the notebook
      - states[sid]['path'] -> used only for coloring/inspection
    """
    def __init__(self, obs_symbols: List[int], cfg: Optional[UncCfg]=None):
        self.cfg = cfg or UncCfg()

        # Base states are the observation symbols themselves (except 0 is treated as padding/end)
        self.obs_symbols = sorted(set(int(x) for x in obs_symbols))
        base_states = [s for s in self.obs_symbols if s != 0]

        # Synthetic terminals (keep them outside the obs symbol set)
        max_base = max(base_states) if base_states else 0
        self.rewarded_terminal   = max_base + 1
        self.unrewarded_terminal = max_base + 2

        self.env = SequenceEnv(self.obs_symbols + [self.rewarded_terminal, self.unrewarded_terminal],
                               self.rewarded_terminal, self.unrewarded_terminal)

        core_cfg = CoDAConfig(
            gamma=self.cfg.gamma, lam=self.cfg.lam,
            theta_split=self.cfg.theta_split, theta_merge=self.cfg.theta_merge,
            n_threshold=self.cfg.n_threshold, eps=self.cfg.eps,
            min_presence_episodes=self.cfg.min_presence_episodes,
            min_effective_exposure=self.cfg.min_effective_exposure,
            confidence=self.cfg.confidence, alpha0=self.cfg.alpha0, beta0=self.cfg.beta0,
            count_decay=self.cfg.count_decay, trace_decay=self.cfg.trace_decay, retro_decay=self.cfg.retro_decay
        )
        self.core = CoDAAgent(env=self.env, cfg=core_cfg)

        # Public fields expected by the notebook (for snapshotting/plotting)
        self.states: Dict[int, Dict[str, Optional[str]]] = {}   # sid -> {'path': ...}
        self._next_sid: int = self.core.n_states
        self.obs_to_state_ids: Dict[int, List[int]] = {}
        self.co_occ: Dict[int, Dict[str, float]] = {}
        self.exposure: Dict[int, float] = {}
        self.presence_episodes: Dict[int, float] = {}
        self.salient: Dict[int, bool] = {}

        self._refresh_public_views()

    def _refresh_public_views(self):
        # Update bookkeeping views from core arrays/dicts
        self._next_sid = int(self.core.n_states)

        # states + path labels (only used for coloring)
        # Originals: path=None; clones: path='after_<cue>'
        self.states = {}
        for sid in range(self._next_sid):
            if sid in self.core.clone_parent:
                cue = self.core.created_by_cue.get(sid, None)
                self.states[sid] = {"path": f"after_{cue}" if cue is not None else "clone"}
            else:
                self.states[sid] = {"path": None}

        # obs_to_state_ids: map observation symbol to all latent ids representing it
        self.obs_to_state_ids = {int(o): [] for o in self.obs_symbols if int(o) != 0}
        for sid in range(self._next_sid):
            # underlying observation for a clone is its parent/original state's obs symbol (here: its successor id)
            if sid in self.core.clone_parent:
                obs = int(self.core.clone_parent[sid])
            else:
                obs = int(sid)
            if obs != 0:
                self.obs_to_state_ids.setdefault(obs, []).append(sid)

        # exposure / presence / co-occ style views (used only for snapshotting/debug)
        exp = (self.core.E_r + self.core.E_nr).reshape(-1)
        pres = self.core.presence_episodes.reshape(-1) if hasattr(self.core, "presence_episodes") else None
        self.exposure = {int(i): float(exp[i]) for i in range(min(len(exp), self._next_sid))}
        if pres is not None:
            self.presence_episodes = {int(i): float(pres[i]) for i in range(min(len(pres), self._next_sid))}
        else:
            self.presence_episodes = {}

        self.co_occ = {int(i): {"E_r": float(self.core.E_r[0, i]), "E_nr": float(self.core.E_nr[0, i])}
                       for i in range(min(self._next_sid, self.core.E_r.shape[1]))}
        self.salient = {int(s): True for s in self.core.salient_cues}

    def _episode_label_is_near(self, seq: List[int]) -> bool:
        # In this notebook, "near" contains symbol 2 and "far" contains symbol 3.
        # Use that to assign the binary outcome expected by the core CoDAAgent.
        return (2 in seq) and (3 not in seq)

    def _map_obs_to_latent(self, obs_seq: List[int]) -> List[int]:
        """
        Convert an observation sequence to latent state-ids using the *current* split mapping:
        if current sid is a salient cue and next base state has a clone, jump into the clone.
        """
        # truncate at padding 0
        seq = []
        for o in obs_seq:
            if int(o) == 0:
                break
            seq.append(int(o))
        if not seq:
            return []

        lat = [seq[0]]
        for nxt in seq[1:]:
            cur = lat[-1]
            # apply the split routing rule used by the core implementation: successor -> clone if cue salient
            if (cur in self.core.salient_cues) and (nxt in self.env.reverse_clone_dict):
                lat.append(int(self.env.reverse_clone_dict[nxt]))
            else:
                lat.append(int(nxt))
        return lat

    def run_episode(self, obs_seq: List[int], learn: bool=True) -> List[int]:
        lat = self._map_obs_to_latent(obs_seq)
        if not lat:
            return []

        # Append binary terminal expected by core CoDAAgent
        near_like = self._episode_label_is_near([o for o in obs_seq if int(o)!=0])
        terminal = self.rewarded_terminal if near_like else self.unrewarded_terminal
        states = lat + [terminal]
        actions = [0] * (len(states) - 1)

        if learn:
            self.core.update_with_episode(states, actions)
            self.core.maybe_split()
            self.core.maybe_merge()
            self._refresh_public_views()

        return lat

    def encode_sequence(self, obs_seq: List[int]) -> np.ndarray:
        lat = self.run_episode(obs_seq, learn=False)
        S = self._next_sid
        X = np.zeros((len(lat), S), dtype=float)
        for t, sid in enumerate(lat):
            if 0 <= sid < S:
                X[t, sid] = 1.0
        return X[:, :self._next_sid]

    def near_far_corr(self, near_seq, far_seq) -> np.ndarray:
        A = self.encode_sequence(near_seq)
        B = self.encode_sequence(far_seq)
        C = np.zeros((A.shape[0], B.shape[0]))
        for i in range(A.shape[0]):
            for j in range(B.shape[0]):
                a = A[i]; b = B[j]
                if np.allclose(a, 0) or np.allclose(b, 0):
                    C[i, j] = 0.0
                    continue
                a0 = a - a.mean()
                b0 = b - b.mean()
                den = (np.linalg.norm(a0) * np.linalg.norm(b0))
                C[i, j] = (a0 @ b0) / den if den > 0 else 0.0
        return C

    def clone(self):
        # Deep copy for checkpointing (used by the notebook)
        import copy
        return copy.deepcopy(self)


In [None]:

N_RUNS=6; SESSIONS=9; TRIALS_PER_SESSION=80; THRESH=0.3
def block_mean(C, pairs): 
    return float(np.mean([C[i,j] for (i,j) in pairs])) if pairs else np.nan
rng=np.random.default_rng(123)
all_final_blocks=[]; all_time_to_thr=[]; mat_by_session={s:[] for s in [1,3,4,9]}
checkpoint_sessions=[1,3,4,9]; demo_checkpoints={}

for run in range(N_RUNS):
    agent=CoDAUncAgent(obs_symbols=sorted(set(near)|set(far)), cfg=UncCfg())
    tt={'offdiag':None,'preR2':None,'preR1':None}
    for session in range(1, SESSIONS+1):
        episodes=[near]*(TRIALS_PER_SESSION//2)+[far]*(TRIALS_PER_SESSION//2)
        rng.shuffle(episodes)
        for ep in episodes: agent.run_episode(ep, learn=True)
        C=agent.near_far_corr(near, far)
        if session in mat_by_session: mat_by_session[session].append(C)
        if run==0 and session in checkpoint_sessions:
            demo_checkpoints[session] = agent.clone()
        b_off=block_mean(C,offdiag_pairs); b_r2=block_mean(C,same_preR2_pairs); b_r1=block_mean(C,same_preR1_pairs)
        if tt['offdiag'] is None and b_off<THRESH: tt['offdiag']=session
        if tt['preR2']  is None and b_r2<THRESH: tt['preR2']=session
        if tt['preR1']  is None and b_r1<THRESH: tt['preR1']=session
    C_final=agent.near_far_corr(near, far)
    all_final_blocks.append((run, block_mean(C_final, offdiag_pairs), block_mean(C_final, same_preR2_pairs), block_mean(C_final, same_preR1_pairs)))
    def norm(x): return x/SESSIONS if x is not None else np.nan
    all_time_to_thr.append((run, norm(tt['offdiag']), norm(tt['preR2']), norm(tt['preR1'])))

blocks_df=pd.DataFrame(all_final_blocks, columns=['run','offdiag','preR2','preR1']).set_index('run')
times_df =pd.DataFrame(all_time_to_thr, columns=['run','offdiag_t','preR2_t','preR1_t']).set_index('run')


In [None]:


def canonical_latents(agent, seq):
    lat = agent.run_episode(seq, learn=False)
    out = []
    for t, (obs, sid) in enumerate(zip(seq, lat)):
        if obs == 0:
            break
        out.append((t, obs, sid, agent.states[sid]['path']))
    return out


def ring_positions(T: int, r_base=1.0, r_near=1.12, r_far=0.88, start_angle=np.pi/2):
    pos_by_t = {}
    for t in range(T):
        theta = start_angle - 2 * np.pi * (t / max(T, 1))
        pos_by_t[t] = {
            'theta': theta,
            'base': (r_base * np.cos(theta), r_base * np.sin(theta)),
            'R1': (r_near * np.cos(theta), r_near * np.sin(theta)),
            'R2': (r_far * np.cos(theta), r_far * np.sin(theta)),
        }
    return pos_by_t


def build_color_assignments(scheme: str, near_lat, far_lat):
    import matplotlib as mpl

    colors = {'near': {}, 'far': {}}

    def set_color(target, lat_seq, fn):
        for (t, obs, sid, path) in lat_seq:
            ring = 'base' if path is None else path
            target[(t, ring)] = fn(t, obs, sid, ring)

    if scheme == 'state_id':
        tab20 = mpl.cm.get_cmap('tab20')
        set_color(colors['near'], near_lat, lambda t, o, s, ring: tab20(s % 20))
        set_color(colors['far'],  far_lat,  lambda t, o, s, ring: tab20(s % 20))

    elif scheme == 'obs_step':
        hsv = mpl.cm.get_cmap('hsv')
        set_color(colors['near'], near_lat, lambda t, o, s, ring: hsv((t % 26) / 26.0))
        set_color(colors['far'],  far_lat,  lambda t, o, s, ring: hsv((t % 26) / 26.0))

    else:
        # Default: color-code by the "third element" (latent state id) with overrides
        unique_sids = sorted({sid for (_, _, sid, _) in near_lat + far_lat})
        cmap = mpl.cm.get_cmap('tab20', max(1, len(unique_sids)))
        sid_to_color = {sid: cmap(i % cmap.N) for i, sid in enumerate(unique_sids)}
        overrides = {1: '#b3b3b3', 5: '#9467bd', 9: '#6a6a6a', 11: '#000000', 6: '#d62728', 8: '#d62728'}
        sid_to_color.update({sid: color for sid, color in overrides.items() if sid in unique_sids})

        set_color(colors['near'], near_lat, lambda t, o, s, ring: sid_to_color[s])
        set_color(colors['far'],  far_lat,  lambda t, o, s, ring: sid_to_color[s])

    return colors


def draw_cscg_style_ring(
    ax,
    agent,
    title=None,
    scheme='blocks',
    close_loop=True,
    edge_color='#5a5a5a',
    edge_lw=2.2,
 ):
    # --- build sequences and ring geometry ---
    near_lat = canonical_latents(agent, near)
    far_lat = canonical_latents(agent, far)
    T = max(len(near_lat), len(far_lat))
    pos = ring_positions(T)

    # --- node colors and occupancy ---
    node_colors = build_color_assignments(scheme, near_lat, far_lat)
    from collections import defaultdict
    occupancy = defaultdict(list)
    sid_shared = set()
    near_by_t = {t: sid for (t, _, sid, _) in near_lat}
    far_by_t = {t: sid for (t, _, sid, _) in far_lat}
    for t, sid in near_by_t.items():
        if far_by_t.get(t) == sid:
            sid_shared.add((t, sid))
    for label, lat_seq in (('near', near_lat), ('far', far_lat)):
        for (t, obs, sid, path) in lat_seq:
            ring = 'base' if path is None else path
            occupancy[(t, ring)].append((label, sid))

    node_positions = {}
    offset_mag = 0.07
    radial_offset = 0.09
    prev_pos = {'near': None, 'far': None}
    all_times = sorted({t for (t, _) in occupancy.keys()})
    rings_order = ['base', 'R1', 'R2']
    force_swap_times = {6, 19}
    arrow_mutation = 14.0

    for t in all_times:
        for ring in rings_order:
            key = (t, ring)
            if key not in occupancy:
                continue
            entries = occupancy[key]
            base_x, base_y = pos[t][ring]
            sids = {sid for (_, sid) in entries}
            shared_here = any((t, sid) in sid_shared for (_, sid) in entries)
            if len(entries) == 1 or (len(sids) == 1 and shared_here):
                for label, sid in entries:
                    node_positions[(label, t, ring)] = (base_x, base_y)
                    prev_pos[label] = (base_x, base_y)
                continue
            labels_here = {label for (label, _) in entries}
            vx, vy = base_x, base_y
            norm = (vx * vx + vy * vy) ** 0.5
            if norm == 0.0:
                ux, uy = 1.0, 0.0
            else:
                ux, uy = vx / norm, vy / norm
            if len(entries) == 2 and labels_here == {'near', 'far'}:
                def assign(sign_map):
                    out = {}
                    for label, sid in entries:
                        direction = radial_offset * sign_map[label]
                        out[label] = (base_x + direction * ux, base_y + direction * uy)
                    return out

                default_map = {'near': 1.0, 'far': -1.0}
                swapped_map = {'near': -1.0, 'far': 1.0}

                def mapping_cost(mapped):
                    cost = 0.0
                    for label, (x, y) in mapped.items():
                        prev = prev_pos.get(label)
                        if prev is None:
                            continue
                        px, py = prev
                        cost += (x - px) * (x - px) + (y - py) * (y - py)
                    return cost

                opt_default = assign(default_map)
                opt_swapped = assign(swapped_map)
                if t in force_swap_times:
                    chosen = opt_swapped
                elif mapping_cost(opt_swapped) < mapping_cost(opt_default):
                    chosen = opt_swapped
                else:
                    chosen = opt_default
                for label, coords in chosen.items():
                    node_positions[(label, t, ring)] = coords
                    prev_pos[label] = coords
                continue
            if len(entries) == 1:
                label, sid = entries[0]
                node_positions[(label, t, ring)] = (base_x, base_y)
                prev_pos[label] = (base_x, base_y)
                continue
            if norm == 0.0:
                px, py = 0.0, 1.0
            else:
                px, py = -vy / norm, vx / norm
            ordered = sorted(entries, key=lambda item: item[0])
            span = len(ordered) - 1
            for idx, (label, sid) in enumerate(ordered):
                factor = idx - span / 2.0
                coords = (base_x + factor * offset_mag * px, base_y + factor * offset_mag * py)
                node_positions[(label, t, ring)] = coords
                prev_pos[label] = coords

    def node_xy(label, t, ring):
        return node_positions[(label, t, ring)]

    def draw_edges(lat_seq, label, z=1):
        for i in range(len(lat_seq) - 1):
            t, _, _, p = lat_seq[i]
            t2, _, _, p2 = lat_seq[i + 1]
            r1 = 'base' if p is None else p
            r2 = 'base' if p2 is None else p2
            x1, y1 = node_xy(label, t, r1)
            x2, y2 = node_xy(label, t2, r2)
            ax.add_patch(
                FancyArrowPatch(
                    (x1, y1),
                    (x2, y2),
                    arrowstyle='-|>',
                    mutation_scale=arrow_mutation,
                    lw=edge_lw,
                    color=edge_color,
                    alpha=0.9,
                    shrinkA=2.5,
                    shrinkB=2.5,
                    zorder=z,
                )
            )
        if close_loop and len(lat_seq) >= 2:
            t0, _, _, p0 = lat_seq[0]
            tL, _, _, pL = lat_seq[-1]
            r0 = 'base' if p0 is None else p0
            rL = 'base' if pL is None else pL
            x1, y1 = node_xy(label, tL, rL)
            x2, y2 = node_xy(label, t0, r0)
            ax.add_patch(
                FancyArrowPatch(
                    (x1, y1),
                    (x2, y2),
                    arrowstyle='-|>',
                    mutation_scale=arrow_mutation,
                    lw=edge_lw,
                    color=edge_color,
                    alpha=0.9,
                    shrinkA=2.5,
                    shrinkB=2.5,
                    zorder=z,
                )
            )

    def draw_nodes(lat_seq, label, z=3):
        for (t, obs, sid, path) in lat_seq:
            ring = 'base' if path is None else path
            x, y = node_xy(label, t, ring)
            face = node_colors[label].get((t, ring), '#bfbfbf')
            ax.add_patch(
                Circle(
                    (x, y),
                    radius=0.065,
                    facecolor=face,
                    edgecolor='#3a3a3a',
                    lw=1.1,
                    alpha=0.98,
                    zorder=z,
                )
            )

    draw_edges(near_lat, 'near', z=1)
    draw_edges(far_lat, 'far', z=1)
    draw_nodes(far_lat, 'far', z=3)
    draw_nodes(near_lat, 'near', z=4)

    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_xlim(-1.55, 1.55)
    ax.set_ylim(-1.55, 1.55)
    if title:
        ax.set_title(title, fontsize=20)


In [None]:
near_lat = canonical_latents(demo_checkpoints[4], near)
near_lat

In [None]:
far_lat = canonical_latents(demo_checkpoints[4], far)
far_lat

In [None]:
import os
os.makedirs("figures/poster", exist_ok=True)

export_specs = [
    ("svg", {"format": "svg"}),
    ("pdf", {"format": "pdf"}),
    ("png", {"format": "png", "dpi": 400})
]

final_session = max([s for s in [1, 3, 4, 9] if s in demo_checkpoints])
fig, ax = plt.subplots(figsize=(7.4, 4.6), constrained_layout=True)
draw_cscg_style_ring(
    ax,
    demo_checkpoints[final_session],
    title="CoDA — transition graph (CSCG-style colors)",
    scheme='blocks'
)
for ext, kwargs in export_specs:
    fig.savefig(
        os.path.join("figures", "poster", f"coda_transition_final_session.{ext}"),
        bbox_inches='tight',
        **kwargs
    )
plt.show()

for s in sorted(demo_checkpoints.keys()):
    fig_s, ax_s = plt.subplots(figsize=(7.4, 4.6), constrained_layout=True)
    draw_cscg_style_ring(
        ax_s,
        demo_checkpoints[s],
        title=f"Session {s}",
        scheme='blocks'
    )
    for ext, kwargs in export_specs:
        fig_s.savefig(
            os.path.join("figures", "poster", f"coda_transition_session_{s}.{ext}"),
            bbox_inches='tight',
            **kwargs
        )
    plt.show()


In [None]:
# === Trajectory of decorrelation during learning (CSCG-style) ===
import matplotlib as mpl

check_sessions = [1, 3, 4, 9]
mean_mats = {s: np.mean(mat_by_session[s], axis=0) for s in check_sessions}

# CSCG-style colormap (dark background, warm high values)
# 'magma' or 'inferno' gives similar look; both start dark→warm→bright
cmap = mpl.cm.get_cmap('magma')

# Pre/indicator windows to highlight (indices are inclusive, 0-based)
highlight_windows = [(6, 10), (13, 15), (18, 20)]
dash_lines = [6, 10, 13, 15, 18, 20]

# Make independent figures for each checkpoint session
for s in check_sessions:
    fig, ax = plt.subplots(figsize=(4.8, 4.8), constrained_layout=True)

    last_im = ax.imshow(
        mean_mats[s],
        vmin=-0.1,
        vmax=1.0,
        cmap=cmap,
        origin='upper',
        aspect='equal',
        interpolation='nearest'
    )

    ny, nx = mean_mats[s].shape
    ax.set_xticks(np.arange(-0.5, nx, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, ny, 1), minor=True)
    ax.grid(which='minor', color='white', linestyle=':', linewidth=0.4)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.tick_params(colors='white', labelsize=0, length=0)
    for spine in ax.spines.values():
        spine.set_visible(False)

    for idx in dash_lines:
        if idx <= nx:
            ax.axvline(idx - 0.5, color='white', linewidth=1.2, linestyle=(0, (2, 4)))
        if idx <= ny:
            ax.axhline(idx - 0.5, color='white', linewidth=1.2, linestyle=(0, (2, 4)))

    for low, high in highlight_windows:
        if high <= min(nx, ny):
            x0, x1 = low - 0.5, high - 0.5
            y0, y1 = low - 0.5, high - 0.5
            ax.plot(
                [x0, x1, x1, x0, x0],
                [y0, y0, y1, y1, y0],
                color='white',
                linewidth=2.5
            )

    fig.patch.set_facecolor('black')
    ax.set_facecolor('black')

    cbar = fig.colorbar(last_im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("")
    cbar.set_ticks([])
    cbar.ax.tick_params(length=0)

    plt.show()


In [None]:

# Fig. 4i
means = blocks_df.mean(); ses = blocks_df.sem()
labels = ['offdiag','preR2','preR1']
x = np.arange(len(labels)); y = [means[l] for l in labels]; yerr = [ses[l] for l in labels]
fig, ax = plt.subplots(figsize=(6.5,4.6))
bars = ax.bar(x, y, yerr=yerr, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(labels)
ax.set_ylim(0, 1.05)
ax.set_ylabel("Mean correlation (final)")
ax.set_title("CoDA (uncertainty‑aware) — Fig. 4i analogue")
for i, (b, val) in enumerate(zip(bars, y)):
    ax.text(b.get_x()+b.get_width()/2, (val-0.05 if val>0.85 else min(1.02, val+0.05)), f"{val:.2f}", ha='center',
            va=('top' if val>0.85 else 'bottom'), color=('white' if val>0.85 else 'black'), fontsize=9)
fig.tight_layout(); plt.show()

# Fig. 4j
means_t = times_df.mean(skipna=True); ses_t = times_df.sem(skipna=True)
labels_t = ['offdiag_t','preR2_t','preR1_t']; disp = ['offdiag','preR2','preR1']
x = np.arange(len(labels_t)); y = [means_t[l] for l in labels_t]; yerr = [ses_t[l] for l in labels_t]
fig, ax = plt.subplots(figsize=(6.5,4.6))
bars = ax.bar(x, y, yerr=yerr, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(disp)
ax.set_ylim(0, 1.05)
ax.set_ylabel("Fraction of training (first corr < 0.3)")
ax.set_title("CoDA (uncertainty‑aware) — Fig. 4j analogue")
for i, (b, val) in enumerate(zip(bars, y)):
    ax.text(b.get_x()+b.get_width()/2, min(1.02, val+0.05), f"{val:.2f}", ha='center', va='bottom', fontsize=9)
fig.tight_layout(); plt.show()
