# CoDA (uncertainty‑aware) — OSM Fig. 4c/4i/4j with **CSCG‑style color‑coded states**

**New:** CSCG‑style node colors (uniform gray arrows), three coloring modes (`blocks`, `state_id`, `obs_step`).

In [None]:

import os
import math, random
from dataclasses import dataclass
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch, Circle


In [None]:

def posterior_prob_p_greater_than(theta: float, success: float, failure: float, alpha0: float=0.5, beta0: float=0.5) -> float:
    if not _HAS_MPMATH:
        return 0.0
    a = alpha0 + max(0.0, float(success))
    b = beta0 + max(0.0, float(failure))
    cdf = betainc(a, b, 0, theta, regularized=True)
    return float(1.0 - cdf)

def wilson_lower_bound(phat: float, n: float, confidence: float=0.95) -> float:
    if n <= 0:
        return 0.0
    if _HAS_MPMATH:
        z = float((2.0**0.5) * erfcinv(2*(1.0-confidence)))
    else:
        z = 1.6448536269514722
    denom = 1.0 + (z*z)/n
    center = phat + (z*z)/(2.0*n)
    adj = z * ((phat*(1.0-phat) + (z*z)/(4.0*n))/n)**0.5
    return (center - adj)/denom


In [None]:

near = [1,1,1,1,1,1, 2,2,2,2, 1,1,1, 4,6, 1,1,1, 5,5, 1,1, 7, 0,0,0]
far  = [1,1,1,1,1,1, 3,3,3,3, 1,1,1, 4,4, 1,1,1, 5,6, 1,1, 7, 0,0,0]

preR1_idx = list(range(10,13))
preR2_idx = list(range(15,18))

def block_indices(rows, cols): return [(r,c) for r in rows for c in cols]
offdiag_pairs     = block_indices(preR1_idx, preR2_idx) + block_indices(preR2_idx, preR1_idx)
same_preR1_pairs  = block_indices(preR1_idx, preR1_idx)
same_preR2_pairs  = block_indices(preR2_idx, preR2_idx)


In [None]:
# --- CoDA agent (trial-by-trial) ---
# This block replaces the placeholder "agent" with a runnable CoDA implementation
# consistent with the paper's idea: outcome-conditioned eligibility traces drive
# state-splitting (cloning) via prospective contingency, with optional merge via
# prospective×retrospective utility.

from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
import math
import numpy as np

@dataclass
class UncCfg:
    # Eligibility traces (Algorithm 1)
    gamma: float = 0.9
    lam: float = 0.8

    # Split criterion (Algorithm 2)
    theta_split: float = 0.8            # split if a state predicts some US with high confidence
    confidence: float = 0.95            # one-sided confidence for Wilson bound
    n_threshold: int = 5                # minimum effective evidence before testing
    min_presence_episodes: int = 3      # minimum number of episodes in which state appeared
    min_effective_exposure: float = 10.0# minimum eligibility mass exposure

    # Optional merge (paper: utility = PC * RC; merge if < theta_merge)
    theta_merge: float = 0.0            # set >0 to enable merges; keep 0.0 to disable in this demo

    # Symbols that reset context (e.g., ITI / end markers)
    reset_symbols: tuple = (0,)

def _z_one_sided(confidence: float) -> float:
    """Small lookup for common one-sided z; defaults to 0.95 -> 1.64485."""
    # one-sided z such that P(Z <= z) = confidence
    table = {0.90: 1.2815515655446004,
             0.95: 1.6448536269514722,
             0.975: 1.959963984540054,
             0.99: 2.3263478740408408}
    # nearest key
    key = min(table.keys(), key=lambda k: abs(k - confidence))
    return table[key]

def wilson_lower_bound_local(phat: float, n: float, confidence: float=0.95) -> float:
    """One-sided Wilson lower bound for a Bernoulli proportion."""
    if n <= 0:
        return 0.0
    z = _z_one_sided(confidence)
    denom = 1.0 + (z*z)/n
    center = phat + (z*z)/(2.0*n)
    adj = z * math.sqrt((phat*(1.0-phat) + (z*z)/(4.0*n))/n)
    return (center - adj)/denom

class CoDAUncAgent:
    """
    CoDA-style state-splitting for a 1D symbol stream.

    Latent states are (obs, path) pairs:
      - original states have path=None
      - clones have path like "R1"/"R2" (context label) created after a salient cue

    Context update:
      - if current latent state is a salient cue, set context = salient_path
      - if current obs is in reset_symbols, context=None

    Learning:
      - For each US class u (default: 4 and 5), accumulate outcome-conditioned eligibility
        mass co_occ[s][u] from eligibility snapshots at each US occurrence.
      - Prospective contingency is P(u | s) = co_occ[s][u] / sum_u co_occ[s][u].
      - If max_u P(u|s) is confidently > theta_split, mark s salient and clone successor
        observations into the corresponding context path.
      - Optional merge can be enabled via theta_merge>0 using utility = PC*RC.
    """
    def __init__(self, obs_symbols, cfg: UncCfg = UncCfg()):
        self.cfg = cfg
        self.reset_symbols: Set[int] = set(cfg.reset_symbols)

        # Define US classes used in this notebook (matches symbol design)
        self.us_classes = [4, 5]

        # Latent state storage
        self.states: Dict[int, Dict] = {}
        self.obs_to_state_ids: Dict[int, List[int]] = {o: [] for o in obs_symbols}

        sid = 0
        for o in obs_symbols:
            self.states[sid] = {'obs': o, 'path': None, 'parent': None}
            self.obs_to_state_ids[o].append(sid)
            sid += 1
        self._next_sid = sid

        # Outcome-conditioned eligibility accumulators: co_occ[s][u] and exposure[s]
        self.co_occ: Dict[int, Dict[int, float]] = {s: {u: 0.0 for u in self.us_classes} for s in self.states}
        self.exposure: Dict[int, float] = {s: 0.0 for s in self.states}

        # Episode presence counters (for evidence gates + retrospective)
        self.presence_episodes: Dict[int, int] = {s: 0 for s in self.states}
        self.us_episode_counts: Dict[int, int] = {u: 0 for u in self.us_classes}
        self.cs_us_presence: Dict[int, Dict[int, int]] = {s: {u: 0 for u in self.us_classes} for s in self.states}

        # Salient cue map: sid -> path label ("R1"/"R2")
        self.salient: Dict[int, str] = {}

        # Optional bookkeeping: cue -> clones created (for cleanup/merging if desired)
        self.cue_to_clones: Dict[int, Set[int]] = {}

    def _ensure(self, sid: int) -> None:
        if sid not in self.co_occ:
            self.co_occ[sid] = {u: 0.0 for u in self.us_classes}
            self.exposure[sid] = 0.0
            self.presence_episodes[sid] = 0
            self.cs_us_presence[sid] = {u: 0 for u in self.us_classes}

    def _clone_state(self, orig_state_id: int, path: str) -> int:
        orig = self.states[orig_state_id]
        cid = self._next_sid
        self.states[cid] = {'obs': orig['obs'], 'path': path, 'parent': orig_state_id}
        self.obs_to_state_ids[orig['obs']].append(cid)
        self._ensure(cid)
        self._next_sid += 1
        return cid

    def _select_state_for_obs(self, obs: int, context: Optional[str]) -> int:
        cands = self.obs_to_state_ids[obs]
        if context is not None:
            for sid in cands:
                if self.states[sid]['path'] == context:
                    return sid
        # prefer canonical (path=None)
        for sid in cands:
            if self.states[sid]['path'] is None:
                return sid
        return cands[0]

    def prospective(self, sid: int) -> Dict[int, float]:
        """Return P(u|sid) over us_classes."""
        tot = sum(self.co_occ.get(sid, {}).get(u, 0.0) for u in self.us_classes)
        if tot <= 0:
            return {u: 0.0 for u in self.us_classes}
        return {u: self.co_occ[sid][u] / tot for u in self.us_classes}

    def retrospective(self, sid: int, u: int) -> float:
        """Return P(sid present | US=u present) using episode counts."""
        denom = self.us_episode_counts.get(u, 0)
        if denom <= 0:
            return 0.0
        return float(self.cs_us_presence.get(sid, {}).get(u, 0)) / float(denom)

    def _maybe_merge(self) -> None:
        """Optional: merge by dropping salient status when utility falls below threshold."""
        if self.cfg.theta_merge <= 0:
            return
        # for each cue, compute utility = PC * RC for its associated US
        for cue in list(self.salient.keys()):
            path = self.salient[cue]
            u_star = 4 if path == 'R1' else 5
            P = self.prospective(cue)
            pc = P.get(u_star, 0.0)
            rc = self.retrospective(cue, u_star)
            if pc * rc < self.cfg.theta_merge:
                # Remove salient tag; clones will simply stop being selected because context no longer set.
                self.salient.pop(cue, None)

    def run_episode(self, obs_seq: List[int], learn: bool = True) -> List[int]:
        # --- inference / latent assignment ---
        context: Optional[str] = None
        latent_seq: List[int] = []
        visited: Set[int] = set()

        for obs in obs_seq:
            sid = self._select_state_for_obs(obs, context)
            latent_seq.append(sid)
            visited.add(sid)

            # context update
            if sid in self.salient:
                context = self.salient[sid]
            if obs in self.reset_symbols:
                context = None

        if not learn:
            return latent_seq

        # --- episode-level presence bookkeeping ---
        for sid in visited:
            self.presence_episodes[sid] = self.presence_episodes.get(sid, 0) + 1

        # US present flags for retrospective counts
        us_present = {u: any(o == u for o in obs_seq) for u in self.us_classes}
        for u, pres in us_present.items():
            if pres:
                self.us_episode_counts[u] = self.us_episode_counts.get(u, 0) + 1
                for sid in visited:
                    self._ensure(sid)
                    self.cs_us_presence[sid][u] = self.cs_us_presence[sid].get(u, 0) + 1

        # --- outcome-conditioned eligibility traces (multi-US version) ---
        decay = self.cfg.gamma * self.cfg.lam
        trace: Dict[int, float] = {}
        snapshots: List[Dict[int, float]] = []
        for sid in latent_seq:
            # decay
            for k in list(trace.keys()):
                trace[k] *= decay
                if trace[k] < 1e-12:
                    trace.pop(k, None)
            trace[sid] = trace.get(sid, 0.0) + 1.0
            snapshots.append(dict(trace))

        # Add eligibility mass to co_occ for each US occurrence
        for u in self.us_classes:
            positions = [i for i, o in enumerate(obs_seq) if o == u]
            for t_us in positions:
                snap = snapshots[t_us]
                for s_id, val in snap.items():
                    self._ensure(s_id)
                    self.co_occ[s_id][u] += float(val)
                    self.exposure[s_id] += float(val)

        # --- cue discovery & splitting (Algorithm 2 spirit) ---
        newly_salient: List[Tuple[int, str]] = []
        for s in list(self.states.keys()):
            if s in self.salient:
                continue

            tot = sum(self.co_occ[s][u] for u in self.us_classes)
            if tot < self.cfg.n_threshold:
                continue
            if self.presence_episodes.get(s, 0) < self.cfg.min_presence_episodes:
                continue
            if self.exposure.get(s, 0.0) < self.cfg.min_effective_exposure:
                continue

            P = self.prospective(s)
            u_star = max(self.us_classes, key=lambda u: P[u])
            phat = P[u_star]
            # Wilson lower bound gate for "confidently above theta_split"
            lb = wilson_lower_bound_local(phat, tot, self.cfg.confidence)
            if lb > self.cfg.theta_split:
                path = 'R1' if u_star == 4 else 'R2'
                self.salient[s] = path
                newly_salient.append((s, path))

        # Clone successors for newly salient cues (split)
        # Here, "successor" is the next observation symbol along the experienced trajectory.
        for cue, path in newly_salient:
            self.cue_to_clones.setdefault(cue, set())
            context = None
            for t, obs in enumerate(obs_seq[:-1]):
                sid = self._select_state_for_obs(obs, context)
                if sid == cue:
                    nxt_obs = obs_seq[t + 1]
                    cands = self.obs_to_state_ids[nxt_obs]
                    if not any(self.states[c]['path'] == path for c in cands):
                        # clone the canonical state for nxt_obs (prefer path=None)
                        base = next((c for c in cands if self.states[c]['path'] is None), cands[0])
                        clone_id = self._clone_state(base, path)
                        self.cue_to_clones[cue].add(clone_id)

                if sid in self.salient:
                    context = self.salient[sid]
                if obs_seq[t + 1] in self.reset_symbols:
                    context = None

        # Optional merge (disabled by default in this notebook)
        self._maybe_merge()

        return latent_seq

    def encode_sequence(self, obs_seq: List[int]) -> np.ndarray:
        lat = self.run_episode(obs_seq, learn=False)
        S = self._next_sid
        X = np.zeros((len(lat), S), dtype=float)
        for t, sid in enumerate(lat):
            X[t, sid] = 1.0
        return X[:, :self._next_sid]

    def near_far_corr(self, near_seq, far_seq) -> np.ndarray:
        A = self.encode_sequence(near_seq)
        B = self.encode_sequence(far_seq)
        C = np.zeros((A.shape[0], B.shape[0]))
        for i in range(A.shape[0]):
            for j in range(B.shape[0]):
                a = A[i]; b = B[j]
                if np.allclose(a, 0) or np.allclose(b, 0):
                    C[i, j] = 0.0
                    continue
                a0 = a - a.mean()
                b0 = b - b.mean()
                den = (np.linalg.norm(a0) * np.linalg.norm(b0))
                C[i, j] = (a0 @ b0) / den if den > 0 else 0.0
        return C


In [None]:

N_RUNS=6; SESSIONS=9; TRIALS_PER_SESSION=80; THRESH=0.3
def block_mean(C, pairs): 
    return float(np.mean([C[i,j] for (i,j) in pairs])) if pairs else np.nan
rng=np.random.default_rng(123)
all_final_blocks=[]; all_time_to_thr=[]; mat_by_session={s:[] for s in [1,3,4,9]}
checkpoint_sessions=[1,3,4,9]; demo_checkpoints={}

for run in range(N_RUNS):
    agent=CoDAUncAgent(obs_symbols=sorted(set(near)|set(far)), cfg=UncCfg())
    tt={'offdiag':None,'preR2':None,'preR1':None}
    for session in range(1, SESSIONS+1):
        episodes=[near]*(TRIALS_PER_SESSION//2)+[far]*(TRIALS_PER_SESSION//2)
        rng.shuffle(episodes)
        for ep in episodes: agent.run_episode(ep, learn=True)
        C=agent.near_far_corr(near, far)
        if session in mat_by_session: mat_by_session[session].append(C)
        if run==0 and session in checkpoint_sessions:
            snap=CoDAUncAgent(obs_symbols=sorted(set(near)|set(far)), cfg=agent.cfg)
            snap.states={k:v.copy() for k,v in agent.states.items()}
            snap._next_sid=agent._next_sid
            snap.obs_to_state_ids={k:list(v) for k,v in agent.obs_to_state_ids.items()}
            snap.co_occ={k:dict(v) for k,v in agent.co_occ.items()}
            snap.exposure=dict(agent.exposure); snap.presence_episodes=dict(agent.presence_episodes)
            snap.salient=dict(agent.salient); demo_checkpoints[session]=snap
        b_off=block_mean(C,offdiag_pairs); b_r2=block_mean(C,same_preR2_pairs); b_r1=block_mean(C,same_preR1_pairs)
        if tt['offdiag'] is None and b_off<THRESH: tt['offdiag']=session
        if tt['preR2']  is None and b_r2<THRESH: tt['preR2']=session
        if tt['preR1']  is None and b_r1<THRESH: tt['preR1']=session
    C_final=agent.near_far_corr(near, far)
    all_final_blocks.append((run, block_mean(C_final, offdiag_pairs), block_mean(C_final, same_preR2_pairs), block_mean(C_final, same_preR1_pairs)))
    def norm(x): return x/SESSIONS if x is not None else np.nan
    all_time_to_thr.append((run, norm(tt['offdiag']), norm(tt['preR2']), norm(tt['preR1'])))

blocks_df=pd.DataFrame(all_final_blocks, columns=['run','offdiag','preR2','preR1']).set_index('run')
times_df =pd.DataFrame(all_time_to_thr, columns=['run','offdiag_t','preR2_t','preR1_t']).set_index('run')


In [None]:


def canonical_latents(agent, seq):
    lat = agent.run_episode(seq, learn=False)
    out = []
    for t, (obs, sid) in enumerate(zip(seq, lat)):
        if obs == 0:
            break
        out.append((t, obs, sid, agent.states[sid]['path']))
    return out


def ring_positions(T: int, r_base=1.0, r_near=1.12, r_far=0.88, start_angle=np.pi/2):
    pos_by_t = {}
    for t in range(T):
        theta = start_angle - 2 * np.pi * (t / max(T, 1))
        pos_by_t[t] = {
            'theta': theta,
            'base': (r_base * np.cos(theta), r_base * np.sin(theta)),
            'R1': (r_near * np.cos(theta), r_near * np.sin(theta)),
            'R2': (r_far * np.cos(theta), r_far * np.sin(theta)),
        }
    return pos_by_t


def build_color_assignments(scheme: str, near_lat, far_lat):
    import matplotlib as mpl

    colors = {'near': {}, 'far': {}}

    def set_color(target, lat_seq, fn):
        for (t, obs, sid, path) in lat_seq:
            ring = 'base' if path is None else path
            target[(t, ring)] = fn(t, obs, sid, ring)

    if scheme == 'state_id':
        tab20 = mpl.cm.get_cmap('tab20')
        set_color(colors['near'], near_lat, lambda t, o, s, ring: tab20(s % 20))
        set_color(colors['far'],  far_lat,  lambda t, o, s, ring: tab20(s % 20))

    elif scheme == 'obs_step':
        hsv = mpl.cm.get_cmap('hsv')
        set_color(colors['near'], near_lat, lambda t, o, s, ring: hsv((t % 26) / 26.0))
        set_color(colors['far'],  far_lat,  lambda t, o, s, ring: hsv((t % 26) / 26.0))

    else:
        # Default: color-code by the "third element" (latent state id) with overrides
        unique_sids = sorted({sid for (_, _, sid, _) in near_lat + far_lat})
        cmap = mpl.cm.get_cmap('tab20', max(1, len(unique_sids)))
        sid_to_color = {sid: cmap(i % cmap.N) for i, sid in enumerate(unique_sids)}
        overrides = {1: '#b3b3b3', 5: '#9467bd', 9: '#6a6a6a', 11: '#000000', 6: '#d62728', 8: '#d62728'}
        sid_to_color.update({sid: color for sid, color in overrides.items() if sid in unique_sids})

        set_color(colors['near'], near_lat, lambda t, o, s, ring: sid_to_color[s])
        set_color(colors['far'],  far_lat,  lambda t, o, s, ring: sid_to_color[s])

    return colors


def draw_cscg_style_ring(
    ax,
    agent,
    title=None,
    scheme='blocks',
    close_loop=True,
    edge_color='#5a5a5a',
    edge_lw=2.2,
 ):
    # --- build sequences and ring geometry ---
    near_lat = canonical_latents(agent, near)
    far_lat = canonical_latents(agent, far)
    T = max(len(near_lat), len(far_lat))
    pos = ring_positions(T)

    # --- node colors and occupancy ---
    node_colors = build_color_assignments(scheme, near_lat, far_lat)
    from collections import defaultdict
    occupancy = defaultdict(list)
    sid_shared = set()
    near_by_t = {t: sid for (t, _, sid, _) in near_lat}
    far_by_t = {t: sid for (t, _, sid, _) in far_lat}
    for t, sid in near_by_t.items():
        if far_by_t.get(t) == sid:
            sid_shared.add((t, sid))
    for label, lat_seq in (('near', near_lat), ('far', far_lat)):
        for (t, obs, sid, path) in lat_seq:
            ring = 'base' if path is None else path
            occupancy[(t, ring)].append((label, sid))

    node_positions = {}
    offset_mag = 0.07
    radial_offset = 0.09
    prev_pos = {'near': None, 'far': None}
    all_times = sorted({t for (t, _) in occupancy.keys()})
    rings_order = ['base', 'R1', 'R2']
    force_swap_times = {6, 19}
    arrow_mutation = 14.0

    for t in all_times:
        for ring in rings_order:
            key = (t, ring)
            if key not in occupancy:
                continue
            entries = occupancy[key]
            base_x, base_y = pos[t][ring]
            sids = {sid for (_, sid) in entries}
            shared_here = any((t, sid) in sid_shared for (_, sid) in entries)
            if len(entries) == 1 or (len(sids) == 1 and shared_here):
                for label, sid in entries:
                    node_positions[(label, t, ring)] = (base_x, base_y)
                    prev_pos[label] = (base_x, base_y)
                continue
            labels_here = {label for (label, _) in entries}
            vx, vy = base_x, base_y
            norm = (vx * vx + vy * vy) ** 0.5
            if norm == 0.0:
                ux, uy = 1.0, 0.0
            else:
                ux, uy = vx / norm, vy / norm
            if len(entries) == 2 and labels_here == {'near', 'far'}:
                def assign(sign_map):
                    out = {}
                    for label, sid in entries:
                        direction = radial_offset * sign_map[label]
                        out[label] = (base_x + direction * ux, base_y + direction * uy)
                    return out

                default_map = {'near': 1.0, 'far': -1.0}
                swapped_map = {'near': -1.0, 'far': 1.0}

                def mapping_cost(mapped):
                    cost = 0.0
                    for label, (x, y) in mapped.items():
                        prev = prev_pos.get(label)
                        if prev is None:
                            continue
                        px, py = prev
                        cost += (x - px) * (x - px) + (y - py) * (y - py)
                    return cost

                opt_default = assign(default_map)
                opt_swapped = assign(swapped_map)
                if t in force_swap_times:
                    chosen = opt_swapped
                elif mapping_cost(opt_swapped) < mapping_cost(opt_default):
                    chosen = opt_swapped
                else:
                    chosen = opt_default
                for label, coords in chosen.items():
                    node_positions[(label, t, ring)] = coords
                    prev_pos[label] = coords
                continue
            if len(entries) == 1:
                label, sid = entries[0]
                node_positions[(label, t, ring)] = (base_x, base_y)
                prev_pos[label] = (base_x, base_y)
                continue
            if norm == 0.0:
                px, py = 0.0, 1.0
            else:
                px, py = -vy / norm, vx / norm
            ordered = sorted(entries, key=lambda item: item[0])
            span = len(ordered) - 1
            for idx, (label, sid) in enumerate(ordered):
                factor = idx - span / 2.0
                coords = (base_x + factor * offset_mag * px, base_y + factor * offset_mag * py)
                node_positions[(label, t, ring)] = coords
                prev_pos[label] = coords

    def node_xy(label, t, ring):
        return node_positions[(label, t, ring)]

    def draw_edges(lat_seq, label, z=1):
        for i in range(len(lat_seq) - 1):
            t, _, _, p = lat_seq[i]
            t2, _, _, p2 = lat_seq[i + 1]
            r1 = 'base' if p is None else p
            r2 = 'base' if p2 is None else p2
            x1, y1 = node_xy(label, t, r1)
            x2, y2 = node_xy(label, t2, r2)
            ax.add_patch(
                FancyArrowPatch(
                    (x1, y1),
                    (x2, y2),
                    arrowstyle='-|>',
                    mutation_scale=arrow_mutation,
                    lw=edge_lw,
                    color=edge_color,
                    alpha=0.9,
                    shrinkA=2.5,
                    shrinkB=2.5,
                    zorder=z,
                )
            )
        if close_loop and len(lat_seq) >= 2:
            t0, _, _, p0 = lat_seq[0]
            tL, _, _, pL = lat_seq[-1]
            r0 = 'base' if p0 is None else p0
            rL = 'base' if pL is None else pL
            x1, y1 = node_xy(label, tL, rL)
            x2, y2 = node_xy(label, t0, r0)
            ax.add_patch(
                FancyArrowPatch(
                    (x1, y1),
                    (x2, y2),
                    arrowstyle='-|>',
                    mutation_scale=arrow_mutation,
                    lw=edge_lw,
                    color=edge_color,
                    alpha=0.9,
                    shrinkA=2.5,
                    shrinkB=2.5,
                    zorder=z,
                )
            )

    def draw_nodes(lat_seq, label, z=3):
        for (t, obs, sid, path) in lat_seq:
            ring = 'base' if path is None else path
            x, y = node_xy(label, t, ring)
            face = node_colors[label].get((t, ring), '#bfbfbf')
            ax.add_patch(
                Circle(
                    (x, y),
                    radius=0.065,
                    facecolor=face,
                    edgecolor='#3a3a3a',
                    lw=1.1,
                    alpha=0.98,
                    zorder=z,
                )
            )

    draw_edges(near_lat, 'near', z=1)
    draw_edges(far_lat, 'far', z=1)
    draw_nodes(far_lat, 'far', z=3)
    draw_nodes(near_lat, 'near', z=4)

    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_xlim(-1.55, 1.55)
    ax.set_ylim(-1.55, 1.55)
    if title:
        ax.set_title(title, fontsize=20)


In [None]:
near_lat = canonical_latents(demo_checkpoints[4], near)
near_lat

In [None]:
far_lat = canonical_latents(demo_checkpoints[4], far)
far_lat

In [None]:
import os
os.makedirs("figures/poster", exist_ok=True)

export_specs = [
    ("svg", {"format": "svg"}),
    ("pdf", {"format": "pdf"}),
    ("png", {"format": "png", "dpi": 400})
]

final_session = max([s for s in [1, 3, 4, 9] if s in demo_checkpoints])
fig, ax = plt.subplots(figsize=(7.4, 4.6), constrained_layout=True)
draw_cscg_style_ring(
    ax,
    demo_checkpoints[final_session],
    title="CoDA — transition graph (CSCG-style colors)",
    scheme='blocks'
)
for ext, kwargs in export_specs:
    fig.savefig(
        os.path.join("figures", "poster", f"coda_transition_final_session.{ext}"),
        bbox_inches='tight',
        **kwargs
    )
plt.show()

for s in sorted(demo_checkpoints.keys()):
    fig_s, ax_s = plt.subplots(figsize=(7.4, 4.6), constrained_layout=True)
    draw_cscg_style_ring(
        ax_s,
        demo_checkpoints[s],
        title=f"Session {s}",
        scheme='blocks'
    )
    for ext, kwargs in export_specs:
        fig_s.savefig(
            os.path.join("figures", "poster", f"coda_transition_session_{s}.{ext}"),
            bbox_inches='tight',
            **kwargs
        )
    plt.show()


In [None]:
# === Trajectory of decorrelation during learning (CSCG-style) ===
import matplotlib as mpl

check_sessions = [1, 3, 4, 9]
mean_mats = {s: np.mean(mat_by_session[s], axis=0) for s in check_sessions}

# CSCG-style colormap (dark background, warm high values)
# 'magma' or 'inferno' gives similar look; both start dark→warm→bright
cmap = mpl.cm.get_cmap('magma')

# Pre/indicator windows to highlight (indices are inclusive, 0-based)
highlight_windows = [(6, 10), (13, 15), (18, 20)]
dash_lines = [6, 10, 13, 15, 18, 20]

# Make independent figures for each checkpoint session
for s in check_sessions:
    fig, ax = plt.subplots(figsize=(4.8, 4.8), constrained_layout=True)

    last_im = ax.imshow(
        mean_mats[s],
        vmin=-0.1,
        vmax=1.0,
        cmap=cmap,
        origin='upper',
        aspect='equal',
        interpolation='nearest'
    )

    ny, nx = mean_mats[s].shape
    ax.set_xticks(np.arange(-0.5, nx, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, ny, 1), minor=True)
    ax.grid(which='minor', color='white', linestyle=':', linewidth=0.4)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.tick_params(colors='white', labelsize=0, length=0)
    for spine in ax.spines.values():
        spine.set_visible(False)

    for idx in dash_lines:
        if idx <= nx:
            ax.axvline(idx - 0.5, color='white', linewidth=1.2, linestyle=(0, (2, 4)))
        if idx <= ny:
            ax.axhline(idx - 0.5, color='white', linewidth=1.2, linestyle=(0, (2, 4)))

    for low, high in highlight_windows:
        if high <= min(nx, ny):
            x0, x1 = low - 0.5, high - 0.5
            y0, y1 = low - 0.5, high - 0.5
            ax.plot(
                [x0, x1, x1, x0, x0],
                [y0, y0, y1, y1, y0],
                color='white',
                linewidth=2.5
            )

    fig.patch.set_facecolor('black')
    ax.set_facecolor('black')

    cbar = fig.colorbar(last_im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("")
    cbar.set_ticks([])
    cbar.ax.tick_params(length=0)

    plt.show()


In [None]:

# Fig. 4i
means = blocks_df.mean(); ses = blocks_df.sem()
labels = ['offdiag','preR2','preR1']
x = np.arange(len(labels)); y = [means[l] for l in labels]; yerr = [ses[l] for l in labels]
fig, ax = plt.subplots(figsize=(6.5,4.6))
bars = ax.bar(x, y, yerr=yerr, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(labels)
ax.set_ylim(0, 1.05)
ax.set_ylabel("Mean correlation (final)")
ax.set_title("CoDA (uncertainty‑aware) — Fig. 4i analogue")
for i, (b, val) in enumerate(zip(bars, y)):
    ax.text(b.get_x()+b.get_width()/2, (val-0.05 if val>0.85 else min(1.02, val+0.05)), f"{val:.2f}", ha='center',
            va=('top' if val>0.85 else 'bottom'), color=('white' if val>0.85 else 'black'), fontsize=9)
fig.tight_layout(); plt.show()

# Fig. 4j
means_t = times_df.mean(skipna=True); ses_t = times_df.sem(skipna=True)
labels_t = ['offdiag_t','preR2_t','preR1_t']; disp = ['offdiag','preR2','preR1']
x = np.arange(len(labels_t)); y = [means_t[l] for l in labels_t]; yerr = [ses_t[l] for l in labels_t]
fig, ax = plt.subplots(figsize=(6.5,4.6))
bars = ax.bar(x, y, yerr=yerr, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(disp)
ax.set_ylim(0, 1.05)
ax.set_ylabel("Fraction of training (first corr < 0.3)")
ax.set_title("CoDA (uncertainty‑aware) — Fig. 4j analogue")
for i, (b, val) in enumerate(zip(bars, y)):
    ax.text(b.get_x()+b.get_width()/2, min(1.02, val+0.05), f"{val:.2f}", ha='center', va='bottom', fontsize=9)
fig.tight_layout(); plt.show()
