# CoDA × Near/Far (stochastic US at site) — **Full Notebook**

**What this notebook does**
- Uses your existing **CoDAAgent / CoDAConfig** (no changes) and only adapts the **environment + data pipeline**.
- **Reward (US) = 6**; **No‑reward (NR) = backup site** where the US would have appeared:
  - On **NEAR** trials, if `4→6` doesn’t occur, the episode ends at the **second 5** (NR@2nd‑5).
  - On **FAR** trials, if `5→6` doesn’t occur, the episode ends at the **second 4** (NR@2nd‑4).
- Trains across sessions, then plots CSCG‑style **near×far correlation** panels and **Fig. 4i/4j analogues**.

**How to use**
1) If your `CoDAAgent` / `CoDAConfig` are not already in this kernel, paste them in the cell marked *Paste your CoDA classes here* **or** let the next cell try to import them from `/mnt/data`.
2) Run the notebook top to bottom.


In [1]:
import sys
sys.path.append('/mnt/data')
try:
    from coda_trial_by_trial_util import CoDAAgent, CoDAConfig
    print('Imported CoDAAgent/CoDAConfig from coda_trial_by_trial_util.py')
except Exception as e:
    print('Note: CoDAAgent/CoDAConfig not imported from file; assuming they are already defined in this session.')


Imported CoDAAgent/CoDAConfig from coda_trial_by_trial_util.py


### (Optional) Paste your `CoDAAgent` / `CoDAConfig` here if not already defined

In [2]:
# Paste your class definitions here if needed
# class CoDAConfig: ...
# class CoDAAgent:  ...

In [3]:
import numpy as np, random
import matplotlib.pyplot as plt
import matplotlib as mpl

random.seed(0); np.random.seed(0)


In [4]:
# Canonical sequences for evaluation (kept fixed)
near_eval = [1,1,1,1,1,1,  2,2,2,2,  1,1,1,  4,6,  1,1,1,  5,5,  1,1,  7,  0,0,0]
far_eval  = [1,1,1,1,1,1,  3,3,3,3,  1,1,1,  4,4,  1,1,1,  5,6,  1,1,  7,  0,0,0]

# Analysis windows as in OSM Fig. 4
preR1_idx = list(range(10,13))   # 10..12
preR2_idx = list(range(15,18))   # 15..17

def block_indices(rows, cols): return [(r,c) for r in rows for c in cols]
offdiag_pairs     = block_indices(preR1_idx, preR2_idx) + block_indices(preR2_idx, preR1_idx)
same_preR1_pairs  = block_indices(preR1_idx, preR1_idx)
same_preR2_pairs  = block_indices(preR2_idx, preR2_idx)


In [5]:
# -------- Minimal environment that matches CoDAAgent expectations --------
class NFEnv:
    def __init__(self, n_states: int, reward: int = 6, noreward_candidates=(4,5)):
        self.num_unique_states = int(n_states)
        self.valid_actions = {s: [0] for s in range(self.num_unique_states)}  # single action
        self.rewarded_terminals   = [int(reward)]
        self.unrewarded_terminals = [int(x) for x in noreward_candidates]      # we override per-episode
        self.clone_dict = {}
        self.reverse_clone_dict = {}

    def add_clone_dict(self, clone_id: int, successor: int):
        self.clone_dict[clone_id] = successor

    def add_reverse_clone_dict(self, new_clone: int, successor: int):
        self.reverse_clone_dict[successor] = new_clone


In [6]:
# -------- Stochastic NEAR/FAR episode generator --------
near_base = [1,1,1,1,1,1,  2,2,2,2,  1,1,1,  4,6,  1,1,1,  5,5,  1,1,  7,  0,0,0]
far_base  = [1,1,1,1,1,1,  3,3,3,3,  1,1,1,  4,4,  1,1,1,  5,6,  1,1,  7,  0,0,0]

def find_nth(seq, value, n):
    k = 0
    for i, v in enumerate(seq):
        if v == value:
            k += 1
            if k == n: return i
    return None

def truncate_at_index(seq, idx):
    return seq[:idx+1] if idx is not None else seq

def sample_episode(p_4_to_6=0.5, p_5_to_6=0.5, rng=None):
    """
    Returns: (states, (branch, tag))
      branch in {"near","far"}
      tag in {"R", "NR@2nd-5", "NR@2nd-4"}
    """
    rnd = random.random if rng is None else rng.random
    if rnd() < 0.5:
        # ---- NEAR ----
        seq = near_base.copy()
        idx4 = seq.index(4)
        keep_US = rnd() < p_4_to_6
        if keep_US:
            term_idx = idx4 + 1           # terminate at the 6 after 4
            label = ("near", "R")
        else:
            if idx4+1 < len(seq) and seq[idx4+1] == 6:
                seq.pop(idx4+1)           # remove the 6
            term_idx = find_nth(seq, 5, 2)  # backup NR at 2nd 5
            label = ("near", "NR@2nd-5")
        return truncate_at_index(seq, term_idx), label
    else:
        # ---- FAR ----
        seq = far_base.copy()
        idx5 = seq.index(5)
        keep_US = rnd() < p_5_to_6
        if keep_US:
            term_idx = idx5 + 1           # terminate at the 6 after 5
            label = ("far", "R")
        else:
            if idx5+1 < len(seq) and seq[idx5+1] == 6:
                seq.pop(idx5+1)           # remove the 6
            term_idx = find_nth(seq, 4, 2)  # backup NR at 2nd 4
            label = ("far", "NR@2nd-4")
        return truncate_at_index(seq, term_idx), label

print('Example sampled episodes (branch, tag):')
for _ in range(4):
    s, lab = sample_episode(0.5, 0.5)
    print(lab, s)


Example sampled episodes (branch, tag):
('far', 'NR@2nd-4') [1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 4, 4]
('near', 'R') [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 4, 6]
('far', 'R') [1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 4, 4, 1, 1, 1, 5, 6]
('far', 'R') [1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 4, 4, 1, 1, 1, 5, 6]


In [7]:
# -------- Utilities: latent one-hot + near×far correlation --------
def latent_onehot(agent, obs_seq):
    lat = agent.run_episode(obs_seq, learn=False)
    S = getattr(agent, 'n_states', None)
    if S is None:
        try:
            S = agent.transition_counts.shape[0]
        except Exception:
            S = max(lat)+1 if lat else 1
    X = np.zeros((len(lat), S))
    for t, sid in enumerate(lat):
        if sid < S:
            X[t, sid] = 1.0
    return X

def near_far_corr(agent, near_seq, far_seq):
    A = latent_onehot(agent, near_seq)
    B = latent_onehot(agent, far_seq)
    C = np.zeros((A.shape[0], B.shape[0]))
    for i in range(A.shape[0]):
        for j in range(B.shape[0]):
            a = A[i]; b = B[j]
            a0 = a - a.mean(); b0 = b - b.mean()
            den = np.linalg.norm(a0) * np.linalg.norm(b0)
            C[i,j] = (a0 @ b0)/den if den>0 else 0.0
    return C


In [8]:
# -------- Train once and collect session snapshots --------
def train_once(seed=0, sessions=9, trials_per_session=120, p4=0.5, p5=0.5,
               cfg_override=None):
    random.seed(seed); np.random.seed(seed)
    # Build env/agent
    n_states = max(max(near_eval), max(far_eval)) + 1
    env = NFEnv(n_states=n_states, reward=6, noreward_candidates=(4,5))
    if cfg_override is None:
        try:
            cfg = CoDAConfig(
                gamma=0.9, lam=0.8,
                theta_split=0.85, theta_merge=0.50,
                n_threshold=3, min_presence_episodes=3, min_effective_exposure=8.0,
                confidence=0.90, alpha0=0.25, beta0=0.25,
                count_decay=1.0, trace_decay=1.0, retro_decay=1.0
            )
        except Exception as e:
            raise RuntimeError('CoDAConfig is not defined. Please paste your CoDA classes above.') from e
    else:
        cfg = cfg_override
    try:
        agent = CoDAAgent(env, cfg)
    except Exception as e:
        raise RuntimeError('CoDAAgent is not defined. Please paste your CoDA classes above.') from e

    mats_by_session = {}
    THRESH = 0.3
    tt = {'offdiag': None, 'preR2': None, 'preR1': None}

    for session in range(1, sessions+1):
        for _ in range(trials_per_session):
            states, (branch, tag) = sample_episode(p_4_to_6=p4, p_5_to_6=p5)
            # Bind correct NR outcome for this episode
            if tag == 'NR@2nd-5':
                env.unrewarded_terminals = [5]
            elif tag == 'NR@2nd-4':
                env.unrewarded_terminals = [4]
            else:
                env.unrewarded_terminals = []  # rewarded episode
            actions = [0] * (len(states)-1)
            agent.update_with_episode(states, actions)
            agent.maybe_split()
            # agent.maybe_merge()  # keep off during acquisition
        # snapshot correlation matrix
        C = near_far_corr(agent, near_eval, far_eval)
        mats_by_session[session] = C.copy()

        # threshold bookkeeping (like Fig. 4j)
        def block_mean(C, pairs): 
            return float(np.mean([C[i,j] for (i,j) in pairs])) if pairs else np.nan
        b_off = block_mean(C, offdiag_pairs)
        b_r2  = block_mean(C, same_preR2_pairs)
        b_r1  = block_mean(C, same_preR1_pairs)
        if tt['offdiag'] is None and b_off < THRESH: tt['offdiag'] = session
        if tt['preR2']  is None and b_r2  < THRESH:  tt['preR2']  = session
        if tt['preR1']  is None and b_r1  < THRESH:  tt['preR1']  = session

    return agent, mats_by_session, tt


In [9]:
# -------- Multi-run to compute means/SEMs and time-to-threshold --------
N_RUNS = 6
SESSIONS = 9
TRIALS_PER_SESSION = 120
P4 = 0.5; P5 = 0.5

all_mats = {s: [] for s in range(1, SESSIONS+1)}
times = []
agents = []

for r in range(N_RUNS):
    agent_r, mats_r, tt_r = train_once(seed=1000+r, sessions=SESSIONS, trials_per_session=TRIALS_PER_SESSION,
                                       p4=P4, p5=P5, cfg_override=None)
    agents.append(agent_r)
    for s in range(1, SESSIONS+1):
        all_mats[s].append(mats_r[s])
    times.append(tt_r)

agent = agents[0]
check_sessions = [1,3,4,9]
mean_mats = {s: np.mean(np.stack(all_mats[s], axis=0), axis=0) for s in check_sessions}

def norm(x): return (x / SESSIONS) if (x is not None) else np.nan
times_norm = np.array([[norm(tt['offdiag']), norm(tt['preR2']), norm(tt['preR1'])] for tt in times])


AttributeError: 'CoDAAgent' object has no attribute 'run_episode'

In [None]:
# -------- CSCG-style correlation panels (sessions 1,3,4,9) --------
fig, axes = plt.subplots(1, len(check_sessions), figsize=(4.8 * len(check_sessions), 4.8),
                         constrained_layout=True, sharex=True, sharey=True)
cmap = mpl.cm.get_cmap('magma')
vmin, vmax = -0.1, 1.0
last_im = None

for ax, s in zip(axes, check_sessions):
    M = mean_mats[s]
    last_im = ax.imshow(M, vmin=vmin, vmax=vmax, cmap=cmap,
                        origin='upper', aspect='equal', interpolation='nearest')
    ax.set_title(f'Session {s}', color='white')
    ax.set_xlabel('Far position index', color='white'); ax.set_ylabel('Near position index', color='white')
    ny, nx = M.shape
    ax.set_xticks(np.arange(-0.5, nx, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, ny, 1), minor=True)
    ax.grid(which='minor', color='white', linestyle=':', linewidth=0.4)
    ax.tick_params(colors='white', labelsize=8)
    for spine in ax.spines.values(): spine.set_visible(False)
fig.patch.set_facecolor('black')
for ax in axes: ax.set_facecolor('black')
cbar = fig.colorbar(last_im, ax=axes, fraction=0.046, pad=0.04)
cbar.set_label('Correlation', color='white'); cbar.ax.yaxis.set_tick_params(color='white', labelcolor='white')
plt.show()


In [None]:
# -------- Fig. 4i analogue: final block correlations (means ± s.e.m.) --------
def block_mean(C, pairs):
    return float(np.mean([C[i,j] for (i,j) in pairs])) if pairs else np.nan

final_blocks = []
for r in range(N_RUNS):
    C_final = all_mats[SESSIONS][r]
    final_blocks.append([
        block_mean(C_final, offdiag_pairs),
        block_mean(C_final, same_preR2_pairs),
        block_mean(C_final, same_preR1_pairs)
    ])
final_blocks = np.array(final_blocks)
means = np.nanmean(final_blocks, axis=0)
sems  = np.nanstd(final_blocks, axis=0, ddof=1) / np.sqrt(max(1, N_RUNS))

labels = ['offdiag','preR2','preR1']
x = np.arange(len(labels))
fig, ax = plt.subplots(figsize=(6.5,4.6))
bars = ax.bar(x, means, yerr=sems, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(labels)
ax.set_ylim(0, 1.05)
ax.set_ylabel('Mean correlation (final)')
ax.set_title('CoDA (US@site) — Fig. 4i analogue')
for b, val in zip(bars, means):
    if val > 0.85:
        ax.text(b.get_x()+b.get_width()/2, val-0.05, f'{val:.2f}', ha='center', va='top', color='white', fontsize=9)
    else:
        ax.text(b.get_x()+b.get_width()/2, min(1.02, val+0.05), f'{val:.2f}', ha='center', va='bottom', fontsize=9)
plt.show()


In [None]:
# -------- Fig. 4j analogue: normalized time to first < 0.3 --------
labels = ['offdiag','preR2','preR1']
means_t = np.nanmean(times_norm, axis=0)
sems_t  = np.nanstd(times_norm, axis=0, ddof=1) / np.sqrt(max(1, times_norm.shape[0]))
x = np.arange(len(labels))
fig, ax = plt.subplots(figsize=(6.5,4.6))
bars = ax.bar(x, means_t, yerr=sems_t, capsize=4, color=['#777','#2f6db3','#d64b5a'])
ax.set_xticks(x); ax.set_xticklabels(labels)
ax.set_ylim(0, 1.05)
ax.set_ylabel('Fraction of training (first corr < 0.3)')
ax.set_title('CoDA (US@site) — Fig. 4j analogue')
for b, val in zip(bars, means_t):
    ax.text(b.get_x()+b.get_width()/2, min(1.02, val+0.05), f'{val:.2f}', ha='center', va='bottom', fontsize=9)
plt.show()
