# CoDA trial-by-trial with **uncertainty-aware** split & online merge

This single notebook runs acquisition → split and then extinction → merge, with diagnostics.

In [None]:

import sys, numpy as np
sys.path.append('/mnt/data')
from coda_trial_by_trial import (
    CoDAAgent, CoDAConfig,
    GridEnvRightDownNoSelf, GridEnvRightDownNoCue,
    generate_episode, posterior_prob_p_greater_than
)
import matplotlib.pyplot as plt


In [None]:

# Configure uncertainty-aware splitting
cfg = CoDAConfig(
    gamma=0.9, lam=0.8,
    theta_split=0.9, theta_merge=0.5,
    n_threshold=8,                 # min step-level evidence
    min_presence_episodes=8,       # min episode-level presence
    min_effective_exposure=25.0,   # min E_r + E_nr
    confidence=0.95,               # require P(p>theta_split | data) >= 0.95
    alpha0=0.5, beta0=0.5          # Jeffreys prior
)

env = GridEnvRightDownNoSelf(cue_states=[5], env_size=(4,4), rewarded_terminal=[15])
agent = CoDAAgent(env, cfg)


In [None]:

# Acquisition: run episodes and allow splitting when uncertainty gates are satisfied
split_eps = []
with_clones = False
for ep in range(1, 400):
    states, actions = generate_episode(env, T=agent.get_T() if with_clones else None, max_steps=20)
    agent.update_with_episode(states, actions)
    new = agent.maybe_split()
    if new:
        with_clones = True
        split_eps.extend([ep]*len(new))
        # optional: stop early once we see the first split
        # break

print("Split episodes:", split_eps)


In [None]:

# Diagnostics: PC, posterior tail prob, and RC
pc = agent.prospective()
rc = agent.retrospective()
exposure = (agent.E_r + agent.E_nr).reshape(-1)
post_tail = np.array([posterior_prob_p_greater_than(cfg.theta_split, agent.E_r[0,i], agent.E_nr[0,i], cfg.alpha0, cfg.beta0) for i in range(agent.n_states)])

# Plot prospective contingency
plt.figure()
plt.title("Prospective contingency P(US|CS)")
plt.plot(pc, marker='o')
plt.axhline(cfg.theta_split, linestyle='--')
plt.xlabel("state index"); plt.ylabel("PC")
plt.show()

# Plot posterior tail probability
plt.figure()
plt.title(f"Posterior tail Pr[P(US|CS)>={cfg.theta_split}]")
plt.plot(post_tail, marker='o')
plt.axhline(cfg.confidence, linestyle='--')
plt.xlabel("state index"); plt.ylabel("tail prob")
plt.show()

# Plot retrospective contingency
plt.figure()
plt.title("Retrospective contingency P(CS|US)")
plt.plot(rc, marker='o')
plt.xlabel("state index"); plt.ylabel("RC")
plt.show()


In [None]:

# Extinction-like phase: reward no longer depends on cue, so informativeness drops and merges should occur
env2 = GridEnvRightDownNoCue(cue_states=[5], env_size=(4,4), rewarded_terminal=[15])
env2.clone_dict = dict(agent.env.clone_dict)
env2.reverse_clone_dict = dict(agent.env.reverse_clone_dict)
agent.env = env2

merge_eps = []
for ep in range(400, 900):
    states, actions = generate_episode(env2, T=agent.get_T(), max_steps=20)
    agent.update_with_episode(states, actions)
    merged = agent.maybe_merge()
    if merged:
        merge_eps.extend([ep]*len(merged))

print("Merge episodes:", merge_eps)
