# CoDA — Graph snapshots **only when changed** (uncertainty-aware)

This notebook runs acquisition → extinction with the uncertainty-aware split rule, but **only draws a graph snapshot when the graph changes** relative to the previous episode.

**Change criterion** (fast, deterministic):
- Any change in transition tensor shape (e.g., new clone added).
- After aggregating over actions, any change in the **thresholded** adjacency (same threshold as the plotter).
- Any change in `clone_dict` mapping (detects clone re-wiring/merge).

This matches the visual “Graph at episode XX” you use, but avoids redundant frames.


In [None]:

import sys, numpy as np
# sys.path.append('/mnt/data')
import matplotlib.pyplot as plt
from typing import List, Optional
from spatial_environments import GridEnvRightDownNoSelf, GridEnvRightDownNoCue, GridEnvRightDownExtinction
from coda_trial_by_trial_util import CoDAAgent, CoDAConfig
# from spatial_environments import GridEnvRightDownNoSelf, GridEnvRightDownNoCue
from util import generate_dataset, generate_dataset_post_augmentation, compute_transition_entropies, find_stochastic_state_actions_by_entropy, get_successor_states

# Rewritten metrics functions to ensure correct handling of fixed reference shapes
EPS = 1e-12

def _safe_row_norm(x: np.ndarray, axis: int = -1, eps: float = EPS) -> np.ndarray:
    y = x.astype(float, copy=True)
    s = y.sum(axis=axis, keepdims=True)
    s[s < eps] = 1.0
    y /= s
    return y

def _pad_to_shape(A: np.ndarray, shape: tuple) -> np.ndarray:
    S, A_, S2 = A.shape
    Sg, Ag, S2g = shape
    out = np.zeros(shape, dtype=float)
    out[:min(S,Sg), :min(A_,Ag), :min(S2,S2g)] = A[:min(S,Sg), :min(A_,Ag), :min(S2,S2g)]
    return out

def _aggregate_actions(T: np.ndarray) -> np.ndarray:
    """Aggregate over actions to get P(s'|s) as [S,S] row-stochastic."""
    return _safe_row_norm(T.sum(axis=1), axis=1)

def _kl_row(p: np.ndarray, q: np.ndarray, eps: float = EPS) -> float:
    p = np.clip(p, eps, 1.0); p /= p.sum()
    q = np.clip(q, eps, 1.0); q /= q.sum()
    return float(np.sum(p * (np.log(p) - np.log(q))))

def _js_row(p: np.ndarray, q: np.ndarray, eps: float = EPS) -> float:
    p = np.clip(p, eps, 1.0); p /= p.sum()
    q = np.clip(q, eps, 1.0); q /= q.sum()
    m = 0.5*(p+q)
    return 0.5*_kl_row(p, m, eps) + 0.5*_kl_row(q, m, eps)

def _entropy_row(p: np.ndarray, eps: float = EPS) -> float:
    p = np.clip(p, eps, 1.0); p /= p.sum()
    return float(-np.sum(p * np.log(p)))

def kl_over_time(T_series: List[np.ndarray],
                 T_ref_fn,
                 weights: Optional[np.ndarray] = None,
                 use_js: bool = False,
                 base_states_only: bool = False,
                 threshold: float = 0.0) -> np.ndarray:
    """
    Compute KL (or JS) divergence over time between learned T and reference.
    The reference function should return a fixed-shape reference tensor.
    Both T and T_ref are padded to a consistent comparison shape for fair comparison.
    
    If base_states_only=True, only compute KL over base states (first num_unique_states),
    excluding clone states. This is useful for extinction where reference has zeros for clones.
    
    If threshold > 0, only consider transitions above threshold in the adjacency matrix.
    This focuses the comparison on significant transitions and ignores small/noise transitions.
    """
    if len(T_series) == 0:
        return np.array([])
    
    # Get reference shape from first call - this is our fixed comparison shape
    T_ref_sample = T_ref_fn(T_series[0])
    ref_shape = T_ref_sample.shape
    
    # Always use reference shape for comparison (not max shape)
    # This ensures consistent comparison throughout extinction
    comparison_shape = ref_shape
    
    # If base_states_only, determine base state count from reference
    # (assuming reference has base states in first num_unique_states positions)
    if base_states_only:
        # Find base state count by looking for first row with all zeros in reference
        T_ref_check = _pad_to_shape(T_ref_sample, comparison_shape)
        Q_check = _aggregate_actions(T_ref_check)
        # Base states are those that have non-zero transitions in reference
        # (clone states in extinction reference are all zeros)
        base_mask = Q_check.sum(axis=1) > EPS
        if base_mask.any():
            n_base = int(np.where(~base_mask)[0][0]) if (~base_mask).any() else Q_check.shape[0]
        else:
            n_base = Q_check.shape[0]
    else:
        n_base = None
    
    scores = []
    for T in T_series:
        T_ref = T_ref_fn(T)
        # Ensure reference is at its fixed shape
        assert T_ref.shape == ref_shape, f"Reference shape changed: {T_ref.shape} != {ref_shape}"
        
        # Pad both to the reference shape for consistent comparison
        T_padded = _pad_to_shape(T, comparison_shape)
        T_ref_padded = _pad_to_shape(T_ref, comparison_shape)
        
        # Normalize T_padded to ensure it's a proper probability distribution
        # (padding might have introduced zeros, so renormalize)
        for s in range(T_padded.shape[0]):
            for a in range(T_padded.shape[1]):
                row_sum = T_padded[s, a, :].sum()
                if row_sum > EPS:
                    T_padded[s, a, :] /= row_sum
        
        # Aggregate over actions
        P = _aggregate_actions(T_padded)
        Q = _aggregate_actions(T_ref_padded)
        
        # If base_states_only, only compare base states
        if base_states_only and n_base is not None:
            P = P[:n_base, :n_base]
            Q = Q[:n_base, :n_base]
            S = n_base
        else:
            S = P.shape[0]
        
        # Apply threshold if specified: only consider transitions above threshold
        if threshold > 0:
            # Create mask for significant transitions (above threshold in either P or Q)
            significant_mask = (P >= threshold) | (Q >= threshold)
            
            # For each row, only consider transitions that are significant
            # Mask out non-significant transitions before computing KL
            P_masked = P.copy()
            Q_masked = Q.copy()
            for i in range(S):
                # Zero out non-significant transitions
                P_masked[i, ~significant_mask[i, :]] = 0.0
                Q_masked[i, ~significant_mask[i, :]] = 0.0
                # Renormalize
                p_sum = P_masked[i, :].sum()
                q_sum = Q_masked[i, :].sum()
                if p_sum > EPS:
                    P_masked[i, :] /= p_sum
                if q_sum > EPS:
                    Q_masked[i, :] /= q_sum
            P = P_masked
            Q = Q_masked
        
        if weights is None:
            w = np.ones(S, dtype=float)/S
        else:
            w = np.zeros(S, dtype=float)
            w[:min(S, weights.shape[0])] = weights[:min(S, weights.shape[0])]
            w = w / max(w.sum(), EPS)
        
        if use_js:
            row_scores = np.array([_js_row(P[i], Q[i]) for i in range(S)])
        else:
            row_scores = np.array([_kl_row(P[i], Q[i]) for i in range(S)])
        scores.append(float(np.sum(w * row_scores)))
    return np.array(scores)

def entropy_over_time(T_series: List[np.ndarray]) -> np.ndarray:
    """Average next-state entropy H(S'|S) per episode."""
    out = []
    for T in T_series:
        P = _aggregate_actions(T)
        H = np.array([_entropy_row(P[i]) for i in range(P.shape[0])])
        out.append(float(np.mean(H)))
    return np.array(out)

def markovization_score(T: np.ndarray, eps: float = EPS) -> float:
    """1 - normalized conditional entropy averaged across states."""
    P = _aggregate_actions(T)
    H = np.array([_entropy_row(P[i], eps) for i in range(P.shape[0])])
    Hmax = np.log(max(2, P.shape[1]))
    return float(1.0 - np.mean(H)/Hmax)

# Keep these imports for other uses
from coda_metrics import ref_empirical_from_rollouts, greedy_right_down_policy


In [None]:
# Helper function to create fixed reference function
def _ref_fn_fixed_from_T(T_fixed: np.ndarray):
    """Create a reference function from a fixed transition tensor."""
    def _fn(T_learned):
        return T_fixed.copy()
    return _fn

# Helper function for degradation GT (uses _build_base_T which is defined later)
def build_gt_degradation_no_clones(env_always_reward) -> np.ndarray:
    """
    Contingency degradation GT: reward is given regardless of cue/history.
    Structural GT is simply the base right/down graph with terminals absorbing (no clones).
    """
    return _build_base_T(env_always_reward)

# -------- Latent inhibition: seed-level --------
def run_latent_inhibition_seed(seed:int,
                               cfg:CoDAConfig,
                               pre_episodes:int = 500,
                               acq_episodes:int = 1000,
                               max_steps:int = 20,
                               cue:int = 5,
                               threshold: float = 0.3):
    """
    Phase 1 (pre-exposure):  no reward (extinction-like) to inflate P(US & ~CS)
    Phase 2 (acquisition):   normal cued task; splitting should be delayed
    Metrics are computed across the entire run vs. the acquisition GT-with-clones.
    """
    np.random.seed(seed)

    # -------- Pre-exposure (no reward) --------
    env_pre = GridEnvRightDownNoCue(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    agent   = CoDAAgent(env_pre, cfg)

    T_series = []

    for ep in range(1, pre_episodes+1):
        (states, actions) = generate_dataset(env_pre, n_episodes=1, max_steps=max_steps)[0]
        agent.update_with_episode(states, actions)
        # No splitting expected (no US); still collect T
        T_series.append(agent.get_T().copy())

    # -------- Acquisition (normal cued task) --------
    env_acq = GridEnvRightDownNoSelf(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    agent.env = env_acq  # keep uncertainty/accumulators; switch to cued env

    with_clones = False
    for ep in range(1, acq_episodes+1):
        if with_clones:
            (states, actions) = generate_dataset_post_augmentation(env_acq, agent.get_T(), n_episodes=1, max_steps=max_steps)[0]
        else:
            (states, actions) = generate_dataset(env_acq, n_episodes=1, max_steps=max_steps)[0]

        agent.update_with_episode(states, actions)
        new = agent.maybe_split()
        if new:
            with_clones = True

        T_series.append(agent.get_T().copy())

    # GT for latent inhibition evaluation: acquisition graph with clones
    T_ref = build_gt_acquisition_with_clones(env_acq, cue_state=cue)
    ref_fn = _ref_fn_fixed_from_T(T_ref)

    # Metrics over time (full series vs acquisition GT) with threshold
    KL = kl_over_time(T_series, ref_fn, use_js=False, threshold=threshold)
    JS = kl_over_time(T_series, ref_fn, use_js=True, threshold=threshold)
    H  = entropy_over_time(T_series)
    MS = np.array([markovization_score(T) for T in T_series])

    return dict(T_series=T_series, T_ref=T_ref, KL=KL, JS=JS, H=H, MS=MS)

# -------- Normal acquisition (for comparison with latent inhibition) --------
def run_normal_acquisition_seed(seed:int,
                                cfg:CoDAConfig,
                                acq_episodes:int = 1000,
                                max_steps:int = 20,
                                cue:int = 5,
                                threshold: float = 0.3):
    """
    Normal acquisition without pre-exposure, for comparison with latent inhibition.
    """
    np.random.seed(seed)

    env_acq = GridEnvRightDownNoSelf(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    agent   = CoDAAgent(env_acq, cfg)

    T_series = []
    with_clones = False

    for ep in range(1, acq_episodes+1):
        if with_clones:
            (states, actions) = generate_dataset_post_augmentation(env_acq, agent.get_T(), n_episodes=1, max_steps=max_steps)[0]
        else:
            (states, actions) = generate_dataset(env_acq, n_episodes=1, max_steps=max_steps)[0]

        agent.update_with_episode(states, actions)
        new = agent.maybe_split()
        if new:
            with_clones = True

        T_series.append(agent.get_T().copy())

    # GT: acquisition graph with clones
    T_ref = build_gt_acquisition_with_clones(env_acq, cue_state=cue)
    ref_fn = _ref_fn_fixed_from_T(T_ref)

    # Metrics over time with threshold
    KL = kl_over_time(T_series, ref_fn, use_js=False, threshold=threshold)
    JS = kl_over_time(T_series, ref_fn, use_js=True, threshold=threshold)
    H  = entropy_over_time(T_series)
    MS = np.array([markovization_score(T) for T in T_series])

    return dict(T_series=T_series, T_ref=T_ref, KL=KL, JS=JS, H=H, MS=MS)

# -------- Contingency degradation: seed-level --------
def run_contingency_degradation_seed(seed:int,
                                     cfg:CoDAConfig,
                                     acq_episodes:int = 1000,
                                     degr_episodes:int = 1000,
                                     max_steps:int = 20,
                                     cue:int = 5,
                                     threshold: float = 0.3,
                                     wash_in:int = 50,
                                     edge_eps_early:float = 1e-4,
                                     edge_eps_late:float  = 1e-6):
    """
    Phase 1 (acquisition):      normal cued task (splits form).
    Phase 2 (degradation):      reward at terminal regardless of cue/history (always reward).
                                RC falls while PC~1, so clones should merge.
    We report metrics separately for acq (vs acq GT-with-clones) and degr (vs degr GT no-clones).
    """
    np.random.seed(seed)

    # -------- Acquisition --------
    env_acq = GridEnvRightDownNoSelf(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    agent   = CoDAAgent(env_acq, cfg)

    T_series_acq = []
    with_clones = False

    for ep in range(1, acq_episodes+1):
        if with_clones:
            (states, actions) = generate_dataset_post_augmentation(env_acq, agent.get_T(), n_episodes=1, max_steps=max_steps)[0]
        else:
            (states, actions) = generate_dataset(env_acq, n_episodes=1, max_steps=max_steps)[0]

        agent.update_with_episode(states, actions)
        if agent.maybe_split():
            with_clones = True

        T_series_acq.append(agent.get_T().copy())

    T_ref_acq = build_gt_acquisition_with_clones(env_acq, cue_state=cue)
    ref_fn_acq = _ref_fn_fixed_from_T(T_ref_acq)

    # -------- Degradation (always reward; no reset) --------
    env_deg = GridEnvRightDownAlwaysReward(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    # carry learned clones into degradation env so merges are meaningful
    env_deg.clone_dict = dict(getattr(env_acq, "clone_dict", {}))
    env_deg.reverse_clone_dict = dict(getattr(env_acq, "reverse_clone_dict", {}))
    agent.env = env_deg

    # short wash-in to encourage structural merges early (optional)
    orig = dict(count_decay=agent.cfg.count_decay, trace_decay=agent.cfg.trace_decay, retro_decay=agent.cfg.retro_decay,
                theta_merge=agent.cfg.theta_merge, confidence=agent.cfg.confidence,
                min_presence_episodes=agent.cfg.min_presence_episodes,
                min_effective_exposure=agent.cfg.min_effective_exposure)

    agent.cfg.count_decay = 0.98
    agent.cfg.trace_decay = 0.98
    agent.cfg.retro_decay = 0.98
    agent.cfg.theta_merge = 0.60
    agent.cfg.confidence  = 0.99
    agent.cfg.min_presence_episodes += 3
    agent.cfg.min_effective_exposure = int(agent.cfg.min_effective_exposure * 1.5)

    T_series_deg = []
    for k in range(degr_episodes):
        (states, actions) = generate_dataset_post_augmentation(env_deg, agent.get_T(), n_episodes=1, max_steps=max_steps)[0]
        agent.update_with_episode(states, actions)
        agent._edge_eps_override = edge_eps_early if k < wash_in else edge_eps_late
        agent.maybe_merge()
        T_series_deg.append(agent.get_T().copy())

        if k == wash_in - 1:
            # restore original cfg after wash-in
            for key, val in orig.items():
                setattr(agent.cfg, key, val)

    T_ref_deg = build_gt_degradation_no_clones(env_deg)
    ref_fn_deg = _ref_fn_fixed_from_T(T_ref_deg)

    # Metrics (separate for each phase) with threshold
    KL_acq = kl_over_time(T_series_acq, ref_fn_acq, use_js=False, threshold=threshold)
    JS_acq = kl_over_time(T_series_acq, ref_fn_acq, use_js=True, threshold=threshold)
    H_acq  = entropy_over_time(T_series_acq)
    MS_acq = np.array([markovization_score(T) for T in T_series_acq])

    KL_deg = kl_over_time(T_series_deg, ref_fn_deg, use_js=False, threshold=threshold)
    JS_deg = kl_over_time(T_series_deg, ref_fn_deg, use_js=True, threshold=threshold)
    H_deg  = entropy_over_time(T_series_deg)
    MS_deg = np.array([markovization_score(T) for T in T_series_deg])

    return dict(
        env_acq=env_acq, env_deg=env_deg,
        T_series_acq=T_series_acq, T_series_deg=T_series_deg,
        T_ref_acq=T_ref_acq, T_ref_deg=T_ref_deg,
        KL_acq=KL_acq, JS_acq=JS_acq, H_acq=H_acq, MS_acq=MS_acq,
        KL_deg=KL_deg, JS_deg=JS_deg, H_deg=H_deg, MS_deg=MS_deg
    )

# -------- Latent inhibition: multi-seed --------
def run_latent_inhibition_many(cfg:CoDAConfig,
                               seeds:list,
                               pre_episodes:int=500,
                               acq_episodes:int=1000,
                               max_steps:int=20,
                               cue:int=5,
                               threshold: float = 0.3):
    runs = [run_latent_inhibition_seed(s, cfg, pre_episodes, acq_episodes, max_steps, cue, threshold) for s in seeds]
    KL_runs = [r["KL"] for r in runs]
    JS_runs = [r["JS"] for r in runs]
    H_runs  = [r["H"]  for r in runs]
    MS_runs = [r["MS"] for r in runs]
    return dict(runs=runs, KL=KL_runs, JS=JS_runs, H=H_runs, MS=MS_runs)

# -------- Normal acquisition: multi-seed (for comparison) --------
def run_normal_acquisition_many(cfg:CoDAConfig,
                                seeds:list,
                                acq_episodes:int=1000,
                                max_steps:int=20,
                                cue:int=5,
                                threshold: float = 0.3):
    runs = [run_normal_acquisition_seed(s, cfg, acq_episodes, max_steps, cue, threshold) for s in seeds]
    KL_runs = [r["KL"] for r in runs]
    JS_runs = [r["JS"] for r in runs]
    H_runs  = [r["H"]  for r in runs]
    MS_runs = [r["MS"] for r in runs]
    return dict(runs=runs, KL=KL_runs, JS=JS_runs, H=H_runs, MS=MS_runs)

def plot_latent_inhibition_summary(res_li, pre_episodes:int, title_suffix=""):
    plot_band(res_li["KL"], f"KL (latent inhibition{title_suffix})", "KL (nats)")
    plot_band(res_li["JS"], f"JS (latent inhibition{title_suffix})", "JS")
    plot_band(res_li["H"],  f"Avg H(S'|S) (latent inhibition{title_suffix})", "nats")
    plot_band(res_li["MS"], f"Markovization (latent inhibition{title_suffix})", "[0,1]")
    # vertical line to show transition from pre-exposure to acquisition
    for fig_num in plt.get_fignums()[-4:]:
        plt.figure(fig_num)
        plt.axvline(pre_episodes, ls="--", alpha=0.4)

def plot_latent_inhibition_comparison(res_li, res_normal, pre_episodes:int, acq_episodes:int, title_suffix=""):
    """
    Compare latent inhibition (with pre-exposure) vs normal acquisition (without pre-exposure).
    Shows that latent inhibition slows down learning.
    """
    # Align the series: LI starts at episode 0 (pre-exposure), normal starts at episode 0
    # For comparison, we want to show LI acquisition phase vs normal acquisition
    # So we'll plot LI from pre_episodes onwards, and normal from 0 onwards
    # But offset LI x-axis by pre_episodes so they align at "acquisition start"
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot normal acquisition (starts immediately)
    mean_normal, se_normal = mean_se(res_normal["KL"])
    x_normal = np.arange(len(mean_normal))
    ax.plot(x_normal, mean_normal, 'b-', lw=2.2, label=f"Normal acquisition (n={len(res_normal['KL'])})")
    ax.fill_between(x_normal, mean_normal - se_normal, mean_normal + se_normal, alpha=0.25, color='blue')
    
    # Plot latent inhibition (pre-exposure + acquisition)
    # Extract only the acquisition phase (after pre_episodes)
    KL_li_acq = []
    for kl_series in res_li["KL"]:
        if len(kl_series) > pre_episodes:
            KL_li_acq.append(kl_series[pre_episodes:])
    
    mean_li, se_li = mean_se(KL_li_acq)
    x_li = np.arange(len(mean_li))
    # Offset by pre_episodes to align with normal acquisition start
    ax.plot(x_li + pre_episodes, mean_li, 'r-', lw=2.2, label=f"Latent inhibition (n={len(res_li['KL'])})")
    ax.fill_between(x_li + pre_episodes, mean_li - se_li, mean_li + se_li, alpha=0.25, color='red')
    
    # Add vertical line at pre-exposure end
    ax.axvline(pre_episodes, ls="--", color='gray', alpha=0.5, label="Pre-exposure end")
    
    ax.set_xlabel("Episode")
    ax.set_ylabel("KL (nats)")
    ax.set_title(f"Learning Speed Comparison: Latent Inhibition vs Normal Acquisition{title_suffix}")
    ax.legend()
    ax.grid(alpha=0.2)
    plt.show()

# -------- Degradation: multi-seed --------
def run_degradation_many(cfg:CoDAConfig,
                         seeds:list,
                         acq_episodes:int=1000,
                         degr_episodes:int=1000,
                         max_steps:int=20,
                         cue:int=5,
                         threshold: float = 0.3):
    runs = [run_contingency_degradation_seed(s, cfg, acq_episodes, degr_episodes, max_steps, cue, threshold) for s in seeds]
    return dict(
        runs=runs,
        KL_acq=[r["KL_acq"] for r in runs], JS_acq=[r["JS_acq"] for r in runs],
        H_acq=[r["H_acq"] for r in runs],   MS_acq=[r["MS_acq"] for r in runs],
        KL_deg=[r["KL_deg"] for r in runs], JS_deg=[r["JS_deg"] for r in runs],
        H_deg=[r["H_deg"] for r in runs],   MS_deg=[r["MS_deg"] for r in runs],
    )

def plot_degradation_summary(res_deg, acq_episodes:int, title_suffix=""):
    plot_band(res_deg["KL_acq"], f"KL (acquisition — degradation runs{title_suffix})", "KL (nats)")
    plot_band(res_deg["JS_acq"], f"JS (acquisition — degradation runs{title_suffix})", "JS")
    plot_band(res_deg["H_acq"],  f"Avg H(S'|S) (acquisition — degradation runs{title_suffix})", "nats")
    plot_band(res_deg["MS_acq"], f"Markovization (acquisition — degradation runs{title_suffix})", "[0,1]")

    plot_band(res_deg["KL_deg"], f"KL (contingency degradation{title_suffix})", "KL (nats)")
    plot_band(res_deg["JS_deg"], f"JS (contingency degradation{title_suffix})", "JS")
    plot_band(res_deg["H_deg"],  f"Avg H(S'|S) (contingency degradation{title_suffix})", "nats")
    plot_band(res_deg["MS_deg"], f"Markovization (contingency degradation{title_suffix})", "[0,1]")
    
    # Add vertical line to show transition from acquisition to degradation
    for fig_num in plt.get_fignums()[-4:]:
        plt.figure(fig_num)
        plt.axvline(acq_episodes, ls="--", alpha=0.4)

def plot_degradation_kl_over_time(res_deg, acq_episodes:int, title_suffix=""):
    """
    Plot KL over time for degradation, showing both acquisition and degradation phases.
    """
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Plot acquisition phase
    mean_acq, se_acq = mean_se(res_deg["KL_acq"])
    x_acq = np.arange(len(mean_acq))
    ax.plot(x_acq, mean_acq, 'b-', lw=2.2, label=f"Acquisition (n={len(res_deg['KL_acq'])})")
    ax.fill_between(x_acq, mean_acq - se_acq, mean_acq + se_acq, alpha=0.25, color='blue')
    
    # Plot degradation phase
    mean_deg, se_deg = mean_se(res_deg["KL_deg"])
    x_deg = np.arange(len(mean_deg)) + acq_episodes
    ax.plot(x_deg, mean_deg, 'r-', lw=2.2, label=f"Degradation (n={len(res_deg['KL_deg'])})")
    ax.fill_between(x_deg, mean_deg - se_deg, mean_deg + se_deg, alpha=0.25, color='red')
    
    # Add vertical line at transition
    ax.axvline(acq_episodes, ls="--", color='gray', alpha=0.5, label="Acquisition → Degradation")
    
    ax.set_xlabel("Episode")
    ax.set_ylabel("KL (nats)")
    ax.set_title(f"KL Over Time: Contingency Degradation{title_suffix}")
    ax.legend()
    ax.grid(alpha=0.2)
    plt.show()

# Helper function for mean ± SE (if not already defined)
def mean_se(arrs):
    """Compute mean and standard error across runs."""
    L = max(len(a) for a in arrs)
    M = np.full((len(arrs), L), np.nan)
    for i, a in enumerate(arrs):
        M[i, :len(a)] = a
    mean = np.nanmean(M, axis=0)
    se   = np.nanstd(M, axis=0, ddof=max(1, min(len(arrs)-1, 1))) / np.sqrt(max(1, len(arrs)))
    return mean, se

def plot_band(y_runs, title, ylabel):
    """Plot mean ± SE band."""
    mean, se = mean_se(y_runs)
    x = np.arange(len(mean))
    plt.figure(figsize=(10, 4))
    plt.plot(x, mean, lw=2.2, label=f"mean ({len(y_runs)} seeds)")
    plt.fill_between(x, mean-se, mean+se, alpha=0.25, label="±1 SE")
    plt.title(title)
    plt.xlabel("Episode")
    plt.ylabel(ylabel)
    plt.legend()
    plt.grid(alpha=0.2)
    plt.show()

In [None]:
# import numpy as np

def sanitize_for_plot(env, T, eps=1e-12):
    """
    Remove clones with ~zero inbound+outbound mass and rebuild reverse map.
    Safe for growing/shrinking T during splits/merges.
    """
    if T is None or getattr(T, "ndim", 0) != 3:
        return

    S = T.shape[0]

    # Outbound mass from each state: sum over actions and next states
    out_mass = T.sum(axis=(1, 2))   # shape [S]

    # Inbound mass to each state: sum over sources and actions
    in_mass  = T.sum(axis=(0, 1))   # shape [S]

    active = (out_mass + in_mass) > eps

    # Drop clone ids that are inactive or out of bounds
    for cl in list(env.clone_dict.keys()):
        if cl >= S or not active[cl]:
            env.clone_dict.pop(cl, None)

    # Rebuild reverse mapping (parent -> latest clone)
    env.reverse_clone_dict = {parent: cl for cl, parent in env.clone_dict.items()}

def make_terminals_absorbing_for_plot(T, terminals):
    T = T.copy()
    for t in terminals:
        if t < T.shape[0]:
            T[t, :, :] = 0.0
    return T
def thresh_adj(T, thr=0.3):
    A = T.sum(axis=1)         # [S,S]
    return (A >= thr).astype(np.uint8)

def clone_dict_tuple(d):
    return tuple(sorted(d.items()))

def graph_changed(prev_T, prev_map, curr_T, curr_map, thr=0.3):
    if prev_T is None or prev_T.shape != curr_T.shape:
        return True
    A_prev = thresh_adj(prev_T, thr)
    A_curr = thresh_adj(curr_T, thr)
    if A_prev.shape != A_curr.shape:
        return True
    if (A_prev != A_curr).any():
        return True
    return prev_map != curr_map



In [None]:

def _n_actions(env):
    return max(a for acts in env.valid_actions.values() for a in acts) + 1

def _base_successor(env, s, a):
    i, j  = env.state_to_pos[s]
    di, dj = env.base_actions[a]
    ni, nj = i + di, j + dj
    return env.pos_to_state.get((ni, nj), s)

def _build_base_T(env):
    S0 = env.num_unique_states
    A  = _n_actions(env)
    T  = np.zeros((S0, A, S0), dtype=float)
    terminals = set(env.rewarded_terminals) | set(env.unrewarded_terminals)
    for s, acts in env.valid_actions.items():
        if s in terminals: 
            continue
        for a in acts:
            sp = _base_successor(env, s, a)
            T[s, a, sp] = 1.0
    return T

def _descendants_until_terminal(env, start, terminals):
    adj = {s: [_base_successor(env, s, a) for a in env.valid_actions.get(s, [])]
           for s in range(env.num_unique_states)}
    seen, Q = set(), [start]
    while Q:
        s = Q.pop(0)
        for sp in adj.get(s, []):
            if sp in terminals: 
                continue
            if sp not in seen:
                seen.add(sp); Q.append(sp)
    return seen

def build_gt_acquisition_with_clones(env, cue_state: int) -> np.ndarray:
    S0 = env.num_unique_states
    A  = _n_actions(env)
    T0 = _build_base_T(env)

    terminals = set(env.rewarded_terminals) | set(env.unrewarded_terminals)
    rewT = env.rewarded_terminals
    unrewT = env.unrewarded_terminals

    D = _descendants_until_terminal(env, cue_state, terminals)
    clone_of = {orig: S0 + k for k, orig in enumerate(sorted(D))}
    S = S0 + len(clone_of)
    T = np.zeros((S, A, S), dtype=float)
    T[:S0, :, :S0] = T0

    for a in env.valid_actions[cue_state]:
        sp = _base_successor(env, cue_state, a)
        if sp in D:
            T[cue_state, a, sp] = 0.0
            T[cue_state, a, clone_of[sp]] = 1.0

    for orig, cl in clone_of.items():
        for a in env.valid_actions[orig]:
            sp = _base_successor(env, orig, a)
            if sp in terminals:
                for t in rewT:   T[cl, a, t] = 1.0
                for t in unrewT: T[cl, a, t] = 0.0
            else:
                T[cl, a, clone_of[sp] if sp in clone_of else sp] = 1.0

        for a in env.valid_actions[orig]:
            sp = _base_successor(env, orig, a)
            if sp in rewT:
                T[orig, a, sp] = 0.0
                idx = rewT.index(sp)
                T[orig, a, unrewT[idx]] = 1.0
    return T

def build_gt_extinction_no_clones(env2) -> np.ndarray:
    T = _build_base_T(env2)
    for s in range(env2.num_unique_states):
        for a in env2.valid_actions.get(s, []):
            sp = _base_successor(env2, s, a)
            if sp in env2.rewarded_terminals:
                idx = env2.rewarded_terminals.index(sp)
                T[s, a, sp] = 0.0
                T[s, a, env2.unrewarded_terminals[idx]] = 1.0
    return T

def build_gt_extinction_with_clone_shape(env2, T_acq_shape: tuple) -> np.ndarray:
    """
    Build extinction reference with same shape as acquisition (with clones).
    Base states redirect to unrewarded terminals (extinction behavior).
    Clone states have all zeros (no transitions, since extinction has no clones).
    """
    S_acq, A_acq, S2_acq = T_acq_shape
    S_base = env2.num_unique_states
    
    # Start with base extinction (no clones)
    T_base = build_gt_extinction_no_clones(env2)
    
    # Expand to match acquisition shape
    T_ext = np.zeros((S_acq, A_acq, S2_acq), dtype=float)
    T_ext[:S_base, :, :S_base] = T_base
    
    # Clone states (S_base and above) remain zeros - no transitions
    # This represents extinction where clones don't exist
    
    # Normalize rows to ensure proper probability distribution
    for s in range(S_base):
        for a in range(A_acq):
            row_sum = T_ext[s, a, :].sum()
            if row_sum > 0:
                T_ext[s, a, :] /= row_sum
    
    return T_ext

def _pad3(arr: np.ndarray, shape: tuple) -> np.ndarray:
    S, A, S2 = arr.shape
    Sg, Ag, S2g = shape
    if (S, A, S2) == (Sg, Ag, S2g): 
        return arr
    out = np.zeros((Sg, Ag, S2g), dtype=float)
    out[:min(S,Sg), :min(A,Ag), :min(S2,S2g)] = arr[:min(S,Sg), :min(A,Ag), :min(S2,S2g)]
    return out

def _ref_fn_fixed(T_fixed: np.ndarray):
    """
    Returns a function that always returns the fixed reference at its original shape.
    kl_over_time will handle padding the learned T to match.
    """
    def _fn(T_learned):
        # Always return reference at its fixed shape - let kl_over_time handle padding
        return T_fixed.copy()
    return _fn


In [None]:

def ref_builder_factory(env, policy_fn, nroll=300, max_steps=20):
    def _make_ref(T_learned):
        return ref_empirical_from_rollouts(env, policy_fn, n_episodes=nroll, max_steps=max_steps)
    return _make_ref

def collapse_to_base_size(T, env):
    """
    Collapse transition tensor to base size by removing clones.
    Redirects any transitions to/from clones back to their parent states.
    Returns tensor of shape (num_unique_states, A, num_unique_states).
    """
    S_base = env.num_unique_states
    A = T.shape[1]
    T_collapsed = np.zeros((S_base, A, S_base), dtype=float)
    
    # Build mapping: clone_id -> parent_id (for all clones, including merged ones)
    clone_to_parent = {}
    for clone_id, parent_id in env.clone_dict.items():
        if parent_id < S_base:
            clone_to_parent[clone_id] = parent_id
    
    # Process all transitions
    for s in range(T.shape[0]):
        for a in range(A):
            for sp in range(T.shape[2]):
                mass = T[s, a, sp]
                if mass <= 0:
                    continue
                
                # Map source state to base state
                if s < S_base:
                    s_base = s
                else:
                    # s is a clone, redirect to parent
                    s_base = clone_to_parent.get(s, None)
                    if s_base is None:
                        continue  # Skip if clone has no parent mapping
                
                # Map target state to base state
                if sp < S_base:
                    sp_base = sp
                else:
                    # sp is a clone, redirect to parent
                    sp_base = clone_to_parent.get(sp, None)
                    if sp_base is None:
                        continue  # Skip if clone has no parent mapping
                
                # Add mass to collapsed tensor
                T_collapsed[s_base, a, sp_base] += mass
    
    # Normalize rows to maintain stochasticity
    for s in range(S_base):
        for a in range(A):
            row_sum = T_collapsed[s, a, :].sum()
            if row_sum > 0:
                T_collapsed[s, a, :] /= row_sum
    
    return T_collapsed

def run_one_seed(seed:int):
    np.random.seed(seed)

    env = GridEnvRightDownNoSelf(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])

    T_ref_acq_fixed = build_gt_acquisition_with_clones(env, cue_state=CUE)
    ref_fn_acq = _ref_fn_fixed(T_ref_acq_fixed)


    agent = CoDAAgent(env, cfg)

    T_series_acq: List[np.ndarray] = []
    with_clones = False
    for ep in range(1, N_ACQ+1):
        if with_clones:
            (states, actions) = generate_dataset_post_augmentation(env, agent.get_T(), n_episodes=1, max_steps=MAX_STEPS)[0]
        else:
            (states, actions) = generate_dataset(env, n_episodes=1, max_steps=MAX_STEPS)[0]
        agent.update_with_episode(states, actions)
        if agent.maybe_split():
            with_clones = True
        T_series_acq.append(agent.get_T().copy())

    env2 = GridEnvRightDownNoCue(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])
    env2.clone_dict = dict(getattr(env, "clone_dict", {}))
    env2.reverse_clone_dict = dict(getattr(env, "reverse_clone_dict", {}))
    agent.env = env2

    # Build extinction reference with same shape as acquisition (with clones, but zeros for clones)
    T_ref_ext_fixed = build_gt_extinction_with_clone_shape(env2, T_ref_acq_fixed.shape)
    # Verify shapes match
    assert T_ref_ext_fixed.shape == T_ref_acq_fixed.shape, f"Extinction ref shape {T_ref_ext_fixed.shape} != acquisition ref shape {T_ref_acq_fixed.shape}"
    ref_fn_ext = _ref_fn_fixed(T_ref_ext_fixed)

    T_series_ext: List[np.ndarray] = []
    for ep in range(N_ACQ+1, N_ACQ+N_EXT+1):
        (states, actions) = generate_dataset_post_augmentation(env2, agent.get_T(), n_episodes=1, max_steps=MAX_STEPS)[0]
        agent.update_with_episode(states, actions)
        agent.maybe_merge()
        # Keep full size (same as acquisition) for fair KL comparison over time
        T_series_ext.append(agent.get_T().copy())

    # ref_fn_acq = ref_builder_factory(env,  greedy_right_down_policy, nroll=N_ROLL_REF, max_steps=MAX_STEPS)
    # ref_fn_ext = ref_builder_factory(env2, greedy_right_down_policy, nroll=N_ROLL_REF, max_steps=MAX_STEPS)

    # For acquisition, use threshold to focus on significant transitions (>0.3)
    # This helps KL approach 0 by ignoring small/noise transitions
    KL_acq = kl_over_time(T_series_acq, ref_fn_acq, use_js=False, threshold=THRESH)
    JS_acq = kl_over_time(T_series_acq, ref_fn_acq, use_js=True, threshold=THRESH)
    H_acq  = entropy_over_time(T_series_acq)
    MS_acq = np.array([markovization_score(T) for T in T_series_acq])

    # Verify extinction reference shape (should be same as acquisition: 24,2,24)
    if len(T_series_ext) > 0:
        test_ref = ref_fn_ext(T_series_ext[0])
        assert test_ref.shape == T_ref_acq_fixed.shape, f"Extinction ref function returns shape {test_ref.shape}, expected {T_ref_acq_fixed.shape}"
        # Debug: print shapes to verify
        if seed == SEED0:  # Only print for first seed to avoid spam
            print(f"Debug (seed {seed}): T_ref_ext_fixed shape = {T_ref_ext_fixed.shape}")
            print(f"Debug (seed {seed}): T_ref_acq_fixed shape = {T_ref_acq_fixed.shape}")
            print(f"Debug (seed {seed}): First T_ext shape = {T_series_ext[0].shape}")
            print(f"Debug (seed {seed}): Last T_ext shape = {T_series_ext[-1].shape}")
            print(f"Debug (seed {seed}): ref_fn_ext returns shape = {test_ref.shape}")
    
    # For extinction, compute KL only over base states (excluding clones)
    # since extinction reference has zeros for clones
    # This ensures we're comparing the actual transition structure, not penalizing for clone mass
    KL_ext = kl_over_time(T_series_ext, ref_fn_ext, use_js=False, base_states_only=True)
    JS_ext = kl_over_time(T_series_ext, ref_fn_ext, use_js=True, base_states_only=True)
    H_ext  = entropy_over_time(T_series_ext)
    MS_ext = np.array([markovization_score(T) for T in T_series_ext])
    
    # Debug: print KL values and inspect transitions
    if seed == SEED0 and len(KL_ext) > 0:  # Only print for first seed
        print(f"Debug (seed {seed}): KL_ext[0] = {KL_ext[0]:.6f}, KL_ext[25] = {KL_ext[25] if len(KL_ext) > 25 else 'N/A':.6f}, KL_ext[-1] = {KL_ext[-1]:.6f}")
        print(f"Debug (seed {seed}): KL_ext change (0->25) = {KL_ext[25] - KL_ext[0] if len(KL_ext) > 25 else 'N/A':.6f}, (0->end) = {KL_ext[-1] - KL_ext[0]:.6f}")
        
        # Inspect transitions for cue state (state 5) to see what's happening
        if len(T_series_ext) > 0:
            T_start = T_series_ext[0]
            T_ref = ref_fn_ext(T_start)
            # Check transitions from cue state (5) to terminals
            cue_state = CUE
            if cue_state < T_start.shape[0] and cue_state < T_ref.shape[0]:
                # Aggregate over actions
                P_start = _aggregate_actions(_pad_to_shape(T_start, T_ref.shape))
                Q_ref = _aggregate_actions(T_ref)
                print(f"Debug (seed {seed}): Start - P[{cue_state}, 15] (rewarded) = {P_start[cue_state, 15]:.4f}, P[{cue_state}, 11] (unrewarded) = {P_start[cue_state, 11]:.4f}")
                print(f"Debug (seed {seed}): Ref   - Q[{cue_state}, 15] (rewarded) = {Q_ref[cue_state, 15]:.4f}, Q[{cue_state}, 11] (unrewarded) = {Q_ref[cue_state, 11]:.4f}")
            
            if len(T_series_ext) > 25:
                T_ep25 = T_series_ext[25]
                P_ep25 = _aggregate_actions(_pad_to_shape(T_ep25, T_ref.shape))
                print(f"Debug (seed {seed}): Ep25  - P[{cue_state}, 15] (rewarded) = {P_ep25[cue_state, 15]:.4f}, P[{cue_state}, 11] (unrewarded) = {P_ep25[cue_state, 11]:.4f}")

    return dict(KL_acq=KL_acq, JS_acq=JS_acq, H_acq=H_acq, MS_acq=MS_acq,
                KL_ext=KL_ext, JS_ext=JS_ext, H_ext=H_ext, MS_ext=MS_ext,
                T_acq_final=T_series_acq[-1], T_ext_final=T_series_ext[-1])


In [None]:
# Always-reward variant: terminal delivers reward regardless of cue/history
class GridEnvRightDownAlwaysReward(GridEnvRightDownNoSelf):
    def step(self, action):
        # Use base dynamics to compute next state & done flag
        ns, _, done = super().step(action)
        # Force reward at terminal regardless of visited_cue
        if done and hasattr(self, "is_terminal") and self.is_terminal(ns):
            return ns, 1, True
        return ns, 0, done

In [None]:

# --- Config ---
# CUE = 5
THRESH = 0.5               # must match env.plot_graph threshold
cfg = CoDAConfig(
    theta_split=0.6, theta_merge=0.3,
    n_threshold=8, min_presence_episodes=3, min_effective_exposure=5.0,
    confidence=0.8, 
    count_decay=0.9, 
    trace_decay=0.99,    # makes PC recent
    # retro_decay=0.9     # makes RC recent
)
N_SEEDS   = 30
SEED0     = 0
# N_ACQ     = 250
# N_EXT     = 300
# MAX_STEPS = 20
CUE       = 5

# cfg.theta_split = 0.85
N_ACQ, N_EXT = 250, 300
MAX_STEPS = 20
N_ROLL_REF = 300
env = GridEnvRightDownNoSelf(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])
agent = CoDAAgent(env, cfg)


In [None]:

def thresh_adj(T, thr=0.3):
    """Aggregate over actions and threshold to binary adjacency."""
    A = T.sum(axis=1)   # [S,S]
    if A.ndim != 2:
        # handle empty / malformed
        return None
    return (A >= thr).astype(np.uint8)

def clone_dict_tuple(d):
    """Stable tuple view of clone mapping for change detection."""
    # sort by clone_id
    return tuple(sorted(d.items()))


In [None]:

def graph_changed(prev_T, prev_clone_map, curr_T, curr_clone_map, thr=THRESH):
    if prev_T is None:
        return True
    # shape change (e.g., clones added)
    if prev_T.shape != curr_T.shape:
        return True
    # adjacency change
    A_prev = thresh_adj(prev_T, thr=thr)
    A_curr = thresh_adj(curr_T, thr=thr)
    if A_prev is None or A_curr is None:
        return True
    if A_prev.shape != A_curr.shape:
        return True
    if np.any(A_prev != A_curr):
        return True
    # clone map change
    if prev_clone_map != curr_clone_map:
        return True
    return False


In [None]:

# --- Run loops; only plot when changed ---
with_clones = False
prev_T = None
prev_map = None
changed_episodes = []

# Acquisition
for ep in range(1, N_ACQ+1):
    if with_clones:
        (states, actions) = generate_dataset_post_augmentation(env, agent.get_T(), n_episodes=1, max_steps=MAX_STEPS)[0]
    else:
        (states, actions) = generate_dataset(env, n_episodes=1, max_steps=MAX_STEPS)[0]
    # maybe here stochastic pairs
    transition_probs = agent.get_T()
    entropies = compute_transition_entropies(transition_probs)
    stochastic_pairs = find_stochastic_state_actions_by_entropy(entropies, eps=1e-9) # (s,a,sprime,sprime2)
    agent.update_with_episode(states, actions)
    if stochastic_pairs:
        unique_outcomes = set()

        for (s,a) in stochastic_pairs:
            sprime, sprime2 = get_successor_states(agent.transition_counts, s, a)
            unique_outcomes.add((sprime, sprime2))
        for (sprime,sprime2) in unique_outcomes:
            agent.update_eligibility_traces(states, sprime, sprime2)
        new = agent.maybe_split()
        if new:
            with_clones = True
            # if with_clones:
            #     print(f"Episode {ep}: Split occurred, now with clones.")

        T_curr = agent.get_T().copy()
        map_curr = clone_dict_tuple(env.clone_dict)

        if graph_changed(prev_T, prev_map, T_curr, map_curr, thr=THRESH):





            
            sanitize_for_plot(env, T_curr)
            env.plot_graph(T_curr, niter=ep, threshold=THRESH, save=False, savename=f'graph_ep{ep}.png')
            changed_episodes.append(ep)
            prev_T, prev_map = T_curr, map_curr

print("Changed episodes (acquisition):", changed_episodes[:20], "... total:", len(changed_episodes))


In [None]:
env.plot_graph(T_curr, niter=ep, threshold=THRESH, save=False, savename=f'graph_ep{ep}.png')

In [None]:
agent.salient_cues

In [None]:

# Extinction / degradation
THRESH = 0.5
prev_T, prev_map = prev_T, prev_map  # reuse from acquisition if you kept them
changed_ext = []


# Use GridEnvRightDownNoSelf instead of GridEnvRightDownNoCue
# env2 = GridEnvRightDownNoSelf(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])
# env2.clone_dict = dict(agent.env.clone_dict)
# env2.reverse_clone_dict = dict(agent.env.reverse_clone_dict)

# # Manually override step to force extinction (no reward) right after creating env2
# _orig_step = env2.step
# def extinction_step(action):
#     next_state, reward, done = _orig_step(action)
#     if done:
#         # Force extinction: no reward
#         reward = -1
#         # Swap to unrewarded terminal if needed
#         if next_state in env2.rewarded_terminals:
#             idx = env2.rewarded_terminals.index(next_state)
#             next_state = env2.unrewarded_terminals[idx]
#             env2.current_state = next_state
#     return next_state, reward, done

# env2.step = extinction_step
# agent.env = env2


env2 = GridEnvRightDownExtinction(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])
env2.clone_dict = dict(agent.env.clone_dict)
env2.reverse_clone_dict = dict(agent.env.reverse_clone_dict)
agent.env = env2

for ep in range(N_ACQ+1, N_ACQ+N_EXT+1):
    (states, actions) = generate_dataset_post_augmentation(env2, agent.get_T(), n_episodes=1, max_steps=MAX_STEPS)[0]
    agent.update_with_episode(states, actions)
    agent.maybe_merge()

    T_curr  = agent.get_T().copy()
    map_curr = clone_dict_tuple(env2.clone_dict)

    # if graph_changed(prev_T, prev_map, T_curr, map_curr, thr=THRESH):
    # (optional) clean terminals/clones just for the figure:
    T_vis = make_terminals_absorbing_for_plot(T_curr, env2.rewarded_terminals + env2.unrewarded_terminals)
    sanitize_for_plot(env2, T_vis)
    env2.plot_graph(T_vis, niter=ep, threshold=THRESH, save=False, savename=f'graph_ep{ep}.png')

    changed_ext.append(ep)
    prev_T, prev_map = T_curr, map_curr

print("Extinction changed episodes:", changed_ext[:30], "... total:", len(changed_ext))


In [None]:
agent.retrospective()[5]

In [None]:
agent.prospective()[5]

In [None]:
# ============================================================================
# Run all experiments and plot KL over time: Acquisition, Extinction, 
# Latent Inhibition, and Contingency Degradation
# ============================================================================

# Parameters
N_LI_PRE = 300
N_LI_ACQ = 300
N_CD_ACQ = 300
N_CD_DEG = 300
N_SEEDS_POSTER = 100 #100  # Use 100 seeds for poster
SEEDS_ALL = list(range(N_SEEDS_POSTER))

print("Running all experiments...")
print(f"Seeds: {len(SEEDS_ALL)}")

# 1. Acquisition & Extinction (already computed in run_one_seed)
print("\n1. Running Acquisition & Extinction...")
results_acq_ext = []
for k in range(N_SEEDS_POSTER):
    res = run_one_seed(SEED0 + k)
    results_acq_ext.append(res)

KL_acq_runs = [r["KL_acq"] for r in results_acq_ext]
KL_ext_runs = [r["KL_ext"] for r in results_acq_ext]

# 2. Latent Inhibition
print("2. Running Latent Inhibition...")
res_li = run_latent_inhibition_many(cfg, SEEDS_ALL, 
                                     pre_episodes=N_LI_PRE, 
                                     acq_episodes=N_LI_ACQ, 
                                     max_steps=MAX_STEPS, 
                                     cue=CUE, 
                                     threshold=THRESH)

# 3. Contingency Degradation
print("3. Running Contingency Degradation...")
res_deg = run_degradation_many(cfg, SEEDS_ALL,
                                acq_episodes=N_CD_ACQ,
                                degr_episodes=N_CD_DEG,
                                max_steps=MAX_STEPS,
                                cue=CUE,
                                threshold=THRESH)

print("\nAll experiments complete! Plotting KL curves...")

# ============================================================================
# Plot all KL curves - 4 separate plots for poster
# ============================================================================

# Set poster-style font sizes - much larger for poster
plt.rcParams.update({'font.size': 26, 'axes.titlesize': 36, 'axes.labelsize': 30, 
                     'xtick.labelsize': 26, 'ytick.labelsize': 26, 'legend.fontsize': 26,
                     'figure.titlesize': 36})

# 1. Acquisition
fig, ax = plt.subplots(figsize=(14, 6))
mean_acq, se_acq = mean_se(KL_acq_runs)
x_acq = np.arange(len(mean_acq))
ax.plot(x_acq, mean_acq, 'b-', lw=5, label=f"Acquisition (n={len(KL_acq_runs)})")
ax.fill_between(x_acq, mean_acq - se_acq, mean_acq + se_acq, alpha=0.3, color='blue')
ax.set_xlabel("Episode", fontsize=30, fontweight='bold')
ax.set_ylabel("KL Divergence (nats)", fontsize=30, fontweight='bold')
ax.set_title("Acquisition", fontsize=36, fontweight='bold', pad=20)
ax.tick_params(axis='both', which='major', labelsize=26)
ax.legend(fontsize=26, frameon=True, fancybox=True, shadow=True)
ax.grid(alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()

# 2. Extinction
fig, ax = plt.subplots(figsize=(14, 6))
# Acquisition phase
mean_acq, se_acq = mean_se(KL_acq_runs)
x_acq = np.arange(len(mean_acq))
ax.plot(x_acq, mean_acq, 'b-', lw=5, label=f"Acquisition Phase (n={len(KL_acq_runs)})")
ax.fill_between(x_acq, mean_acq - se_acq, mean_acq + se_acq, alpha=0.3, color='blue')

# Extinction phase
mean_ext, se_ext = mean_se(KL_ext_runs)
x_ext = np.arange(len(mean_ext)) + N_ACQ
ax.plot(x_ext, mean_ext, 'r-', lw=5, label=f"Extinction Phase (n={len(KL_ext_runs)})")
ax.fill_between(x_ext, mean_ext - se_ext, mean_ext + se_ext, alpha=0.3, color='red')
ax.axvline(N_ACQ, ls="--", color='gray', alpha=0.6, linewidth=4, label="Acquisition → Extinction")
ax.set_xlabel("Episode", fontsize=30, fontweight='bold')
ax.set_ylabel("KL Divergence (nats)", fontsize=30, fontweight='bold')
ax.set_title("Extinction", fontsize=36, fontweight='bold', pad=20)
ax.tick_params(axis='both', which='major', labelsize=26)
ax.legend(fontsize=26, frameon=True, fancybox=True, shadow=True, loc='best')
ax.grid(alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()

# 3. Latent Inhibition
# Plot only the acquisition phase, starting from episode 0 (treating acquisition start as episode 0)
fig, ax = plt.subplots(figsize=(14, 6))

# Extract acquisition phase only (after pre-exposure)
KL_li_acq = []
for kl_series in res_li["KL"]:
    if len(kl_series) > N_LI_PRE:
        KL_li_acq.append(kl_series[N_LI_PRE:])

mean_li, se_li = mean_se(KL_li_acq)
x_li = np.arange(len(mean_li))
ax.plot(x_li, mean_li, 'r-', lw=5, label=f"Latent Inhibition (n={len(KL_li_acq)})")
ax.fill_between(x_li, mean_li - se_li, mean_li + se_li, alpha=0.3, color='red')
ax.set_xlabel("Episode", fontsize=30, fontweight='bold')
ax.set_ylabel("KL Divergence (nats)", fontsize=30, fontweight='bold')
ax.set_title("Latent Inhibition", fontsize=36, fontweight='bold', pad=20)
ax.tick_params(axis='both', which='major', labelsize=26)
ax.legend(fontsize=26, frameon=True, fancybox=True, shadow=True)
ax.grid(alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()

# 4. Contingency Degradation
fig, ax = plt.subplots(figsize=(14, 6))
# Acquisition phase
mean_deg_acq, se_deg_acq = mean_se(res_deg["KL_acq"])
x_deg_acq = np.arange(len(mean_deg_acq))
ax.plot(x_deg_acq, mean_deg_acq, 'b-', lw=5, label=f"Acquisition Phase (n={len(res_deg['KL_acq'])})")
ax.fill_between(x_deg_acq, mean_deg_acq - se_deg_acq, mean_deg_acq + se_deg_acq, alpha=0.3, color='blue')

# Degradation phase
mean_deg, se_deg = mean_se(res_deg["KL_deg"])
x_deg = np.arange(len(mean_deg)) + N_CD_ACQ
ax.plot(x_deg, mean_deg, 'r-', lw=5, label=f"Degradation Phase (n={len(res_deg['KL_deg'])})")
ax.fill_between(x_deg, mean_deg - se_deg, mean_deg + se_deg, alpha=0.3, color='red')
ax.axvline(N_CD_ACQ, ls="--", color='gray', alpha=0.6, linewidth=4, label="Acquisition → Degradation")
ax.set_xlabel("Episode", fontsize=30, fontweight='bold')
ax.set_ylabel("KL Divergence (nats)", fontsize=30, fontweight='bold')
ax.set_title("Contingency Degradation", fontsize=36, fontweight='bold', pad=20)
ax.tick_params(axis='both', which='major', labelsize=26)
ax.legend(fontsize=26, frameon=True, fancybox=True, shadow=True, loc='best')
ax.grid(alpha=0.3, linestyle='--')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()

print("\nAll plots complete!")


In [None]:

# run many seeds
results = []
for k in range(N_SEEDS):
    res = run_one_seed(SEED0 + k)
    results.append(res)

def _pad_stack(arrs):
    L = max(len(a) for a in arrs)
    M = np.full((len(arrs), L), np.nan)
    for i,a in enumerate(arrs):
        M[i,:len(a)] = a
    return M

def mean_se(arrs):
    M = _pad_stack(arrs)
    mean = np.nanmean(M, axis=0)
    se   = np.nanstd(M, axis=0, ddof=max(1,min(len(arrs)-1,1))) / np.sqrt(max(1,len(arrs)))
    return mean, se

KL_acq_runs = [r["KL_acq"] for r in results]
JS_acq_runs = [r["JS_acq"] for r in results]
H_acq_runs  = [r["H_acq"]  for r in results]
MS_acq_runs = [r["MS_acq"] for r in results]

KL_ext_runs = [r["KL_ext"] for r in results]
JS_ext_runs = [r["JS_ext"] for r in results]
H_ext_runs  = [r["H_ext"]  for r in results]
MS_ext_runs = [r["MS_ext"] for r in results]


In [None]:

def plot_band(y_runs, title, ylabel):
    mean, se = mean_se(y_runs)
    x = np.arange(len(mean))
    plt.plot(x, mean, lw=2.2, label=f"mean ({len(y_runs)} seeds)")
    plt.fill_between(x, mean-se, mean+se, alpha=0.25, label="±1 SE")
    plt.title(title); plt.xlabel("Episode"); plt.ylabel(ylabel); plt.legend(); plt.grid(alpha=0.2)

plt.figure(figsize=(10,4)); plot_band(KL_acq_runs, "KL (acquisition)", "KL (nats)"); plt.show()
plt.figure(figsize=(10,4)); plot_band(JS_acq_runs, "JS (acquisition)", "JS"); plt.show()
plt.figure(figsize=(10,4)); plot_band(H_acq_runs,  "Avg H(S'|S) (acquisition)", "nats"); plt.show()
plt.figure(figsize=(10,4)); plot_band(MS_acq_runs, "Markovization (acquisition)", "[0,1]"); plt.show()

plt.figure(figsize=(10,4)); plot_band(KL_ext_runs, "KL (extinction)", "KL (nats)"); plt.show()
plt.figure(figsize=(10,4)); plot_band(JS_ext_runs, "JS (extinction)", "JS"); plt.show()
plt.figure(figsize=(10,4)); plot_band(H_ext_runs,  "Avg H(S'|S) (extinction)", "nats"); plt.show()
plt.figure(figsize=(10,4)); plot_band(MS_ext_runs, "Markovization (extinction)", "[0,1]"); plt.show()

# Visualize ref_fn_acq/ref_fn_ext target transition structures
env_acq_vis = GridEnvRightDownNoSelf(cue_states=[CUE], env_size=(4, 4), rewarded_terminal=[15])
T_ref_acq_fixed = build_gt_acquisition_with_clones(env_acq_vis, cue_state=CUE)

env_ext_vis = GridEnvRightDownNoCue(cue_states=[CUE], env_size=(4, 4), rewarded_terminal=[15])
env_ext_vis.clone_dict = dict(getattr(env_acq_vis, "clone_dict", {}))
env_ext_vis.reverse_clone_dict = dict(getattr(env_acq_vis, "reverse_clone_dict", {}))
# Build extinction reference with same shape as acquisition (with clones, but zeros for clones)
T_ref_ext_fixed = build_gt_extinction_with_clone_shape(env_ext_vis, T_ref_acq_fixed.shape)

ref_fn_acq = _ref_fn_fixed(T_ref_acq_fixed)
ref_fn_ext = _ref_fn_fixed(T_ref_ext_fixed)

T_ref_acq = ref_fn_acq(np.zeros((1, 1, 1)))
T_ref_ext = ref_fn_ext(np.zeros((1, 1, 1)))

print("ref_fn_acq output shape:", T_ref_acq.shape)
print("ref_fn_ext output shape:", T_ref_ext.shape)

adj_acq = T_ref_acq.sum(axis=1)
adj_ext = T_ref_ext.sum(axis=1)

fig, axes = plt.subplots(1, 2, figsize=(12, 5), constrained_layout=True)
for ax, mat, title in zip(
    axes,
    [adj_acq, adj_ext],
    ["Reference transitions (acquisition)", "Reference transitions (extinction)"]
):
    im = ax.imshow(mat, cmap="magma", vmin=0.0, vmax=1.0)
    ax.set_xlabel("next state")
    ax.set_ylabel("current state")
    ax.set_title(title)
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.show()


In [None]:
# Ensure results include the final transition tensors before plotting
if not results or "T_acq_final" not in results[-1] or "T_ext_final" not in results[-1]:
    results = []
    for k in range(N_SEEDS):
        results.append(run_one_seed(SEED0 + k))

def plot_band(y_runs, title, ylabel):
    mean, se = mean_se(y_runs)
    x = np.arange(len(mean))
    plt.plot(x, mean, lw=2.2, label=f"mean ({len(y_runs)} seeds)")
    plt.fill_between(x, mean-se, mean+se, alpha=0.25, label="±1 SE")
    plt.title(title); plt.xlabel("Episode"); plt.ylabel(ylabel); plt.legend(); plt.grid(alpha=0.2)

plt.figure(figsize=(10,4)); plot_band(KL_acq_runs, "KL (acquisition)", "KL (nats)"); plt.show()
plt.figure(figsize=(10,4)); plot_band(JS_acq_runs, "JS (acquisition)", "JS"); plt.show()
plt.figure(figsize=(10,4)); plot_band(H_acq_runs,  "Avg H(S'|S) (acquisition)", "nats"); plt.show()
plt.figure(figsize=(10,4)); plot_band(MS_acq_runs, "Markovization (acquisition)", "[0,1]"); plt.show()

plt.figure(figsize=(10,4)); plot_band(KL_ext_runs, "KL (extinction)", "KL (nats)"); plt.show()
plt.figure(figsize=(10,4)); plot_band(JS_ext_runs, "JS (extinction)", "JS"); plt.show()
plt.figure(figsize=(10,4)); plot_band(H_ext_runs,  "Avg H(S'|S) (extinction)", "nats"); plt.show()
plt.figure(figsize=(10,4)); plot_band(MS_ext_runs, "Markovization (extinction)", "[0,1]"); plt.show()

# Visualize ref_fn_acq/ref_fn_ext target transition structures
env_acq_vis = GridEnvRightDownNoSelf(cue_states=[CUE], env_size=(4, 4), rewarded_terminal=[15])
T_ref_acq_fixed = build_gt_acquisition_with_clones(env_acq_vis, cue_state=CUE)

env_ext_vis = GridEnvRightDownNoCue(cue_states=[CUE], env_size=(4, 4), rewarded_terminal=[15])
env_ext_vis.clone_dict = dict(getattr(env_acq_vis, "clone_dict", {}))
env_ext_vis.reverse_clone_dict = dict(getattr(env_acq_vis, "reverse_clone_dict", {}))
# Build extinction reference with same shape as acquisition (with clones, but zeros for clones)
T_ref_ext_fixed = build_gt_extinction_with_clone_shape(env_ext_vis, T_ref_acq_fixed.shape)

ref_fn_acq = _ref_fn_fixed(T_ref_acq_fixed)
ref_fn_ext = _ref_fn_fixed(T_ref_ext_fixed)

T_ref_acq = ref_fn_acq(np.zeros((1, 1, 1)))
T_ref_ext = ref_fn_ext(np.zeros((1, 1, 1)))

print("ref_fn_acq output shape:", T_ref_acq.shape)
print("ref_fn_ext output shape:", T_ref_ext.shape)

adj_acq = T_ref_acq.sum(axis=1)
adj_ext = T_ref_ext.sum(axis=1)

fig, axes = plt.subplots(1, 2, figsize=(12, 5), constrained_layout=True)
for ax, mat, title in zip(
    axes,
    [adj_acq, adj_ext],
    ["Reference transitions (acquisition)", "Reference transitions (extinction)"]
):
    im = ax.imshow(mat, cmap="magma", vmin=0.0, vmax=1.0)
    ax.set_xlabel("next state")
    ax.set_ylabel("current state")
    ax.set_title(title)
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

plt.show()

def plot_transition_tensor(T, title_prefix):
    if T is None:
        print(f"No transition matrix available for {title_prefix}.")
        return
    n_actions = T.shape[1]
    fig, axes = plt.subplots(1, n_actions, figsize=(4 * n_actions, 4), constrained_layout=True)
    if n_actions == 1:
        axes = [axes]
    for a_idx, ax in enumerate(axes):
        im = ax.imshow(T[:, a_idx, :], cmap="magma", vmin=0.0, vmax=1.0)
        ax.set_title(f"{title_prefix} — action {a_idx}")
        ax.set_xlabel("next state")
        ax.set_ylabel("current state")
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    plt.show()

    avg_prob = T.mean(axis=1)
    adj_mask = (avg_prob >= THRESH).astype(int)
    fig2, (ax_prob, ax_adj) = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)
    im_prob = ax_prob.imshow(avg_prob, cmap="magma", vmin=0.0, vmax=1.0)
    ax_prob.set_title(f"{title_prefix} — mean over actions")
    ax_prob.set_xlabel("next state")
    ax_prob.set_ylabel("current state")
    fig2.colorbar(im_prob, ax=ax_prob, fraction=0.046, pad=0.04)

    im_adj = ax_adj.imshow(adj_mask, cmap="gray_r", vmin=0, vmax=1)
    ax_adj.set_title(f"{title_prefix} — adjacency (≥ {THRESH})")
    ax_adj.set_xlabel("next state")
    ax_adj.set_ylabel("current state")
    fig2.colorbar(im_adj, ax=ax_adj, fraction=0.046, pad=0.04)
    plt.show()

plot_transition_tensor(results[-1]["T_acq_final"], "Learned acquisition T (final episode)")
plot_transition_tensor(results[-1]["T_ext_final"], "Learned extinction T (final episode)")

In [None]:
# --- Parameters (paper-style) ---
N_LI_PRE = 500
N_LI_ACQ = 1000
N_CD_ACQ = 1000
N_CD_DEG = 1000
SEEDS    = list(range(30))  # 30 seeds

# Use your v2 acquisition cfg (keep it as you’ve set earlier)
cfg_li = CoDAConfig()  # latent inhibition relies on v2-like gating & uncertainty
cfg_cd = CoDAConfig()  # degradation uses same; merges depend on retrospective fall

# -------- Latent inhibition --------
res_li = run_latent_inhibition_many(cfg_li, SEEDS, pre_episodes=N_LI_PRE, acq_episodes=N_LI_ACQ, max_steps=MAX_STEPS, cue=CUE)
plot_latent_inhibition_summary(res_li, pre_episodes=N_LI_PRE)

# -------- Contingency degradation --------
res_deg = run_degradation_many(cfg_cd, SEEDS, acq_episodes=N_CD_ACQ, degr_episodes=N_CD_DEG, max_steps=MAX_STEPS, cue=CUE)
# plot_degradation_summary(res_deg)

# ===== Contingency degradation: plots in the same style as acquisition/extinction =====

# Unpack runs
KL_acq_runs = res_deg["KL_acq"]
JS_acq_runs = res_deg["JS_acq"]
H_acq_runs  = res_deg["H_acq"]
MS_acq_runs = res_deg["MS_acq"]

KL_deg_runs = res_deg["KL_deg"]
JS_deg_runs = res_deg["JS_deg"]
H_deg_runs  = res_deg["H_deg"]
MS_deg_runs = res_deg["MS_deg"]

# Reuse your mean±SE band helper
def mean_se(arrs):
    L = max(len(a) for a in arrs)
    M = np.full((len(arrs), L), np.nan)
    for i, a in enumerate(arrs):
        M[i, :len(a)] = a
    mean = np.nanmean(M, axis=0)
    se   = np.nanstd(M, axis=0, ddof=max(1, min(len(arrs)-1, 1))) / np.sqrt(max(1, len(arrs)))
    return mean, se

def plot_band(y_runs, title, ylabel):
    mean, se = mean_se(y_runs)
    x = np.arange(len(mean))
    plt.plot(x, mean, lw=2.2, label=f"mean ({len(y_runs)} seeds)")
    plt.fill_between(x, mean-se, mean+se, alpha=0.25, label="±1 SE")
    plt.title(title); plt.xlabel("Episode"); plt.ylabel(ylabel); plt.legend(); plt.grid(alpha=0.2)

# Acquisition portion (degradation runs, pre-switch)
plt.figure(figsize=(10,4)); plot_band(KL_acq_runs, "KL (acquisition)", "KL (nats)"); plt.show()
plt.figure(figsize=(10,4)); plot_band(JS_acq_runs, "JS (acquisition)", "JS"); plt.show()
plt.figure(figsize=(10,4)); plot_band(H_acq_runs,  "Avg H(S'|S) (acquisition)", "nats"); plt.show()
plt.figure(figsize=(10,4)); plot_band(MS_acq_runs, "Markovization (acquisition)", "[0,1]"); plt.show()

# Degradation portion (post-switch)
plt.figure(figsize=(10,4)); plot_band(KL_deg_runs, "KL (contingency degradation)", "KL (nats)"); plt.show()
plt.figure(figsize=(10,4)); plot_band(JS_deg_runs, "JS (contingency degradation)", "JS"); plt.show()
plt.figure(figsize=(10,4)); plot_band(H_deg_runs,  "Avg H(S'|S) (contingency degradation)", "nats"); plt.show()
plt.figure(figsize=(10,4)); plot_band(MS_deg_runs, "Markovization (contingency degradation)", "[0,1]"); plt.show()


In [None]:
# PATCH: Fixed version that populates clone_dict for proper graph visualization
def build_gt_acquisition_with_clones(env, cue_state: int) -> np.ndarray:
    S0 = env.num_unique_states
    A  = _n_actions(env)
    T0 = _build_base_T(env)

    terminals = set(env.rewarded_terminals) | set(env.unrewarded_terminals)
    rewT = env.rewarded_terminals
    unrewT = env.unrewarded_terminals

    D = _descendants_until_terminal(env, cue_state, terminals)
    clone_of = {orig: S0 + k for k, orig in enumerate(sorted(D))}
    S = S0 + len(clone_of)
    T = np.zeros((S, A, S), dtype=float)
    T[:S0, :, :S0] = T0

    for a in env.valid_actions[cue_state]:
        sp = _base_successor(env, cue_state, a)
        if sp in D:
            T[cue_state, a, sp] = 0.0
            T[cue_state, a, clone_of[sp]] = 1.0

    for orig, cl in clone_of.items():
        for a in env.valid_actions[orig]:
            sp = _base_successor(env, orig, a)
            if sp in terminals:
                for t in rewT:   T[cl, a, t] = 1.0
                for t in unrewT: T[cl, a, t] = 0.0
            else:
                T[cl, a, clone_of[sp] if sp in clone_of else sp] = 1.0

        for a in env.valid_actions[orig]:
            sp = _base_successor(env, orig, a)
            if sp in rewT:
                T[orig, a, sp] = 0.0
                idx = rewT.index(sp)
                T[orig, a, unrewT[idx]] = 1.0
    
    # KEY FIX: Populate clone_dict so plotting overlays clones on their parent states
    for orig, cl in clone_of.items():
        env.clone_dict[cl] = orig  # Maps clone_id -> original_state_id
    env.reverse_clone_dict = {parent: cl for cl, parent in env.clone_dict.items()}
    
    return T

In [None]:

# Collect T snapshots during acquisition and extinction.
# If you already recorded them earlier, just reuse those lists.
T_series_acq = []
T_series_ext = []

# Re-run quick pass to collect snapshots only (no plotting) -- uses your existing variables:
# Acquisition
with_clones = False
for ep in range(1, N_ACQ+1):
    if with_clones:
        (states, actions) = generate_dataset_post_augmentation(env, agent.get_T(), n_episodes=1, max_steps=MAX_STEPS)[0]
    else:
        (states, actions) = generate_dataset(env, n_episodes=1, max_steps=MAX_STEPS)[0]
    agent.update_with_episode(states, actions)
    if agent.maybe_split():
        with_clones = True
    T_series_acq.append(agent.get_T().copy())

# Extinction
env2 = GridEnvRightDownNoCue(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])
env2.clone_dict = dict(getattr(env, "clone_dict", {}))
env2.reverse_clone_dict = dict(getattr(env, "reverse_clone_dict", {}))
agent.env = env2

for ep in range(N_ACQ+1, N_ACQ+N_EXT+1):
    (states, actions) = generate_dataset_post_augmentation(env2, agent.get_T(), n_episodes=1, max_steps=MAX_STEPS)[0]
    agent.update_with_episode(states, actions)
    agent.maybe_merge()
    T_series_ext.append(agent.get_T().copy())


In [None]:
# Compare learned vs ground-truth representations

# Check if T_series exist, otherwise run the data collection first
if 'T_series_acq' not in locals() or 'T_series_ext' not in locals():
    print("ERROR: Please run cell 19 first to generate T_series_acq and T_series_ext")
    print("Cell 19 is the one that starts with: # Collect T snapshots during acquisition and extinction.")
else:
    # Get final learned representations
    T_learned_acq = T_series_acq[-1].copy()
    T_learned_ext = T_series_ext[-1].copy()

    # Build ground-truth representations
    T_gt_acq = build_gt_acquisition_with_clones(env, cue_state=CUE)
    T_gt_ext = build_gt_extinction_no_clones(env2)

    print(f"Learned acquisition graph shape: {T_learned_acq.shape}")
    print(f"Ground-truth acquisition graph shape: {T_gt_acq.shape}")
    print(f"Learned extinction graph shape: {T_learned_ext.shape}")
    print(f"Ground-truth extinction graph shape: {T_gt_ext.shape}")

    # Compute comparison metrics
    def compare_graphs(T_learned, T_gt, name=""):
        # Pad to same shape if needed
        max_shape = tuple(max(s1, s2) for s1, s2 in zip(T_learned.shape, T_gt.shape))
        T_l = _pad3(T_learned, max_shape)
        T_g = _pad3(T_gt, max_shape)
        
        # Aggregate over actions for adjacency
        A_l = T_l.sum(axis=1)
        A_g = T_g.sum(axis=1)
        
        # Binary adjacency at threshold
        B_l = (A_l >= THRESH).astype(int)
        B_g = (A_g >= THRESH).astype(int)
        
        # Metrics
        edge_match = np.sum(B_l == B_g) / B_l.size
        l1_dist = np.mean(np.abs(A_l - A_g))
        kl_div = np.sum(np.where((A_g > 0) & (A_l > 0), 
                                 A_g * np.log((A_g + 1e-10) / (A_l + 1e-10)), 0))
        
        print(f"\n{name} Comparison:")
        print(f"  Edge match rate: {edge_match:.3f}")
        print(f"  L1 distance (adjacency): {l1_dist:.4f}")
        print(f"  KL divergence: {kl_div:.4f}")
        
        return {"edge_match": edge_match, "l1": l1_dist, "kl": kl_div}

    metrics_acq = compare_graphs(T_learned_acq, T_gt_acq, "Acquisition")
    metrics_ext = compare_graphs(T_learned_ext, T_gt_ext, "Extinction")

    # Visualize learned vs ground-truth (separate figures)
    
    # Acquisition: Ground-truth
    # print("\n=== Ground-Truth (Acquisition) ===")
    # T_vis_gt_acq = make_terminals_absorbing_for_plot(T_gt_acq, env.rewarded_terminals + env.unrewarded_terminals)
    # env_temp = GridEnvRightDownNoSelf(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])
    # env_temp.plot_graph(T_vis_gt_acq, niter=N_ACQ, threshold=THRESH, save=False)
    # plt.title(f"Ground-Truth (Acquisition, ep={N_ACQ})", fontsize=12, fontweight='bold')
    # plt.show()

    # Acquisition: Ground-truth
    print("\n=== Ground-Truth (Acquisition) ===")
    T_vis_gt_acq = make_terminals_absorbing_for_plot(T_gt_acq, env.rewarded_terminals + env.unrewarded_terminals)
    env_temp = GridEnvRightDownNoSelf(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])
    env_temp.clone_dict = dict(env.clone_dict)  # ADD THIS LINE
    env_temp.reverse_clone_dict = dict(env.reverse_clone_dict)  # ADD THIS LINE TOO
    env_temp.plot_graph(T_vis_gt_acq, niter=N_ACQ, threshold=THRESH, save=False, title=f"Ground-Truth (Acquisition, ep={N_ACQ})")
    plt.show()

    # Acquisition: Learned
    print("\n=== Learned (Acquisition) ===")
    T_vis_learned_acq = make_terminals_absorbing_for_plot(T_learned_acq, env.rewarded_terminals + env.unrewarded_terminals)
    sanitize_for_plot(env, T_vis_learned_acq)
    env.plot_graph(T_vis_learned_acq, niter=N_ACQ, threshold=THRESH, save=False, title=f"Learned (Acquisition, ep={N_ACQ})")
    plt.show()

    # Extinction: Ground-truth
    print("\n=== Ground-Truth (Extinction) ===")
    T_vis_gt_ext = make_terminals_absorbing_for_plot(T_gt_ext, env2.rewarded_terminals + env2.unrewarded_terminals)
    env2_temp = GridEnvRightDownNoCue(cue_states=[CUE], env_size=(4,4), rewarded_terminal=[15])
    env2_temp.clone_dict = dict(env2.clone_dict)  # ADD THIS LINE
    env2_temp.reverse_clone_dict = dict(env2.reverse_clone_dict)  # ADD THIS LINE TOO
    env2_temp.plot_graph(T_vis_gt_ext, niter=N_ACQ+N_EXT, threshold=THRESH, save=False, title=f"Ground-Truth (Extinction, ep={N_ACQ+N_EXT})")
    plt.show()



    # Extinction: Learned
    print("\n=== Learned (Extinction) ===")
    T_vis_learned_ext = make_terminals_absorbing_for_plot(T_learned_ext, env2.rewarded_terminals + env2.unrewarded_terminals)
    sanitize_for_plot(env2, T_vis_learned_ext)
    env2.plot_graph(T_vis_learned_ext, niter=N_ACQ+N_EXT, threshold=THRESH, save=False, title=f"Learned (Extinction, ep={N_ACQ+N_EXT})")
    plt.show()

    # Plot comparison metrics as bar charts
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))

    metrics_list = ['edge_match', 'l1', 'kl']
    titles = ['Edge Match Rate', 'L1 Distance', 'KL Divergence']
    ylabels = ['Match Rate', 'L1', 'KL (nats)']

    for idx, (metric, title, ylabel) in enumerate(zip(metrics_list, titles, ylabels)):
        ax = axes[idx]
        values = [metrics_acq[metric], metrics_ext[metric]]
        bars = ax.bar(['Acquisition', 'Extinction'], values, color=['#1f77b4', '#ff7f0e'], alpha=0.7)
        ax.set_ylabel(ylabel)
        ax.set_title(title)
        ax.grid(axis='y', alpha=0.3)
        
        # Add value labels on bars
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.3f}', ha='center', va='bottom')

    plt.tight_layout()
    plt.show()

# Now, degradation and latent inhibition

In [None]:
# Always-reward variant: terminal delivers reward regardless of cue/history
class GridEnvRightDownAlwaysReward(GridEnvRightDownNoSelf):
    def step(self, action):
        # Use base dynamics to compute next state & done flag
        ns, _, done = super().step(action)
        # Force reward at terminal regardless of visited_cue
        if done and hasattr(self, "is_terminal") and self.is_terminal(ns):
            return ns, 1, True
        return ns, 0, done

In [None]:
# ==========================================================
# Analytic ground-truth builders (ACQUISITION / EXTINCTION / DEGRADATION)
# ==========================================================
import numpy as np

def _base_T_from_env(env):
    """
    Deterministic [S, A, S] from grid geometry (no clones).
    Assumes env.valid_actions, env.state_to_pos, env.pos_to_state, env.base_actions are defined.
    Terminals are set absorbing (no outgoing mass).
    """
    S = int(env.num_unique_states)
    A = max(a for acts in env.valid_actions.values() for a in acts) + 1 if env.valid_actions else 0
    T = np.zeros((S, A, S), dtype=float)

    # deterministic right/down moves
    for s, acts in env.valid_actions.items():
        for a in acts:
            i, j = env.state_to_pos[s]
            di, dj = env.base_actions[a]
            ni, nj = i + di, j + dj
            sp = env.pos_to_state[(ni, nj)]
            T[s, a, sp] = 1.0

    # terminals absorbing
    for t in (env.rewarded_terminals + env.unrewarded_terminals):
        if 0 <= t < S:
            T[t, :, :] = 0.0
    return T

def gt_acquisition_with_clones(env, cue_states):
    """
    Acquisition GT: Markovian map created by cloning the *successors* of each cue,
    and redirecting cue->successor edges to cue->clone(successor). Clones inherit
    their parent successor's outgoing transitions.
    Returns [S_gt, A, S_gt]; S_gt >= original S.
    """
    T = _base_T_from_env(env)
    S, A, _ = T.shape
    successor_to_clone = {}

    for cue in cue_states:
        for a in env.get_valid_actions(cue):
            i, j = env.state_to_pos[cue]
            di, dj = env.base_actions[a]
            ni, nj = i + di, j + dj
            sp = env.pos_to_state[(ni, nj)]

            # reuse clone if this successor already cloned
            if sp in successor_to_clone:
                cl = successor_to_clone[sp]
            else:
                cl = S
                S += 1
                # grow T by 1 state (pad last axis and first axis consistently)
                T = np.pad(T, ((0,1),(0,0),(0,1)), mode='constant')
                # clone inherits outgoing from parent successor (NOTE: parent row is in old slice [:, :S-1])
                T[cl, :, :T.shape[2]-1] = T[sp, :, :T.shape[2]-1]
                successor_to_clone[sp] = cl

            # redirect cue->sp to cue->cl
            mass = T[cue, a, sp]
            T[cue, a, sp] = 0.0
            T[cue, a, cl] = mass

    # terminals absorbing (safety)
    for t in (env.rewarded_terminals + env.unrewarded_terminals):
        if 0 <= t < T.shape[0]:
            T[t, :, :] = 0.0
    return T

def gt_extinction_no_clones(env2):
    """
    Extinction GT: no clones; terminal transitions *always* go to the NON-reward terminal.
    We redirect any incoming mass to rewarded terminal -> unrewarded terminal.
    """
    T = _base_T_from_env(env2)
    # map rewarded terminal(s) to paired unrewarded terminal(s)
    for idx, rt in enumerate(env2.rewarded_terminals):
        if idx < len(env2.unrewarded_terminals):
            nt = env2.unrewarded_terminals[idx]
            mask = T[:, :, rt] > 0
            T[:, :, nt][mask] = T[:, :, rt][mask]
            T[:, :, rt][mask] = 0.0
    return T

def gt_degradation_no_clones(env_always_reward):
    """
    Contingency degradation GT: reward is given regardless of cue/history.
    Structural GT is simply the base right/down graph with terminals absorbing (no clones).
    """
    return _base_T_from_env(env_always_reward)

# ------------------------------------------
# Backwards compatibility: provide 'build_*' names
# so existing code that calls build_gt_* keeps working.
# ------------------------------------------
def build_gt_acquisition_with_clones(env, cue_state):
    return gt_acquisition_with_clones(env, cue_states=[cue_state])

def build_gt_extinction_no_clones(env2):
    return gt_extinction_no_clones(env2)

def build_gt_degradation_no_clones(env_always_reward):
    return gt_degradation_no_clones(env_always_reward)

In [None]:
def gt_degradation_no_clones(env_always_reward):
    """
    Degradation GT: no clones; transitions are the base right/down grid transitions.
    Reward is always delivered at the (rewarded) terminal, but that does not change the
    transition *structure*, so this is simply the base graph without clones.
    """
    T = base_T_from_env(env_always_reward)
    # Terminals have no outgoing edges; incoming edges unchanged (deterministic).
    return T

In [None]:
def run_latent_inhibition_seed(seed:int,
                               cfg:CoDAConfig,
                               pre_episodes:int = 500,
                               acq_episodes:int = 1000,
                               max_steps:int = 20,
                               cue:int = 5):
    """
    Phase 1 (pre-exposure):  no reward (extinction-like) to inflate P(US & ~CS)
    Phase 2 (acquisition):   normal cued task; splitting should be delayed
    Metrics are computed across the entire run vs. the acquisition GT-with-clones.
    References: paper pp. 3–4, Fig. 3b; Methods p.8 (protocol & rationale).  [oai_citation:1‡Learning_the_structure_of_aliased_state_spaces_in_non_Markovian_tasks (1).pdf](sediment://file_00000000335c722fbbd1dce642c2838e)
    """
    np.random.seed(seed)

    # -------- Pre-exposure (no reward) --------
    env_pre = GridEnvRightDownNoCue(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    agent   = CoDAAgent(env_pre, cfg)

    T_series = []

    for ep in range(1, pre_episodes+1):
        (states, actions) = generate_dataset(env_pre, n_episodes=1, max_steps=max_steps)[0]
        agent.update_with_episode(states, actions)
        # No splitting expected (no US); still collect T
        T_series.append(agent.get_T().copy())

    # -------- Acquisition (normal cued task) --------
    env_acq = GridEnvRightDownNoSelf(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    agent.env = env_acq  # keep uncertainty/accumulators; switch to cued env

    with_clones = False
    for ep in range(1, acq_episodes+1):
        if with_clones:
            (states, actions) = generate_dataset_post_augmentation(env_acq, agent.get_T(), n_episodes=1, max_steps=max_steps)[0]
        else:
            (states, actions) = generate_dataset(env_acq, n_episodes=1, max_steps=max_steps)[0]

        agent.update_with_episode(states, actions)
        new = agent.maybe_split()
        if new:
            with_clones = True

        T_series.append(agent.get_T().copy())

    # GT for latent inhibition evaluation: acquisition graph with clones
    T_ref = gt_acquisition_with_clones(env_acq, cue_states=[cue])

    # Metrics over time (full series vs acquisition GT)
    KL = kl_over_time_fixed(T_series, T_ref, use_js=False)
    JS = kl_over_time_fixed(T_series, T_ref, use_js=True)
    H  = entropy_over_time(T_series)
    MS = np.array([markovization_score(T) for T in T_series])

    return dict(T_series=T_series, T_ref=T_ref, KL=KL, JS=JS, H=H, MS=MS)

In [None]:
def run_contingency_degradation_seed(seed:int,
                                     cfg:CoDAConfig,
                                     acq_episodes:int = 1000,
                                     degr_episodes:int = 1000,
                                     max_steps:int = 20,
                                     cue:int = 5,
                                     wash_in:int = 50,
                                     edge_eps_early:float = 1e-4,
                                     edge_eps_late:float  = 1e-6):
    """
    Phase 1 (acquisition):      normal cued task (splits form).
    Phase 2 (degradation):      reward at terminal regardless of cue/history (always reward).
                                RC falls while PC~1, so clones should merge.
    We report metrics separately for acq (vs acq GT-with-clones) and degr (vs degr GT no-clones).
    References: paper p.4 (Contingency degradation; Eq. 3); Methods p.8.  [oai_citation:2‡Learning_the_structure_of_aliased_state_spaces_in_non_Markovian_tasks (1).pdf](sediment://file_00000000335c722fbbd1dce642c2838e)
    """
    np.random.seed(seed)

    # -------- Acquisition --------
    env_acq = GridEnvRightDownNoSelf(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    agent   = CoDAAgent(env_acq, cfg)

    T_series_acq = []
    with_clones = False

    for ep in range(1, acq_episodes+1):
        if with_clones:
            (states, actions) = generate_dataset_post_augmentation(env_acq, agent.get_T(), n_episodes=1, max_steps=max_steps)[0]
        else:
            (states, actions) = generate_dataset(env_acq, n_episodes=1, max_steps=max_steps)[0]

        agent.update_with_episode(states, actions)
        if agent.maybe_split():
            with_clones = True

        T_series_acq.append(agent.get_T().copy())

    T_ref_acq = gt_acquisition_with_clones(env_acq, cue_states=[cue])

    # -------- Degradation (always reward; no reset) --------
    env_deg = GridEnvRightDownAlwaysReward(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    # carry learned clones into degradation env so merges are meaningful
    env_deg.clone_dict = dict(getattr(env_acq, "clone_dict", {}))
    env_deg.reverse_clone_dict = dict(getattr(env_acq, "reverse_clone_dict", {}))
    agent.env = env_deg

    # short wash-in to encourage structural merges early (optional)
    orig = dict(count_decay=agent.cfg.count_decay, trace_decay=agent.cfg.trace_decay, retro_decay=agent.cfg.retro_decay,
                theta_merge=agent.cfg.theta_merge, confidence=agent.cfg.confidence,
                min_presence_episodes=agent.cfg.min_presence_episodes,
                min_effective_exposure=agent.cfg.min_effective_exposure)

    agent.cfg.count_decay = 0.98
    agent.cfg.trace_decay = 0.98
    agent.cfg.retro_decay = 0.98
    agent.cfg.theta_merge = 0.60
    agent.cfg.confidence  = 0.99
    agent.cfg.min_presence_episodes += 3
    agent.cfg.min_effective_exposure = int(agent.cfg.min_effective_exposure * 1.5)

    T_series_deg = []
    for k in range(degr_episodes):
        (states, actions) = generate_dataset_post_augmentation(env_deg, agent.get_T(), n_episodes=1, max_steps=max_steps)[0]
        agent.update_with_episode(states, actions)
        agent._edge_eps_override = edge_eps_early if k < wash_in else edge_eps_late
        agent.maybe_merge()
        T_series_deg.append(agent.get_T().copy())

        if k == wash_in - 1:
            # restore original cfg after wash-in
            for key, val in orig.items():
                setattr(agent.cfg, key, val)

    T_ref_deg = gt_degradation_no_clones(env_deg)

    # Metrics (separate for each phase)
    KL_acq = kl_over_time_fixed(T_series_acq, T_ref_acq, use_js=False)
    JS_acq = kl_over_time_fixed(T_series_acq, T_ref_acq, use_js=True)
    H_acq  = entropy_over_time(T_series_acq)
    MS_acq = np.array([markovization_score(T) for T in T_series_acq])

    KL_deg = kl_over_time_fixed(T_series_deg, T_ref_deg, use_js=False)
    JS_deg = kl_over_time_fixed(T_series_deg, T_ref_deg, use_js=True)
    H_deg  = entropy_over_time(T_series_deg)
    MS_deg = np.array([markovization_score(T) for T in T_series_deg])

    return dict(
        env_acq=env_acq, env_deg=env_deg,
        T_series_acq=T_series_acq, T_series_deg=T_series_deg,
        T_ref_acq=T_ref_acq, T_ref_deg=T_ref_deg,
        KL_acq=KL_acq, JS_acq=JS_acq, H_acq=H_acq, MS_acq=MS_acq,
        KL_deg=KL_deg, JS_deg=JS_deg, H_deg=H_deg, MS_deg=MS_deg
    )

In [None]:
# -------- Latent inhibition: multi-seed --------
def run_latent_inhibition_many(cfg:CoDAConfig,
                               seeds:list,
                               pre_episodes:int=500,
                               acq_episodes:int=1000,
                               max_steps:int=20,
                               cue:int=5):
    runs = [run_latent_inhibition_seed(s, cfg, pre_episodes, acq_episodes, max_steps, cue) for s in seeds]
    KL_runs = [r["KL"] for r in runs]
    JS_runs = [r["JS"] for r in runs]
    H_runs  = [r["H"]  for r in runs]
    MS_runs = [r["MS"] for r in runs]
    return dict(runs=runs, KL=KL_runs, JS=JS_runs, H=H_runs, MS=MS_runs)

def plot_latent_inhibition_summary(res_li, pre_episodes:int, title_suffix=""):
    plot_band(res_li["KL"], f"KL (latent inhibition{title_suffix})", "KL (nats)")
    plot_band(res_li["JS"], f"JS (latent inhibition{title_suffix})", "JS")
    plot_band(res_li["H"],  f"Avg H(S'|S) (latent inhibition{title_suffix})", "nats")
    plot_band(res_li["MS"], f"Markovization (latent inhibition{title_suffix})", "[0,1]")
    # vertical line to show transition from pre-exposure to acquisition
    for fig_num in plt.get_fignums()[-4:]:
        plt.figure(fig_num)
        plt.axvline(pre_episodes, ls="--", alpha=0.4)

# -------- Degradation: multi-seed --------
def run_degradation_many(cfg:CoDAConfig,
                         seeds:list,
                         acq_episodes:int=1000,
                         degr_episodes:int=1000,
                         max_steps:int=20,
                         cue:int=5):
    runs = [run_contingency_degradation_seed(s, cfg, acq_episodes, degr_episodes, max_steps, cue) for s in seeds]
    return dict(
        runs=runs,
        KL_acq=[r["KL_acq"] for r in runs], JS_acq=[r["JS_acq"] for r in runs],
        H_acq=[r["H_acq"] for r in runs],   MS_acq=[r["MS_acq"] for r in runs],
        KL_deg=[r["KL_deg"] for r in runs], JS_deg=[r["JS_deg"] for r in runs],
        H_deg=[r["H_deg"] for r in runs],   MS_deg=[r["MS_deg"] for r in runs],
    )

def plot_degradation_summary(res_deg, title_suffix=""):
    plot_band(res_deg["KL_acq"], f"KL (acquisition — degradation runs{title_suffix})", "KL (nats)")
    plot_band(res_deg["JS_acq"], f"JS (acquisition — degradation runs{title_suffix})", "JS")
    plot_band(res_deg["H_acq"],  f"Avg H(S'|S) (acquisition — degradation runs{title_suffix})", "nats")
    plot_band(res_deg["MS_acq"], f"Markovization (acquisition — degradation runs{title_suffix})", "[0,1]")

    plot_band(res_deg["KL_deg"], f"KL (contingency degradation{title_suffix})", "KL (nats)")
    plot_band(res_deg["JS_deg"], f"JS (contingency degradation{title_suffix})", "JS")
    plot_band(res_deg["H_deg"],  f"Avg H(S'|S) (contingency degradation{title_suffix})", "nats")
    plot_band(res_deg["MS_deg"], f"Markovization (contingency degradation{title_suffix})", "[0,1]")

In [None]:
# --- Parameters (paper-style) ---
N_LI_PRE = 500
N_LI_ACQ = 1000
N_CD_ACQ = 1000
N_CD_DEG = 1000
SEEDS    = list(range(30))  # 30 seeds

# Use your v2 acquisition cfg (keep it as you’ve set earlier)
cfg_li = CoDAConfig()  # latent inhibition relies on v2-like gating & uncertainty
cfg_cd = CoDAConfig()  # degradation uses same; merges depend on retrospective fall

# -------- Latent inhibition --------
res_li = run_latent_inhibition_many(cfg_li, SEEDS, pre_episodes=N_LI_PRE, acq_episodes=N_LI_ACQ, max_steps=MAX_STEPS, cue=CUE)
plot_latent_inhibition_summary(res_li, pre_episodes=N_LI_PRE)

# -------- Contingency degradation --------
res_deg = run_degradation_many(cfg_cd, SEEDS, acq_episodes=N_CD_ACQ, degr_episodes=N_CD_DEG, max_steps=MAX_STEPS, cue=CUE)
plot_degradation_summary(res_deg)

In [None]:
# ===== Contingency degradation: plots in the same style as acquisition/extinction =====

# Unpack runs
KL_acq_runs = res_deg["KL_acq"]
JS_acq_runs = res_deg["JS_acq"]
H_acq_runs  = res_deg["H_acq"]
MS_acq_runs = res_deg["MS_acq"]

KL_deg_runs = res_deg["KL_deg"]
JS_deg_runs = res_deg["JS_deg"]
H_deg_runs  = res_deg["H_deg"]
MS_deg_runs = res_deg["MS_deg"]

# Reuse your mean±SE band helper
def mean_se(arrs):
    L = max(len(a) for a in arrs)
    M = np.full((len(arrs), L), np.nan)
    for i, a in enumerate(arrs):
        M[i, :len(a)] = a
    mean = np.nanmean(M, axis=0)
    se   = np.nanstd(M, axis=0, ddof=max(1, min(len(arrs)-1, 1))) / np.sqrt(max(1, len(arrs)))
    return mean, se

def plot_band(y_runs, title, ylabel):
    mean, se = mean_se(y_runs)
    x = np.arange(len(mean))
    plt.plot(x, mean, lw=2.2, label=f"mean ({len(y_runs)} seeds)")
    plt.fill_between(x, mean-se, mean+se, alpha=0.25, label="±1 SE")
    plt.title(title); plt.xlabel("Episode"); plt.ylabel(ylabel); plt.legend(); plt.grid(alpha=0.2)

# Acquisition portion (degradation runs, pre-switch)
plt.figure(figsize=(10,4)); plot_band(KL_acq_runs, "KL (acquisition)", "KL (nats)"); plt.show()
plt.figure(figsize=(10,4)); plot_band(JS_acq_runs, "JS (acquisition)", "JS"); plt.show()
plt.figure(figsize=(10,4)); plot_band(H_acq_runs,  "Avg H(S'|S) (acquisition)", "nats"); plt.show()
plt.figure(figsize=(10,4)); plot_band(MS_acq_runs, "Markovization (acquisition)", "[0,1]"); plt.show()

# Degradation portion (post-switch)
plt.figure(figsize=(10,4)); plot_band(KL_deg_runs, "KL (contingency degradation)", "KL (nats)"); plt.show()
plt.figure(figsize=(10,4)); plot_band(JS_deg_runs, "JS (contingency degradation)", "JS"); plt.show()
plt.figure(figsize=(10,4)); plot_band(H_deg_runs,  "Avg H(S'|S) (contingency degradation)", "nats"); plt.show()
plt.figure(figsize=(10,4)); plot_band(MS_deg_runs, "Markovization (contingency degradation)", "[0,1]"); plt.show()

for fig_num in plt.get_fignums()[-4:]:  # last 4 figures (degradation plots)
    plt.figure(fig_num)
    plt.axvline(N_CD_ACQ, ls="--", color="tab:cyan", alpha=0.4)

# real

In [None]:
# ==========================================================
# Core utilities + Analytic Ground-Truth builders + Metrics
# ==========================================================
import numpy as np

# ---------------------------
# Small numeric helpers
# ---------------------------
EPS = 1e-12

def _safe_row_norm(x: np.ndarray, axis: int = -1, eps: float = EPS) -> np.ndarray:
    y = x.astype(float, copy=True)
    s = y.sum(axis=axis, keepdims=True)
    s[s < eps] = 1.0
    return y / s

def _pad3(A: np.ndarray, shape):
    """Zero-pad a [S,A,S] tensor to 'shape' without cropping."""
    S, A_, S2 = A.shape
    Sg, Ag, S2g = shape
    out = np.zeros((Sg, Ag, S2g), dtype=float)
    out[:min(S,Sg), :min(A_,Ag), :min(S2,S2g)] = A[:min(S,Sg), :min(A_,Ag), :min(S2,S2g)]
    return out

def _agg(T: np.ndarray) -> np.ndarray:
    """Aggregate over actions to get P(s'|s) as [S,S] row-stochastic."""
    P = T.sum(axis=1)
    return _safe_row_norm(P, axis=1)

def _row_kl(p: np.ndarray, q: np.ndarray, eps: float = EPS) -> float:
    p = np.clip(p, eps, 1.0); p /= p.sum()
    q = np.clip(q, eps, 1.0); q /= q.sum()
    return float(np.sum(p * (np.log(p) - np.log(q))))

def _row_js(p: np.ndarray, q: np.ndarray, eps: float = EPS) -> float:
    p = np.clip(p, eps, 1.0); p /= p.sum()
    q = np.clip(q, eps, 1.0); q /= q.sum()
    m = 0.5*(p+q)
    return 0.5*_row_kl(p, m, eps) + 0.5*_row_kl(q, m, eps)

def _row_H(p: np.ndarray, eps: float = EPS) -> float:
    p = np.clip(p, eps, 1.0); p /= p.sum()
    return float(-np.sum(p * np.log(p)))

# ---------------------------
# Deterministic base transitions from geometry
# ---------------------------
def base_T_from_env(env):
    """
    Deterministic [S, A, S] from grid geometry (no clones).
    Assumes env.valid_actions, env.state_to_pos, env.pos_to_state, env.base_actions.
    Terminals are set absorbing (no outgoing mass).
    """
    S = int(env.num_unique_states)
    A = max(a for acts in env.valid_actions.values() for a in acts) + 1 if env.valid_actions else 0
    T = np.zeros((S, A, S), dtype=float)

    for s, acts in env.valid_actions.items():
        for a in acts:
            i, j = env.state_to_pos[s]
            di, dj = env.base_actions[a]
            ni, nj = i + di, j + dj
            sp = env.pos_to_state[(ni, nj)]
            T[s, a, sp] = 1.0

    # terminals absorbing
    for t in (env.rewarded_terminals + env.unrewarded_terminals):
        if 0 <= t < S:
            T[t, :, :] = 0.0
    return T

# internal alias used by some earlier snippets
_base_T_from_env = base_T_from_env

# ---------------------------
# Analytic ground-truth builders
# ---------------------------
def gt_acquisition_with_clones(env, cue_states):
    """
    Acquisition GT: Markovian by cloning each 'successor' of the cue state(s), and
    redirecting cue->successor edges to cue->clone(successor). Clones inherit the
    parent's outgoing transitions.
    """
    T = base_T_from_env(env)
    S, A, _ = T.shape
    successor_to_clone = {}

    for cue in cue_states:
        for a in env.get_valid_actions(cue):
            i, j = env.state_to_pos[cue]
            di, dj = env.base_actions[a]
            ni, nj = i + di, j + dj
            sp = env.pos_to_state[(ni, nj)]

            if sp in successor_to_clone:
                cl = successor_to_clone[sp]
            else:
                cl = S
                S += 1
                T = np.pad(T, ((0,1),(0,0),(0,1)), mode='constant')
                # clone inherits outgoing of parent successor (into old slice)
                T[cl, :, :T.shape[2]-1] = T[sp, :, :T.shape[2]-1]
                successor_to_clone[sp] = cl

            # redirect cue->sp to cue->cl
            mass = T[cue, a, sp]
            T[cue, a, sp] = 0.0
            T[cue, a, cl] = mass

    # safety: terminals absorbing
    for t in (env.rewarded_terminals + env.unrewarded_terminals):
        if 0 <= t < T.shape[0]:
            T[t, :, :] = 0.0
    return T

def gt_extinction_no_clones(env2):
    """
    Extinction GT: no clones; any transition that would go into a rewarded terminal
    is redirected to the paired unrewarded terminal. Terminals are absorbing.
    """
    T = base_T_from_env(env2)
    for idx, rt in enumerate(env2.rewarded_terminals):
        if idx < len(env2.unrewarded_terminals):
            nt = env2.unrewarded_terminals[idx]
            mask = T[:, :, rt] > 0
            T[:, :, nt][mask] = T[:, :, rt][mask]
            T[:, :, rt][mask] = 0.0
    return T

def gt_degradation_no_clones(env_always_reward):
    """
    Contingency degradation GT: reward is given regardless of cue/history.
    Structural GT is the base right/down graph without clones; terminals absorbing.
    """
    return base_T_from_env(env_always_reward)

# Back-compat names some cells call:
def build_gt_acquisition_with_clones(env, cue_state):
    return gt_acquisition_with_clones(env, cue_states=[cue_state])

def build_gt_extinction_no_clones(env2):
    return gt_extinction_no_clones(env2)

def build_gt_degradation_no_clones(env_always_reward):
    return gt_degradation_no_clones(env_always_reward)

# ---------------------------
# Metrics (fixed reference)
# ---------------------------
def kl_over_time_fixed(T_series, T_ref_fixed, use_js: bool = False):
    """
    Row-average KL (or JS if use_js=True) between learned T_t and a fixed
    reference T_ref_fixed (both aggregated over actions). Pads shapes as needed.
    Returns 1D array with one value per episode in T_series.
    """
    scores = []
    shape = T_ref_fixed.shape
    Q = _agg(T_ref_fixed)
    for T in T_series:
        P = _agg(_pad3(T, shape))
        if use_js:
            row_vals = [_row_js(P[i], Q[i]) for i in range(P.shape[0])]
        else:
            row_vals = [_row_kl(P[i], Q[i]) for i in range(P.shape[0])]
        scores.append(float(np.mean(row_vals)))
    return np.array(scores)

def entropy_over_time(T_series):
    """Average next-state entropy H(S'|S) per episode."""
    vals = []
    for T in T_series:
        P = _agg(T)
        vals.append(float(np.mean([_row_H(P[i]) for i in range(P.shape[0])])))
    return np.array(vals)

def markovization_score(T: np.ndarray) -> float:
    """
    1 - normalized conditional entropy averaged across states.
    Returns a score in [0,1], higher = more deterministic/Markovian.
    """
    P = _agg(T)
    H = np.array([_row_H(P[i]) for i in range(P.shape[0])])
    Hmax = np.log(max(2, P.shape[1]))
    return float(1.0 - np.mean(H)/Hmax)

In [None]:
class GridEnvRightDownAlwaysReward(GridEnvRightDownNoSelf):
    def step(self, action):
        # Use base dynamics to compute next state & done flag
        ns, _, done = super().step(action)
        # Force reward at terminal regardless of visited_cue
        if done and hasattr(self, "is_terminal") and self.is_terminal(ns):
            return ns, 1, True
        return ns, 0, done

In [None]:
def gt_degradation_no_clones(env_always_reward):
    """
    Degradation GT: no clones; transitions are the base right/down grid transitions.
    Reward is always delivered at the (rewarded) terminal, but that does not change the
    transition *structure*, so this is simply the base graph without clones.
    """
    T = base_T_from_env(env_always_reward)
    # Terminals have no outgoing edges; incoming edges unchanged (deterministic).
    return T


In [None]:
def run_latent_inhibition_seed(seed:int,
                               cfg:CoDAConfig,
                               pre_episodes:int = 500,
                               acq_episodes:int = 1000,
                               max_steps:int = 20,
                               cue:int = 5):
    """
    Phase 1 (pre-exposure):  no reward (extinction-like) to inflate P(US & ~CS)
    Phase 2 (acquisition):   normal cued task; splitting should be delayed
    Metrics are computed across the entire run vs. the acquisition GT-with-clones.
    References: paper pp. 3–4, Fig. 3b; Methods p.8 (protocol & rationale). :contentReference[oaicite:1]{index=1}
    """
    np.random.seed(seed)

    # -------- Pre-exposure (no reward) --------
    env_pre = GridEnvRightDownNoCue(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    agent   = CoDAAgent(env_pre, cfg)

    T_series = []

    for ep in range(1, pre_episodes+1):
        (states, actions) = generate_dataset(env_pre, n_episodes=1, max_steps=max_steps)[0]
        agent.update_with_episode(states, actions)
        # No splitting expected (no US); still collect T
        T_series.append(agent.get_T().copy())

    # -------- Acquisition (normal cued task) --------
    env_acq = GridEnvRightDownNoSelf(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    agent.env = env_acq  # keep uncertainty/accumulators; switch to cued env

    with_clones = False
    for ep in range(1, acq_episodes+1):
        if with_clones:
            (states, actions) = generate_dataset_post_augmentation(env_acq, agent.get_T(), n_episodes=1, max_steps=max_steps)[0]
        else:
            (states, actions) = generate_dataset(env_acq, n_episodes=1, max_steps=max_steps)[0]

        agent.update_with_episode(states, actions)
        new = agent.maybe_split()
        if new:
            with_clones = True

        T_series.append(agent.get_T().copy())

    # GT for latent inhibition evaluation: acquisition graph with clones
    T_ref = gt_acquisition_with_clones(env_acq, cue_states=[cue])

    # Metrics over time (full series vs acquisition GT)
    KL = kl_over_time_fixed(T_series, T_ref, use_js=False)
    JS = kl_over_time_fixed(T_series, T_ref, use_js=True)
    H  = entropy_over_time(T_series)
    MS = np.array([markovization_score(T) for T in T_series])

    return dict(T_series=T_series, T_ref=T_ref, KL=KL, JS=JS, H=H, MS=MS)


In [None]:
def run_contingency_degradation_seed(seed:int,
                                     cfg:CoDAConfig,
                                     acq_episodes:int = 1000,
                                     degr_episodes:int = 1000,
                                     max_steps:int = 20,
                                     cue:int = 5,
                                     wash_in:int = 50,
                                     edge_eps_early:float = 1e-4,
                                     edge_eps_late:float  = 1e-6):
    """
    Phase 1 (acquisition):      normal cued task (splits form).
    Phase 2 (degradation):      reward at terminal regardless of cue/history (always reward).
                                RC falls while PC~1, so clones should merge.
    We report metrics separately for acq (vs acq GT-with-clones) and degr (vs degr GT no-clones).
    References: paper p.4 (Contingency degradation; Eq. 3); Methods p.8. :contentReference[oaicite:2]{index=2}
    """
    np.random.seed(seed)

    # -------- Acquisition --------
    env_acq = GridEnvRightDownNoSelf(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    agent   = CoDAAgent(env_acq, cfg)

    T_series_acq = []
    with_clones = False

    for ep in range(1, acq_episodes+1):
        if with_clones:
            (states, actions) = generate_dataset_post_augmentation(env_acq, agent.get_T(), n_episodes=1, max_steps=max_steps)[0]
        else:
            (states, actions) = generate_dataset(env_acq, n_episodes=1, max_steps=max_steps)[0]

        agent.update_with_episode(states, actions)
        new = agent.maybe_split()
        if new:
            with_clones = True

        T_curr = agent.get_T().copy()
        T_series_acq.append(T_curr)

        # === plot last iteration graph ===
        if ep == acq_episodes:
            print(f"Latent inhibition — final acquisition graph (seed {seed})")
            T_vis = make_terminals_absorbing_for_plot(T_curr, env_acq.rewarded_terminals + env_acq.unrewarded_terminals)
            sanitize_for_plot(env_acq, T_vis)
            env_acq.plot_graph(T_vis, niter=ep, threshold=0.3, save=False, title=f"Latent Inhibition — Final Graph (seed={seed})")
            plt.show()

    T_ref_acq = gt_acquisition_with_clones(env_acq, cue_states=[cue])

    # -------- Degradation (always reward; no reset) --------
    env_deg = GridEnvRightDownAlwaysReward(cue_states=[cue], env_size=(4,4), rewarded_terminal=[15])
    # carry learned clones into degradation env so merges are meaningful
    env_deg.clone_dict = dict(getattr(env_acq, "clone_dict", {}))
    env_deg.reverse_clone_dict = dict(getattr(env_acq, "reverse_clone_dict", {}))
    agent.env = env_deg

    # short wash-in to encourage structural merges early (optional)
    orig = dict(count_decay=agent.cfg.count_decay, trace_decay=agent.cfg.trace_decay, retro_decay=agent.cfg.retro_decay,
                theta_merge=agent.cfg.theta_merge, confidence=agent.cfg.confidence,
                min_presence_episodes=agent.cfg.min_presence_episodes,
                min_effective_exposure=agent.cfg.min_effective_exposure)

    agent.cfg.count_decay = 0.98
    agent.cfg.trace_decay = 0.98
    agent.cfg.retro_decay = 0.98
    agent.cfg.theta_merge = 0.60
    agent.cfg.confidence  = 0.99
    agent.cfg.min_presence_episodes += 3
    agent.cfg.min_effective_exposure = int(agent.cfg.min_effective_exposure * 1.5)

    T_series_deg = []
    for k in range(degr_episodes):
        (states, actions) = generate_dataset_post_augmentation(env_deg, agent.get_T(), n_episodes=1, max_steps=max_steps)[0]
        agent.update_with_episode(states, actions)
        agent._edge_eps_override = edge_eps_early if k < wash_in else edge_eps_late
        agent.maybe_merge()

        T_curr = agent.get_T().copy()
        T_series_deg.append(T_curr)

        # === plot last iteration graph ===
        if k == degr_episodes - 1:
            print(f"Contingency degradation — final graph (seed {seed})")
            T_vis = make_terminals_absorbing_for_plot(T_curr, env_deg.rewarded_terminals + env_deg.unrewarded_terminals)
            sanitize_for_plot(env_deg, T_vis)
            env_deg.plot_graph(T_vis, niter=k+1, threshold=0.3, save=False, title=f"Contingency Degradation — Final Graph (seed={seed})")
            # plt.title(f"Contingency Degradation — Final Graph (seed={seed})")
            plt.show()

    T_ref_deg = gt_degradation_no_clones(env_deg)

    # Metrics (separate for each phase)
    KL_acq = kl_over_time_fixed(T_series_acq, T_ref_acq, use_js=False)
    JS_acq = kl_over_time_fixed(T_series_acq, T_ref_acq, use_js=True)
    H_acq  = entropy_over_time(T_series_acq)
    MS_acq = np.array([markovization_score(T) for T in T_series_acq])

    KL_deg = kl_over_time_fixed(T_series_deg, T_ref_deg, use_js=False)
    JS_deg = kl_over_time_fixed(T_series_deg, T_ref_deg, use_js=True)
    H_deg  = entropy_over_time(T_series_deg)
    MS_deg = np.array([markovization_score(T) for T in T_series_deg])

    return dict(
        env_acq=env_acq, env_deg=env_deg,
        T_series_acq=T_series_acq, T_series_deg=T_series_deg,
        T_ref_acq=T_ref_acq, T_ref_deg=T_ref_deg,
        KL_acq=KL_acq, JS_acq=JS_acq, H_acq=H_acq, MS_acq=MS_acq,
        KL_deg=KL_deg, JS_deg=JS_deg, H_deg=H_deg, MS_deg=MS_deg
    )


In [None]:
# -------- Latent inhibition: multi-seed --------
def run_latent_inhibition_many(cfg:CoDAConfig,
                               seeds:list,
                               pre_episodes:int=500,
                               acq_episodes:int=1000,
                               max_steps:int=20,
                               cue:int=5):
    runs = [run_latent_inhibition_seed(s, cfg, pre_episodes, acq_episodes, max_steps, cue) for s in seeds]
    KL_runs = [r["KL"] for r in runs]
    JS_runs = [r["JS"] for r in runs]
    H_runs  = [r["H"]  for r in runs]
    MS_runs = [r["MS"] for r in runs]
    return dict(runs=runs, KL=KL_runs, JS=JS_runs, H=H_runs, MS=MS_runs)

def plot_latent_inhibition_summary(res_li, pre_episodes:int, title_suffix=""):
    plot_band(res_li["KL"], f"KL (latent inhibition{title_suffix})", "KL (nats)")
    plot_band(res_li["JS"], f"JS (latent inhibition{title_suffix})", "JS")
    plot_band(res_li["H"],  f"Avg H(S'|S) (latent inhibition{title_suffix})", "nats")
    plot_band(res_li["MS"], f"Markovization (latent inhibition{title_suffix})", "[0,1]")
    # vertical line to show transition from pre-exposure to acquisition
    for fig_num in plt.get_fignums()[-4:]:
        plt.figure(fig_num)
        plt.axvline(pre_episodes, ls="--", alpha=0.4)

# -------- Degradation: multi-seed --------
def run_degradation_many(cfg:CoDAConfig,
                         seeds:list,
                         acq_episodes:int=1000,
                         degr_episodes:int=1000,
                         max_steps:int=20,
                         cue:int=5):
    runs = [run_contingency_degradation_seed(s, cfg, acq_episodes, degr_episodes, max_steps, cue) for s in seeds]
    return dict(
        runs=runs,
        KL_acq=[r["KL_acq"] for r in runs], JS_acq=[r["JS_acq"] for r in runs],
        H_acq=[r["H_acq"] for r in runs],   MS_acq=[r["MS_acq"] for r in runs],
        KL_deg=[r["KL_deg"] for r in runs], JS_deg=[r["JS_deg"] for r in runs],
        H_deg=[r["H_deg"] for r in runs],   MS_deg=[r["MS_deg"] for r in runs],
    )

def plot_degradation_summary(res_deg, title_suffix=""):
    plot_band(res_deg["KL_acq"], f"KL (acquisition — degradation runs{title_suffix})", "KL (nats)")
    plot_band(res_deg["JS_acq"], f"JS (acquisition — degradation runs{title_suffix})", "JS")
    plot_band(res_deg["H_acq"],  f"Avg H(S'|S) (acquisition — degradation runs{title_suffix})", "nats")
    plot_band(res_deg["MS_acq"], f"Markovization (acquisition — degradation runs{title_suffix})", "[0,1]")

    plot_band(res_deg["KL_deg"], f"KL (contingency degradation{title_suffix})", "KL (nats)")
    plot_band(res_deg["JS_deg"], f"JS (contingency degradation{title_suffix})", "JS")
    plot_band(res_deg["H_deg"],  f"Avg H(S'|S) (contingency degradation{title_suffix})", "nats")
    plot_band(res_deg["MS_deg"], f"Markovization (contingency degradation{title_suffix})", "[0,1]")


In [None]:
# --- Parameters (paper-style) ---
N_LI_PRE = 500
N_LI_ACQ = 1000
N_CD_ACQ = 1000
N_CD_DEG = 1000
SEEDS    = list(range(30))  # 30 seeds

# Use your v2 acquisition cfg (keep it as you’ve set earlier)
cfg_li = CoDAConfig()  # latent inhibition relies on v2-like gating & uncertainty
cfg_cd = CoDAConfig()  # degradation uses same; merges depend on retrospective fall

# -------- Latent inhibition --------
res_li = run_latent_inhibition_many(cfg_li, SEEDS, pre_episodes=N_LI_PRE, acq_episodes=N_LI_ACQ, max_steps=MAX_STEPS, cue=CUE)
plot_latent_inhibition_summary(res_li, pre_episodes=N_LI_PRE)

# -------- Contingency degradation --------
res_deg = run_degradation_many(cfg_cd, SEEDS, acq_episodes=N_CD_ACQ, degr_episodes=N_CD_DEG, max_steps=MAX_STEPS, cue=CUE)
# plot_degradation_summary(res_deg)

# ===== Contingency degradation: plots in the same style as acquisition/extinction =====

# Unpack runs
KL_acq_runs = res_deg["KL_acq"]
JS_acq_runs = res_deg["JS_acq"]
H_acq_runs  = res_deg["H_acq"]
MS_acq_runs = res_deg["MS_acq"]

KL_deg_runs = res_deg["KL_deg"]
JS_deg_runs = res_deg["JS_deg"]
H_deg_runs  = res_deg["H_deg"]
MS_deg_runs = res_deg["MS_deg"]

# Reuse your mean±SE band helper
def mean_se(arrs):
    L = max(len(a) for a in arrs)
    M = np.full((len(arrs), L), np.nan)
    for i, a in enumerate(arrs):
        M[i, :len(a)] = a
    mean = np.nanmean(M, axis=0)
    se   = np.nanstd(M, axis=0, ddof=max(1, min(len(arrs)-1, 1))) / np.sqrt(max(1, len(arrs)))
    return mean, se

def plot_band(y_runs, title, ylabel):
    mean, se = mean_se(y_runs)
    x = np.arange(len(mean))
    plt.plot(x, mean, lw=2.2, label=f"mean ({len(y_runs)} seeds)")
    plt.fill_between(x, mean-se, mean+se, alpha=0.25, label="±1 SE")
    plt.title(title); plt.xlabel("Episode"); plt.ylabel(ylabel); plt.legend(); plt.grid(alpha=0.2)

# Acquisition portion (degradation runs, pre-switch)
plt.figure(figsize=(10,4)); plot_band(KL_acq_runs, "KL (acquisition)", "KL (nats)"); plt.show()
plt.figure(figsize=(10,4)); plot_band(JS_acq_runs, "JS (acquisition)", "JS"); plt.show()
plt.figure(figsize=(10,4)); plot_band(H_acq_runs,  "Avg H(S'|S) (acquisition)", "nats"); plt.show()
plt.figure(figsize=(10,4)); plot_band(MS_acq_runs, "Markovization (acquisition)", "[0,1]"); plt.show()

# Degradation portion (post-switch)
plt.figure(figsize=(10,4)); plot_band(KL_deg_runs, "KL (contingency degradation)", "KL (nats)"); plt.show()
plt.figure(figsize=(10,4)); plot_band(JS_deg_runs, "JS (contingency degradation)", "JS"); plt.show()
plt.figure(figsize=(10,4)); plot_band(H_deg_runs,  "Avg H(S'|S) (contingency degradation)", "nats"); plt.show()
plt.figure(figsize=(10,4)); plot_band(MS_deg_runs, "Markovization (contingency degradation)", "[0,1]"); plt.show()


You can set `save=True` in `plot_graph` to export the changed snapshots as PNGs only for those episodes.

## Metrics: KL/JS vs episode, Entropy, and Markovization

In [None]:

# Compute metrics using the module we prepared
from coda_metrics import kl_over_time, entropy_over_time, markovization_score, ref_empirical_from_rollouts, greedy_right_down_policy
import numpy as np

def ref_builder_factory(env, policy_fn, nroll=300, max_steps=20):
    def _make_ref(T_learned):
        return ref_empirical_from_rollouts(env, policy_fn, n_episodes=nroll, max_steps=max_steps)
    return _make_ref

# Build episode-wise empirical references
ref_fn_acq = ref_builder_factory(env,  greedy_right_down_policy, nroll=300, max_steps=20)
ref_fn_ext = ref_builder_factory(env2, greedy_right_down_policy, nroll=300, max_steps=20)

KL_acq = kl_over_time(T_series_acq, ref_fn_acq, use_js=False)
JS_acq = kl_over_time(T_series_acq, ref_fn_acq, use_js=True)
H_acq  = entropy_over_time(T_series_acq)
MS_acq = np.array([markovization_score(T) for T in T_series_acq])

KL_ext = kl_over_time(T_series_ext, ref_fn_ext, use_js=False)
JS_ext = kl_over_time(T_series_ext, ref_fn_ext, use_js=True)
H_ext  = entropy_over_time(T_series_ext)
MS_ext = np.array([markovization_score(T) for T in T_series_ext])


In [None]:

# Plot (one metric per figure)
import matplotlib.pyplot as plt
import numpy as np

def _offset_plot(ax, y1, y2, label1, label2):
    ax.plot(y1, label=label1)
    off = len(y1)
    ax.plot(off + np.arange(len(y2)), y2, label=label2)
    ax.legend()
    ax.set_xlabel("episode")

fig, ax = plt.subplots()
ax.set_title("KL (learned || empirical reference)")
_offset_plot(ax, KL_acq, KL_ext, "acq", "ext")
ax.set_ylabel("KL")
plt.show()

fig, ax = plt.subplots()
ax.set_title("JS distance")
_offset_plot(ax, JS_acq, JS_ext, "acq", "ext")
ax.set_ylabel("JS")
plt.show()

fig, ax = plt.subplots()
ax.set_title("Avg next-state entropy H(S'|S)")
_offset_plot(ax, H_acq, H_ext, "acq", "ext")
ax.set_ylabel("nats")
plt.show()

fig, ax = plt.subplots()
ax.set_title("Markovization score (1 - normalized H)")
_offset_plot(ax, MS_acq, MS_ext, "acq", "ext")
ax.set_ylabel("[0,1]")
plt.show()


## Plots (separate panels with mean ± SE shading)

In [None]:

import numpy as np
import matplotlib.pyplot as plt

def _pad_runs(runs):
    L = max(len(r) for r in runs)
    out = np.full((len(runs), L), np.nan, dtype=float)
    for i, r in enumerate(runs):
        out[i, :len(r)] = r
    return out

def _plot_with_band(ax, runs, title, ylabel):
    M = _pad_runs(runs) if isinstance(runs, (list, tuple)) and len(runs)>0 and isinstance(runs[0], (list, np.ndarray)) else np.atleast_2d(runs)
    mean = np.nanmean(M, axis=0)
    se   = np.nanstd(M, axis=0, ddof=1) / np.sqrt(max(1, M.shape[0]))
    x = np.arange(len(mean))
    ax.plot(x, mean, lw=2.0, label="mean")
    ax.fill_between(x, mean - se, mean + se, alpha=0.2, label="±1 SE")
    ax.set_title(title)
    ax.set_xlabel("Episode")
    ax.set_ylabel(ylabel)
    ax.legend()

# Wrap single-run arrays as [array] so the function produces a zero-width band
KL_acq_runs = [KL_acq] if not isinstance(KL_acq, (list, tuple)) else KL_acq
JS_acq_runs = [JS_acq] if not isinstance(JS_acq, (list, tuple)) else JS_acq
H_acq_runs  = [H_acq]  if not isinstance(H_acq,  (list, tuple)) else H_acq
MS_acq_runs = [MS_acq] if not isinstance(MS_acq, (list, tuple)) else MS_acq

KL_ext_runs = [KL_ext] if not isinstance(KL_ext, (list, tuple)) else KL_ext
JS_ext_runs = [JS_ext] if not isinstance(JS_ext, (list, tuple)) else JS_ext
H_ext_runs  = [H_ext]  if not isinstance(H_ext,  (list, tuple)) else H_ext
MS_ext_runs = [MS_ext] if not isinstance(MS_ext, (list, tuple)) else MS_ext

# Acquisition-only figure
fig, axes = plt.subplots(2, 2, figsize=(10,6), constrained_layout=True)
_plot_with_band(axes[0,0], KL_acq_runs, "KL (acquisition)", "KL (nats)")
_plot_with_band(axes[0,1], JS_acq_runs, "JS (acquisition)", "JS")
_plot_with_band(axes[1,0], H_acq_runs,  "Avg H(S'|S) (acquisition)", "nats")
_plot_with_band(axes[1,1], MS_acq_runs, "Markovization (acquisition)", "[0,1]")
plt.show()

# Extinction-only figure
fig, axes = plt.subplots(2, 2, figsize=(10,6), constrained_layout=True)
_plot_with_band(axes[0,0], KL_ext_runs, "KL (extinction)", "KL (nats)")
_plot_with_band(axes[0,1], JS_ext_runs, "JS (extinction)", "JS")
_plot_with_band(axes[1,0], H_ext_runs,  "Avg H(S'|S) (extinction)", "nats")
_plot_with_band(axes[1,1], MS_ext_runs, "Markovization (extinction)", "[0,1]")
plt.show()
