# CoDA × Near/Far (US at site) — **Full Notebook using your *version that works***

**This notebook does the following:**
- Loads and executes your uploaded notebook: **`/mnt/data/version that works.ipynb`** to define the CoDA classes.
- Defines a minimal environment compatible with your agent, where **US=6** and **No‑US** is the **event site**:
  - NEAR: if `4→6` doesn’t happen, we end at **second 5** (treated as the no‑US terminal *for that episode*).
  - FAR:  if `5→6` doesn’t happen, we end at **second 4** (treated as the no‑US terminal *for that episode*).
- Trains over sessions with stochastic episodes (4→6 and 5→6 each with 50% probability),
  and plots CSCG‑style **near×far correlation** panels plus **Fig. 4i/4j** analogues.

**Note:** We do **not** require `agent.run_episode`. The encoder below infers a latent path from the
learned transition matrix and your agent's observation↔state mapping if available.

In [None]:

# --- Load & execute your base notebook to bring CoDA classes into the kernel ---
import types, os, sys, traceback

BASE_NOTEBOOK = "/mnt/data/version that works.ipynb"

def load_user_notebook(path=BASE_NOTEBOOK):
    if not os.path.exists(path):
        print(f"[WARN] Could not find {path}. If your CoDA classes are already defined, you can ignore this.")
        return
    try:
        nb = nbformat.read(path, as_version=4)
        g = globals()
        executed = 0
        for i, cell in enumerate(nb.cells):
            if cell.cell_type == "code":
                try:
                    exec(cell.source, g)
                    executed += 1
                except Exception as e:
                    print(f"[NOTE] Skipping a cell due to error: {e}\n{traceback.format_exc(limit=1)}")
        print(f"[OK] Executed {executed} code cells from: {path}")
    except Exception as e:
        print(f"[WARN] Failed to read/execute {path}: {e}")

load_user_notebook()


In [None]:

import numpy as np, random
import matplotlib.pyplot as plt
import matplotlib as mpl

random.seed(0); np.random.seed(0)

# Canonical evaluation sequences (unchanged)
near_eval = [1,1,1,1,1,1,  2,2,2,2,  1,1,1,  4,6,  1,1,1,  5,5,  1,1,  7,  0,0,0]
far_eval  = [1,1,1,1,1,1,  3,3,3,3,  1,1,1,  4,4,  1,1,1,  5,6,  1,1,  7,  0,0,0]

# OSM analysis windows
preR1_idx = list(range(10,13))   # 10..12
preR2_idx = list(range(15,18))   # 15..17
def block_indices(rows, cols): return [(r,c) for r in rows for c in cols]
offdiag_pairs     = block_indices(preR1_idx, preR2_idx) + block_indices(preR2_idx, preR1_idx)
same_preR1_pairs  = block_indices(preR1_idx, preR1_idx)
same_preR2_pairs  = block_indices(preR2_idx, preR2_idx)


In [None]:

# --- Minimal environment matching CoDAAgent expectations ---
class NFEnv:
    def __init__(self, n_states: int, reward: int = 6):
        self.num_unique_states = int(n_states)
        self.valid_actions = {s: [0] for s in range(self.num_unique_states)}  # one action
        self.rewarded_terminals   = [int(reward)]
        self.unrewarded_terminals = []    # set per-episode dynamically
        self.clone_dict = {}
        self.reverse_clone_dict = {}
    def add_clone_dict(self, clone_id: int, successor: int):
        self.clone_dict[clone_id] = successor
    def add_reverse_clone_dict(self, new_clone: int, successor: int):
        self.reverse_clone_dict[successor] = new_clone


In [None]:

# --- Stochastic NEAR/FAR generator where episodes end at the event site ---
near_base = [1,1,1,1,1,1,  2,2,2,2,  1,1,1,  4,6,  1,1,1,  5,5,  1,1,  7,  0,0,0]
far_base  = [1,1,1,1,1,1,  3,3,3,3,  1,1,1,  4,4,  1,1,1,  5,6,  1,1,  7,  0,0,0]

def find_nth(seq, value, n):
    k=0
    for i,v in enumerate(seq):
        if v==value:
            k+=1
            if k==n: return i
    return None

def truncate_at(seq, idx):
    return seq[:idx+1] if idx is not None else seq

def sample_episode(p4=0.5, p5=0.5, rng=None):
    """
    Returns (states, tag) with tag in {'R@4','NR@2nd-5','R@5','NR@2nd-4'}.
    NEAR: if 4→6 happens (prob p4), end at that 6; else end at 2nd 5 (no‑US site).
    FAR : if 5→6 happens (prob p5), end at that 6; else end at 2nd 4 (no‑US site).
    """
    rnd = random.random if rng is None else rng.random
    if rnd() < 0.5:  # NEAR
        seq = near_base.copy()
        i4  = seq.index(4)
        if rnd() < p4:
            term = i4+1; tag = 'R@4'         # end at the 6 after 4
        else:
            # remove that 6 if present, then end at 2nd 5
            if i4+1 < len(seq) and seq[i4+1]==6: seq.pop(i4+1)
            term = find_nth(seq, 5, 2); tag='NR@2nd-5'
        return truncate_at(seq, term), tag
    else:            # FAR
        seq = far_base.copy()
        i5  = seq.index(5)
        if rnd() < p5:
            term = i5+1; tag='R@5'          # end at the 6 after 5
        else:
            if i5+1 < len(seq) and seq[i5+1]==6: seq.pop(i5+1)
            term = find_nth(seq, 4, 2); tag='NR@2nd-4'
        return truncate_at(seq, term), tag

# Show a few examples
print("Examples:")
for _ in range(4):
    s, t = sample_episode()
    print(t, s)


In [None]:

# --- Helpers: call episode update with different method names; get T matrix; encode sequence ---
def agent_update_episode(agent, states, actions):
    for name in ['update_with_episode','update_episode','learn_episode','update']:
        fn = getattr(agent, name, None)
        if callable(fn):
            return fn(states, actions)
    raise AttributeError("Agent has no suitable episode update method: tried update_with_episode/update_episode/learn_episode/update")

def agent_maybe_split(agent):
    fn = getattr(agent, 'maybe_split', None)
    if callable(fn):
        return fn()
    return None

def agent_get_T(agent):
    # Return transition probability tensor [S,A,S']
    if hasattr(agent, 'get_T') and callable(agent.get_T):
        T = agent.get_T()
    elif hasattr(agent, 'transition_probs'):
        T = agent.transition_probs
    elif hasattr(agent, 'transition_counts'):
        cnt = agent.transition_counts.astype(float)
        den = cnt.sum(axis=2, keepdims=True); den[den==0]=1.0
        T = cnt/den
    else:
        raise AttributeError("Cannot find transition tensor on agent")
    return T

def agent_obs_mapping(agent):
    # Returns (obs_of_state list, obs_to_states dict) if available
    obs_of_state = getattr(agent, 'obs_of_state', None)
    if obs_of_state is None and hasattr(agent, 'state_obs'):
        # convert dict {state:obs} to list
        st_obs = agent.state_obs
        max_sid = max(st_obs.keys())
        obs_of_state = [st_obs.get(i, 0) for i in range(max_sid+1)]
    obs_to_states = {}
    if obs_of_state is not None:
        for sid, obs in enumerate(obs_of_state):
            obs_to_states.setdefault(obs, []).append(sid)
    return obs_of_state, obs_to_states

def encode_sequence(agent, obs_seq):
    """
    Infer a latent path given observations using learned T and (if available) obs↔state map.
    Greedy argmax over next-state among candidates matching the next observation.
    """
    T = agent_get_T(agent)
    S, A, S2 = T.shape
    a = 0 if A>0 else 0
    obs_of_state, obs_to_states = agent_obs_mapping(agent)

    path = []
    if not obs_seq: return path

    # Choose start state: first candidate matching obs0, else 0
    obs0 = obs_seq[0]
    cands0 = obs_to_states.get(obs0, [])
    s = min(cands0) if cands0 else (obs0 if isinstance(obs0,int) and obs0 < S else 0)
    path.append(s)

    for t in range(len(obs_seq)-1):
        obs_next = obs_seq[t+1]
        cands = obs_to_states.get(obs_next, list(range(S))) if obs_to_states else list(range(S))
        row = T[s, a, :]
        if cands:
            scores = [(row[j], j) for j in cands]
            j_best = max(scores)[1]
        else:
            j_best = int(np.argmax(row))
        s = j_best
        path.append(s)
    return path

def latent_onehot(agent, obs_seq):
    lat = encode_sequence(agent, obs_seq)
    try:
        S = agent.transition_counts.shape[0]
    except Exception:
        S = (max(lat)+1) if lat else 1
    X = np.zeros((len(lat), S))
    for t,sid in enumerate(lat):
        if sid < S:
            X[t,sid]=1.0
    return X

def near_far_corr(agent, near_seq, far_seq):
    A = latent_onehot(agent, near_seq)
    B = latent_onehot(agent, far_seq)
    C = np.zeros((A.shape[0], B.shape[0]))
    for i in range(A.shape[0]):
        for j in range(B.shape[0]):
            a = A[i]; b = B[j]
            a0 = a - a.mean(); b0 = b - b.mean()
            den = np.linalg.norm(a0)*np.linalg.norm(b0)
            C[i,j] = (a0@b0)/den if den>0 else 0.0
    return C


In [None]:

# --- Train once and collect session matrices ---
def train_once(seed=0, sessions=9, trials_per_session=120, p4=0.5, p5=0.5, cfg_kwargs=None):
    random.seed(seed); np.random.seed(seed)
    # Build env/agent
    n_states = max(max(near_eval), max(far_eval)) + 1
    env = NFEnv(n_states=n_states, reward=6)

    # Create config and agent from user's code
    if 'CoDAConfig' not in globals() or 'CoDAAgent' not in globals():
        raise RuntimeError("CoDAConfig/CoDAAgent not found. Ensure your base notebook was loaded, or paste the classes above.")

    # Defaults for easier propagation; override if requested
    kw = dict(gamma=0.9, lam=0.8, theta_split=0.85, theta_merge=0.50,
              n_threshold=3, min_presence_episodes=3, min_effective_exposure=8.0,
              confidence=0.90, alpha0=0.25, beta0=0.25,
              count_decay=1.0, trace_decay=1.0, retro_decay=1.0)
    if isinstance(cfg_kwargs, dict):
        kw.update(cfg_kwargs)
    cfg = CoDAConfig(**{k:v for k,v in kw.items() if k in CoDAConfig.__annotations__ or hasattr(CoDAConfig, k)})

    # Try common constructor patterns
    agent = None; err = None
    for ctor in [(CoDAAgent, (env, cfg), {}),
                 (CoDAAgent, (), {'env':env, 'cfg':cfg}),
                 (CoDAAgent, (env,), {}),
                 (CoDAAgent, (), {})]:
        Cls, args, kwargs = ctor
        try:
            agent = Cls(*args, **kwargs)
            break
        except Exception as e:
            err = e
    if agent is None:
        raise RuntimeError(f"Could not instantiate CoDAAgent with env/cfg. Last error: {err}")

    mats_by_session = {}
    THRESH = 0.3
    tt = {'offdiag': None, 'preR2': None, 'preR1': None}

    for session in range(1, sessions+1):
        for _ in range(trials_per_session):
            states, tag = sample_episode(p4=p4, p5=p5)

            # Dynamic binding of no‑US site for THIS episode
            # Reward is always 6
            if tag == 'NR@2nd-5':     # no-US at 2nd 5 (NEAR)
                env.unrewarded_terminals = [5]
            elif tag == 'NR@2nd-4':   # no-US at 2nd 4 (FAR)
                env.unrewarded_terminals = [4]
            else:
                env.unrewarded_terminals = []

            actions = [0] * (len(states)-1)
            agent_update_episode(agent, states, actions)
            agent_maybe_split(agent)   # optional merge off during acquisition

        # snapshot
        C = near_far_corr(agent, near_eval, far_eval)
        mats_by_session[session] = C.copy()

        # threshold bookkeeping
        def block_mean(C, pairs): 
            return float(np.mean([C[i,j] for (i,j) in pairs])) if pairs else np.nan
        b_off = block_mean(C, offdiag_pairs)
        b_r2  = block_mean(C, same_preR2_pairs)
        b_r1  = block_mean(C, same_preR1_pairs)
        if tt['offdiag'] is None and b_off < THRESH: tt['offdiag'] = session
        if tt['preR2']  is None and b_r2  < THRESH:  tt['preR2']  = session
        if tt['preR1']  is None and b_r1  < THRESH:  tt['preR1']  = session

    return agent, mats_by_session, tt


In [None]:

# --- Multi-run ---
N_RUNS = 6
SESSIONS = 9
TRIALS_PER_SESSION = 120
P4 = 0.5; P5 = 0.5

all_mats = {s: [] for s in range(1, SESSIONS+1)}
times = []
agents = []

for r in range(N_RUNS):
    agent_r, mats_r, tt_r = train_once(seed=1234+r, sessions=SESSIONS, trials_per_session=TRIALS_PER_SESSION,
                                       p4=P4, p5=P5, cfg_kwargs=None)
    agents.append(agent_r)
    for s in range(1, SESSIONS+1):
        all_mats[s].append(mats_r[s])
    times.append(tt_r)

agent = agents[0]  # choose one for any additional inspection

# Averages for selected sessions
check_sessions = [1,3,4,9]
mean_mats = {s: np.mean(np.stack(all_mats[s], axis=0), axis=0) for s in check_sessions}

# Normalize time-to-threshold for Fig. 4j analogue
def norm(x): return (x/SESSIONS) if (x is not None) else np.nan
times_norm = np.array([[norm(tt['offdiag']), norm(tt['preR2']), norm(tt['preR1'])] for tt in times])


In [None]:

# --- CSCG-style correlation panels ---
fig, axes = plt.subplots(1, len(check_sessions), figsize=(4.8 * len(check_sessions), 4.8),
                         constrained_layout=True, sharex=True, sharey=True)
cmap = mpl.cm.get_cmap('magma')
vmin, vmax = -0.1, 1.0
last_im=None
for ax, s in zip(axes, check_sessions):
    M = mean_mats[s]
    last_im = ax.imshow(M, vmin=vmin, vmax=vmax, cmap=cmap,
                        origin='upper', aspect='equal', interpolation='nearest')
    ax.set_title(f"Session {s}", color='white')
    ax.set_xlabel("Far position index", color='white'); ax.set_ylabel("Near position index", color='white')
    ny, nx = M.shape
    ax.set_xticks(np.arange(-0.5, nx, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, ny, 1), minor=True)
    ax.grid(which='minor', color='white', linestyle=':', linewidth=0.4)
    ax.tick_params(colors='white', labelsize=8)
    for spine in ax.spines.values(): spine.set_visible(False)
fig.patch.set_facecolor('black')
for ax in axes: ax.set_facecolor('black')
cbar = fig.colorbar(last_im, ax=axes, fraction=0.046, pad=0.04)
cbar.set_label("Correlation", color='white'); cbar.ax.yaxis.set_tick_params(color='white', labelcolor='white')
plt.show()


In [None]:

# --- Fig. 4i analogue: final block means ---
def block_mean(C, pairs): 
    return float(np.mean([C[i,j] for (i,j) in pairs])) if pairs else np.nan

final_blocks = []
for r in range(N_RUNS):
    C_final = all_mats[SESSIONS][r]
    final_blocks.append([
        block_mean(C_final, offdiag_pairs),
        block_mean(C_final, same_preR2_pairs),
        block_mean(C_final, same_preR1_pairs)
    ])
final_blocks = np.array(final_blocks)
means = np.nanmean(final_blocks, axis=0)
sems  = np.nanstd(final_blocks, axis=0, ddof=1) / np.sqrt(max(1, N_RUNS))

labels = ['offdiag','preR2','preR1']
x = np.arange(len(labels))
fig, ax = plt.subplots(figsize=(6.5,4.6))
bars = ax.bar(x, means, yerr=sems, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(labels)
ax.set_ylim(0, 1.05)
ax.set_ylabel("Mean correlation (final)")
ax.set_title("CoDA (US at site) — Fig. 4i analogue")
for b, val in zip(bars, means):
    if val > 0.85:
        ax.text(b.get_x()+b.get_width()/2, val-0.05, f"{val:.2f}", ha='center', va='top', color='white', fontsize=9)
    else:
        ax.text(b.get_x()+b.get_width()/2, min(1.02, val+0.05), f"{val:.2f}", ha='center', va='bottom', fontsize=9)
plt.show()


In [None]:

# --- Fig. 4j analogue: normalized time to first < 0.3 ---
labels = ['offdiag','preR2','preR1']
means_t = np.nanmean(times_norm, axis=0)
sems_t  = np.nanstd(times_norm, axis=0, ddof=1) / np.sqrt(max(1, times_norm.shape[0]))
x = np.arange(len(labels))
fig, ax = plt.subplots(figsize=(6.5,4.6))
bars = ax.bar(x, means_t, yerr=sems_t, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(labels)
ax.set_ylim(0, 1.05)
ax.set_ylabel("Fraction of training (first corr < 0.3)")
ax.set_title("CoDA (US at site) — Fig. 4j analogue")
for b, val in zip(bars, means_t):
    ax.text(b.get_x()+b.get_width()/2, min(1.02, val+0.05), f"{val:.2f}", ha='center', va='bottom', fontsize=9)
plt.show()
