In [7]:
import numpy as np
from ROCRL_environment import LinearSEMGenerator, LinearUtility, ROCRLEnvironment
from algos import *
from metrics import *
from utilities import get_acyclicity_fallback_count, reset_acyclicity_fallback_count

Assume observations $X\in\mathbb{R}^d$ are generated by $X=GZ$ for some $G\in\mathbb{R}^{d\times n}$ and latent vector $Z=(Z[1],\ldots,Z[n])^T\in\mathbb{R}^n$. Assume the latent variables $Z[i]$ are causally connected in a graph $\mathcal{G}$ via the SEM $Z=BZ+\epsilon$ with $B\in \mathbb{R}^{n\times n}$ the edge-weight matrix and $\epsilon\in\mathbb{R}^n$ the exogenous noise vector with mean $E[\epsilon]:=\nu$. Allow two kinds of intervention mechanisms -- soft and hard -- both of which alter the linear causal mechanism as follows:

$$
\underbrace{Z[i] = [B]_iZ+\epsilon[i]}_{\text{observational}} \longrightarrow \underbrace{Z[i]=[B^*]_iZ+\epsilon^*[i]}_{\text{interventional}}
$$

$[B^*]_i$ denotes the vector of post-intervention edge weights and $\epsilon^*[i]$ denotes the post-intervention noise with mean $E[\epsilon^*[i]]:=\nu^*[i]$. Soft interventions introduce a new $[B^*]_i$ while hard interventions set $[B^*]_i=0$. Define a reward $U$ based on the latent variables by $U(Z):=\theta^TZ+\epsilon_U$ where $E[\epsilon_U]=0$. Our goal is to find the set of interventions $a\in\mathcal{A}:=2^{[n]}$ that maximizes the expected utility: 

$$
a^*:=\text{argmax}_{a\in\mathcal{A}}E_a[U(Z)]
$$

In [8]:
n_latents = 5
d_obs = 10
sem = LinearSEMGenerator(n = n_latents, d = d_obs, seed = 3, latent_noise_std = 1.0)
util = LinearUtility(n = n_latents, noise_std = 1.0, theta_dist = "rademacher", seed = 1)

In [9]:
env = ROCRLEnvironment(sem = sem, utility = util)
# ---- compute oracle over all actions ----
intervention_type = "hard"  
kind_of = lambda a: ("none" if len(a) == 0 else intervention_type)

actions = list(powerset_actions(n_latents))
vals = np.array([expected_utility_under_action(env, a, kind_of(a)) for a in actions])

best_idx = int(np.argmax(vals))
best_action = actions[best_idx]
best_val = float(vals[best_idx])

# show top-k
k = min(10, len(actions))
top_idx = np.argsort(vals)[::-1][:k]

print("=" * 80)
print("TRUE GRAPH (B[i,j] != 0 means j -> i)")
print((np.abs(env.sem.B) > 1e-10).astype(int).T)  # displayed as i->j
print("Theta:", env.utility.theta)
print(f"Computed exact E[U] for all {len(actions)} actions (kind='{intervention_type}' for non-empty).")
print("\nTop actions:")
for rank, idx in enumerate(top_idx):
    print(f"{rank:2d}: a={fmt_action(actions[idx])}   E[U]={vals[idx]:.6f}")

print("\nTRUE OPTIMAL ACTION:", fmt_action(best_action))
print("TRUE OPTIMAL EXPECTED UTILITY:", best_val)
print("=" * 80)

TRUE GRAPH (B[i,j] != 0 means j -> i)
[[0 0 1 0 1]
 [0 0 0 1 0]
 [0 0 0 0 1]
 [0 0 0 0 1]
 [0 0 0 0 0]]
Theta: [-1.  1.  1.  1. -1.]
Computed exact E[U] for all 32 actions (kind='hard' for non-empty).

Top actions:
 0: a={0,1,4}   E[U]=1.646177
 1: a={4}   E[U]=1.646177
 2: a={0,4}   E[U]=1.646177
 3: a={1,4}   E[U]=1.646177
 4: a={3,4}   E[U]=1.344436
 5: a={0,1,3,4}   E[U]=1.344436
 6: a={1,3,4}   E[U]=1.344436
 7: a={0,3,4}   E[U]=1.344436
 8: a={0,1,2,4}   E[U]=1.301741
 9: a={1,2,4}   E[U]=1.301741

TRUE OPTIMAL ACTION: {4}
TRUE OPTIMAL EXPECTED UTILITY: 1.6461772269022772


Hyperparameters:
- $\delta,\epsilon_{\text{max}}$
- $\zeta_t$


Notations:
- $H_t :=$ estimate of $G^{\dagger}$, the Moore-Penrose inverse of $G$, at time $t$
- $\hat{\mathcal{G}}_t:=$ estimate of $\mathcal{G}$ at time $t$
- $\text{pa}_t(i):=$ parents of node $i$ in $\hat{\mathcal{G}}_t$
- $\mathcal{H}_\text{H}(i)=\text{pa}(i)$, $\mathcal{H}_\text{S}(i)=\text{an}(i)$, and $\mathcal{H}_m(i)\in\{\mathcal{H}_\text{S}(i),\mathcal{H}_\text{H}(i)\}$
- $N(\epsilon,\delta):=C^2\text{max}\{\epsilon^{-2},\epsilon_{\text{max}}^{-2}\}(d+\log(1/\delta))$ with $\delta_t=6\delta/(\pi^2t^2)$, the delta schedule for under-sampling thresholding
- $N_{a,t}:=\sum_{s\in[t]}1\{a_s=a\}$, number of times intervention (set) $a$ is selected up to and including time $t$
- $\mu_{a,t}=\frac{1}{N_{a,t}}\sum_{s\in[t]}1\{a_s=a\}X_s$, empirical sample mean of $X$ under intervention $a$ up to and including time $t$
- $\Sigma_{a,t}:=\frac{1}{N_{a,t}}\sum_{s\in[t]}1\{a_s=a\}X_sX_s^T-\mu_{a,t}\mu_{a,t}^T$, empirical covariance matrix of $X$ under intervention $a$ up to and including time $t$
- $\Theta_{\{i\},t}:=(\Sigma_{a,t})^{\dagger}$, empirical precision matrix of $X$ under intervention $a$ up to and including time $t$
- $R_{i,t}:=\Theta_{\{i\},t}-\Theta_{\emptyset, t}$, empirical precision difference between single node $\{i\}$ interventional and observational data up to and including time $t$

R0-CRL Algorithm:
1. Forced exploration
2. Adaptive exploration step $t$
   1. Latent recovery (inputs: $X_t,R_{i,t}, a_t$)
      1. Update inverse transform estimate $H_t \leftarrow$ principal eigenvector of $R_{i,t}$
      2. Estimate latent variables with $\hat{Z}_t \leftarrow H_tX_t$
      3. Update graph estimate $\hat{\mathcal{G}}_t$
         1. Compute estimated latent precision differences $\hat{R}^Z_{i,t} = (H_t^\dagger)^TR_{i,t}H_t^\dagger$
         2. Assign edges according to $i\to j \in \hat{\mathcal{G}}_t\iff i\neq j\text{ and }||[\hat{R}^Z_{i,t}]_j||_2>\gamma$
         3. Set $\hat{\mathcal{G}}_t$ to the transitive closure
      4. If previous action $a_t=\{i\}$ was a hard intervention, refine latent recovery:
         1. Re-update inverse transform estimate $H_t$ by subtracting off residual correlation
            1. Compute estimated latent covariances matrices with $\hat{\Sigma}_{a,t}^Z:=H_t\Sigma_{a,t}H_t^T$
            2. Regress parents of node $i$ onto $i$ via $\Xi_t[i,\text{pa}_t(i)]:=\hat{\Sigma}_{a,t}^Z[i,\text{pa}_t(i)]\left(\hat{\Sigma}_{a,t}^Z[\text{pa}_t(i),\text{pa}_t(i)]\right)^{-1}$
            3. Subtract off remaining relationship and re-update $H_t\leftarrow (I-\Xi_t)H_t$
         2. Update estimated latent variables with $\hat{Z}_t \leftarrow H_tX_t$
         3. Re-update graph estimate using earlier method with updated $H_t$, only do not take transitive closure
   2. Sample from under-explored interventions
      1. Compute graph-informed $u$ for intervention type $m\in\{\text{S},\text{H}\}$ of $a_t$. 
      $$
      u_{m,i}=\begin{cases}
      0 & \text{if $i$ is a root node}\\
      \sum_{j\in\mathcal{H}_m(i)}u_{m,j}+\sqrt{|\mathcal{H}_m(i)} & \text{otherwise}
      \end{cases},\qquad u_m=\sum_{i=1}^nu_{m,i}+\sqrt{n}
      $$
      2. Compute threshold $f_t(\hat{\mathcal{G}}):=\text{max}\{d^{1/3}n^{-2/3}u^{2/3}t^{2/3}, N(\epsilon_{\text{max}},\delta_t)\}$
      3. Compute $\mathcal{A}_t^{UE}:=\{a\in\mathcal{A}_0|N_{a,t}<f_t(\hat{\mathcal{G}})\}$
      4. If $\mathcal{A}_t^{UE}$ is not empty, sample $a_{t+1}$ from $\mathcal{A}_t^{UE}$ and skip last two steps (parameter estimation and ucb selection)
   3. Parameter estimation
      1. Compute weight matrices
      2. Solve $[\hat{Z}_t]_i=[B]_i[\hat{Z}_t]_{\text{pa}_t(i)}+\epsilon$ for $[B]_i$ by ridge regressing $[\hat{Z}_t]_{\text{pa}_t(i)}$ on $[\hat{Z}_t]_i$ with computed weight matrices. Denote estimate $A_t$
      3. Solve $U_t=\theta^T\hat{Z}_t+\epsilon_U$ for $\theta$ by ridge regressing $\hat{Z}_t$ on $U_t$ with computed weight matrices
   4. UCB selection
      1. Compute confidence ellipsoids $\mathcal{C}_{a,t}$ under intervention $a\in\mathcal{A}$ for estimated parameters $A_t$, $A^*_t$, and $\theta_t$. 
      2. Compute $\text{UCB}_{a,t}:=\text{max}_{\{\tilde{A},\tilde{\theta}\}\in \mathcal{C}_{a,t}}\langle \tilde{\theta},\sum_{\ell=0}^{L_t}\tilde{A}^{\ell}\cdot\hat{\nu}_a\rangle$ for all $a\in\mathcal{A}$.
      3. Pull $a_{t+1}=\text{argmax}_{a\in\mathcal{A}}\text{UCB}_{a,t}$

In [10]:
def fmt_action(a: set) -> str:
    if len(a) == 0:
        return "∅"
    return "{" + ",".join(str(i) for i in sorted(list(a))) + "}"

def fmt_AUE(AUE, n):
    # AUE is list of action-ids in {0,...,n}
    # 0 -> ∅, k -> {k-1}
    as_actions = [id_to_action(a_id, n) for a_id in AUE]
    return [fmt_action(a) for a in as_actions]

def adj_edges(adj: np.ndarray):
    edges = []
    if adj is None:
        return edges
    n = adj.shape[0]
    for i in range(n):
        for j in range(n):
            if adj[i, j] != 0:
                edges.append((i, j))
    return edges

def compute_S_and_adj(learner, gamma=None):
    """
    Returns:
      S: (n,n) where S[i,j] = || [RhatZ_{i,t}]_{j,:} ||_2  (i is intervened node, j is candidate child)
      A: (n,n) adjacency induced by threshold gamma (if gamma is None, uses learner.gamma)
    Notes:
      - diagonal is set to 0
      - requires learner.H and at least some data for obs + singleton i
    """
    n = learner.n_latent
    if learner.H is None:
        return None, None

    gamma = learner.gamma if gamma is None else float(gamma)

    H = learner.H
    Hpinv = np.linalg.pinv(H)      # d x n
    Hpinv_T = Hpinv.T              # n x d

    # observational precision
    Theta0 = learner._precision_for_action_id(0)

    S = np.zeros((n, n), dtype=float)
    A = np.zeros((n, n), dtype=int)

    for i in range(n):
        # need at least some obs and singleton samples for i
        if learner.counts_A0.get(0, 0) < 2 or learner.counts_A0.get(i+1, 0) < 2:
            continue

        Thetai = learner._precision_for_action_id(i+1)
        Ri = Thetai - Theta0
        RhatZ = Hpinv_T @ Ri @ Hpinv      # n x n

        for j in range(n):
            if i == j:
                continue
            val = float(np.linalg.norm(RhatZ[j, :], 2))
            S[i, j] = val
            if val > gamma:
                A[i, j] = 1

    return S, A

import matplotlib.pyplot as plt

def plot_S_heatmap(S, title="S heatmap"):
    plt.figure()
    plt.imshow(S, aspect="auto")  # uses default colormap
    plt.colorbar()
    plt.title(title)
    plt.xlabel("j (row index whose norm is taken)")
    plt.ylabel("i (intervention index)")
    plt.show()

def plot_S_with_stars(S, gamma, title=None):
    """
    Plot heatmap of S (row-norm matrix) and overlay '*' where S[i,j] > gamma.
    Diagonal is ignored.
    """
    n = S.shape[0]

    plt.figure()
    plt.imshow(S, aspect="auto")
    plt.colorbar()

    # Overlay stars
    for i in range(n):
        for j in range(n):
            if i == j:
                continue
            if S[i, j] > gamma:
                plt.text(j, i, "*", ha="center", va="center")

    plt.xlabel("j")
    plt.ylabel("i")
    if title is not None:
        plt.title(title)
    else:
        plt.title(f"S with stars (gamma={gamma:.4g})")

    plt.show()

def plot_adj_heatmap(A, title="Adjacency (thresholded)"):
    plt.figure()
    plt.imshow(A, aspect="auto")
    plt.colorbar()
    plt.title(title)
    plt.xlabel("j")
    plt.ylabel("i")
    plt.show()

from matplotlib import animation
from IPython.display import HTML

def animate_S_robust(diag_log, A_true, gamma=None, interval=400,
                     clip_lo=1, clip_hi=99, log_scale=False):
    t_list = diag_log["t"]
    S_list = diag_log["S"]
    n = S_list[0].shape[0]
    has_gamma_series = ("gamma" in diag_log) and (diag_log["gamma"] is not None) and (len(diag_log["gamma"]) == len(t_list))

    pooled = []
    offdiag_mask = ~np.eye(n, dtype=bool)
    for S in S_list:
        vals = S[offdiag_mask]
        vals = vals[np.isfinite(vals)]
        pooled.append(vals)
    pooled = np.concatenate(pooled) if len(pooled) else np.array([0.0])

    vmin = np.percentile(pooled, clip_lo)
    vmax = np.percentile(pooled, clip_hi)
    if vmin == vmax:
        vmax = vmin + 1e-12

    def get_gamma(k):
        if gamma is not None:
            return float(gamma)
        if has_gamma_series:
            return float(diag_log["gamma"][k])
        vals = S_list[k][offdiag_mask]
        return float(np.median(vals))

    def transform(S):
        S_disp = S.copy()
        np.fill_diagonal(S_disp, 0.0)
        if log_scale:
            S_disp = np.log1p(np.maximum(S_disp, 0.0))
        return np.clip(S_disp, vmin if not log_scale else np.log1p(vmin), vmax if not log_scale else np.log1p(vmax))

    true_coords = np.argwhere(A_true == 1)

    fig, ax = plt.subplots()
    S0 = transform(S_list[0])
    im = ax.imshow(S0, aspect="auto")
    fig.colorbar(im, ax=ax)
    ax.set_xlabel("j")
    ax.set_ylabel("i")

    # true edges overlay
    if true_coords.size > 0:
        xs = true_coords[:, 1]
        ys = true_coords[:, 0]
        ax.scatter(xs, ys, marker="o", s=200, facecolors="none",
                   edgecolors="white", linewidths=2.5, zorder=5)

    # ---- UPDATED STAR DRAWING: use scatter instead of text ----
    star_scatter = ax.scatter([], [], marker="*", s=100, edgecolors="white", linewidths=0.5, zorder=7, facecolors="red")

    def update(k):
        S = S_list[k]
        Sd = transform(S)
        im.set_data(Sd)

        g = get_gamma(k)
        ax.set_title(f"S at t={t_list[k]}  gamma={g:.4g}  "
                     f"{'log1p' if log_scale else ''}  clip=[{clip_lo},{clip_hi}]%")

        # stars based on original S
        xs = []
        ys = []
        for i in range(n):
            for j in range(n):
                if i == j:
                    continue
                if S[i, j] > g:
                    xs.append(j)
                    ys.append(i)

        if len(xs) == 0:
            star_scatter.set_offsets(np.empty((0, 2)))
        else:
            star_scatter.set_offsets(np.column_stack([xs, ys]))

        return (im, star_scatter)

    ani = animation.FuncAnimation(fig, update, frames=len(S_list), interval=interval, blit=False)
    plt.close(fig)
    return HTML(ani.to_jshtml())

def true_edge_coords_from_B(B, eps=1e-12):
    return np.argwhere(np.abs(B) > eps)
    
def animate_matrix_sequence(est_log, key, interval=400, clip_lo=1, clip_hi=99, log_scale=False,
                            true_matrix=None, title_prefix=None):
    """
    key in {"A", "Astar"}.
    true_matrix: optional, same shape, used for overlaying true edges.
    """
    t_list = np.array(est_log["t"], dtype=int)
    M_list = est_log[key]

    idx0 = next((k for k,M in enumerate(M_list) if M is not None), None)
    if idx0 is None:
        raise ValueError(f"No snapshots found for est_log['{key}'] (all None).")

    M0 = M_list[idx0]
    n = M0.shape[0]

    # pooled robust limits (over available entries)
    pooled = np.concatenate([M.flatten() for M in M_list if M is not None])
    vmin = np.percentile(pooled, clip_lo)
    vmax = np.percentile(pooled, clip_hi)
    if vmin == vmax:
        vmax = vmin + 1e-12

    def transform(M):
        Md = M.copy()
        np.fill_diagonal(Md, 0.0)
        if log_scale:
            # allow negative weights? If yes, use signed log1p
            Md = np.sign(Md) * np.log1p(np.abs(Md))
            vmin_d = np.sign(vmin) * np.log1p(abs(vmin))
            vmax_d = np.sign(vmax) * np.log1p(abs(vmax))
        else:
            vmin_d, vmax_d = vmin, vmax
        return np.clip(Md, vmin_d, vmax_d), vmin_d, vmax_d

    fig, ax = plt.subplots()
    Md0, vmin_d, vmax_d = transform(M0)
    im = ax.imshow(Md0, aspect="auto")
    im.set_clim(vmin=vmin_d, vmax=vmax_d)
    fig.colorbar(im, ax=ax)

    ax.set_xlabel("parent j")
    ax.set_ylabel("child i")
    prefix = title_prefix if title_prefix is not None else key
    ax.set_title(f"{prefix} at t={t_list[idx0]}")

    # overlay true edges as hollow circles (visible)
    if true_matrix is not None:
        coords = true_edge_coords_from_B(true_matrix)
        if coords.size > 0:
            xs = coords[:, 1]
            ys = coords[:, 0]
            ax.scatter(xs, ys, marker="o", s=160, facecolors="none",
                       edgecolors="white", linewidths=2.0, zorder=5)

    def update(k):
        M = M_list[k]
        if M is None:
            ax.set_title(f"{prefix} at t={t_list[k]} (missing)")
            return (im,)

        Md, _, _ = transform(M)
        im.set_data(Md)
        ax.set_title(f"{prefix} at t={t_list[k]}")
        return (im,)

    ani = animation.FuncAnimation(fig, update, frames=len(M_list), interval=interval, blit=False)
    plt.close(fig)
    return HTML(ani.to_jshtml())

def animate_action_frequency(action_log, n_latents, window=100, interval=300):
    t_list = np.array(action_log["t"], dtype=int)
    a_ids  = np.array(action_log["action_id"], dtype=int)
    policies = np.array(action_log["policy"])

    n_actions = n_latents + 1
    labels = ["∅"] + [f"{{{i}}}" for i in range(n_latents)]

    # --- build frame indices: every 100 until first ucb, then every step ---
    ucb_idxs = np.where(policies == "ucb")[0]
    if len(ucb_idxs) == 0:
        first_ucb = len(policies)  # no ucb at all
    else:
        first_ucb = int(ucb_idxs[0])

    pre = list(range(0, first_ucb, 100))
    # ensure we include the last pre-ucb step so you can see the transition
    if first_ucb > 0 and (first_ucb - 1) not in pre:
        pre.append(first_ucb - 1)

    post = list(range(first_ucb, len(policies)))
    frame_idxs = sorted(set(pre + post))
    if len(frame_idxs) == 0:
        frame_idxs = [0]

    fig, ax = plt.subplots()
    x = np.arange(n_actions)
    bars = ax.bar(x, np.zeros(n_actions))
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.set_ylim(0, 1.0)
    ax.set_ylabel(f"frequency (last {window})")

    def update(frame_pos):
        # frame_pos is an index into frame_idxs
        k = frame_idxs[frame_pos]  # actual round index in logs

        lo = max(0, k - window + 1)
        window_ids = a_ids[lo:k+1]
        counts = np.bincount(window_ids, minlength=n_actions).astype(float)
        freqs = counts / max(len(window_ids), 1)

        for b, h in zip(bars, freqs):
            b.set_height(float(h))

        cum = policies[:k+1]
        n_under = int(np.sum(cum == "under-sample"))
        n_ucb   = int(np.sum(cum == "ucb"))

        ax.set_title(
            f"Action frequency (last {min(window, k+1)} rounds) | "
            f"t={t_list[k]} | undersample={n_under} | ucb={n_ucb}"
        )
        return tuple(bars)

    ani = animation.FuncAnimation(
        fig,
        update,
        frames=len(frame_idxs),
        interval=interval,
        blit=False
    )
    plt.close(fig)
    return HTML(ani.to_jshtml())


In [13]:
rng = np.random.default_rng(0)

intervention_type = "soft"   # "soft" or "hard"
zeta_t = 1.0                 # weight parameter for doubly-weighted regressions
num_mc = 64                  # MC samples for UCB evaluation

reset_acyclicity_fallback_count()
learner = ROCRLLearner(
    n_latent=n_latents,
    d_obs=d_obs,
    gamma=100,
    epsmax=0.25,
    delta=0.05,
)

def true_adj_from_B_in_learner_orientation(B, eps=1e-12):
    # returns A_true where A_true[i,j]=1 means i -> j
    return (np.abs(B.T) > eps).astype(int)

A_true = true_adj_from_B_in_learner_orientation(env.sem.B)
B_true = env.sem.B
Bstar_soft_true = env.sem.B_star_soft
Bstar_hard_true = env.sem.B_star_hard

print("True adjacency (learner orientation i->j):\n", A_true)

# How many samples per action in A0 during forced exploration
m0 = 5   # try 5–20; larger -> better early estimates, slower start

A0_actions = [set()] + [{i} for i in range(n_latents)]

for a in A0_actions:
    kind = "none" if len(a) == 0 else intervention_type
    batch = env.step(num_samples=m0, action=a, kind=kind, return_latents=False)
    X = batch["X"]
    U = batch["U"]
    for s in range(X.shape[0]):
        learner.observe(X[s], float(U[s]), a)

print("Forced exploration done.")
print("Counts in A0:", {k: learner.counts_A0[k] for k in range(n_latents + 1)})

T = 10000
T_0 = m0 * len(A0_actions)

from typing import Set

def id_to_action(a_id: int, n: int) -> Set[int]:
    if a_id == 0:
        return set()
    i = a_id - 1
    if i < 0 or i >= n:
        raise ValueError("action id out of range")
    return {i}

regret = []
actions = []
policies = []   # "under-sample" vs "ucb"
AUE_sizes = []
thresholds = []
did_fit_flags = []

heatmaps = True
show_metrics = False

diag_log = {"t": [], "S": [], "gamma": []}
est_log = {"t": [], "A": [], "Astar": [], "theta": []}
action_log = {"t": [], "action_id": [], "policy": []}
ucb_count = 0

for t in range(T):
    if ucb_count > 10:
        continue
    # 1) Update learner state (H, Zhat, graph, under-explored set; and possibly fit params)
    AUE, thresh, did_fit = learner.learner_update(
        intervention_type=intervention_type,
        zeta_t=zeta_t
    )
        
    # 2) Choose action
    if len(AUE) > 0:
        # Under-sampling: choose uniformly from under-explored A0 actions
        pick_id = int(rng.choice(AUE))        # in {0,1,...,n}
        a = id_to_action(pick_id, n_latents)  # ∅ or {i}
        policy = "under-sample"
    else:
        for i in range(len(learner.pat)):
            print(learner.pat[i])
        est_log["t"].append(len(learner.X_hist))
        est_log["A"].append(None if learner.A is None else np.array(learner.A, dtype=float))
        est_log["Astar"].append(None if learner.Astar is None else np.array(learner.Astar, dtype=float))
        est_log["theta"].append(None if learner.theta is None else np.array(learner.theta, dtype=float))
        print(learner.G_adj)
        print(learner.A)
        ucb_count += 1
        if ucb_count == 1:
            print(f"First UCB at t={t + T_0}")
        else:
            if ucb_count % 5 == 0:
                print(f"UCB at t={t + T_0}")
        # Exploitation/exploration via UCB
        a = learner.choose_action_with_ucb(
            candidate_actions=None,           # defaults to [∅, {0},...,{n-1}]
            intervention_type=intervention_type,
            rng=rng,
            num_mc=num_mc,
        )
        policy = "ucb"
    
    
    action_log["t"].append(len(learner.X_hist))
    action_log["action_id"].append(action_to_id(a, n_latents))  # 0=∅, k={k-1}
    action_log["policy"].append(policy)

    # 3) Interact with environment
    kind = "none" if len(a) == 0 else intervention_type
    batch = env.step(num_samples=1, action=a, kind=kind, return_latents=False)
    x_t = batch["X"][0]
    u_t = float(batch["U"][0])

    

    # ---- diagnostics logging ----
    if heatmaps:
        if t % 100 == 0:
            

            diag_log["t"].append(t + T_0)
            diag_log["gamma"].append(learner.gamma)
            S_t, _ = compute_S_and_adj(learner)
            diag_log["S"].append(S_t)

            if len(AUE) == 0:
                print(f"UCB ")
            
    
    # 4) Observe
    learner.observe(x_t, u_t, a)         

    # 5) Logging
    actions.append(set(a))
    policies.append(policy)
    AUE_sizes.append(len(AUE))
    thresholds.append(int(thresh))
    did_fit_flags.append(bool(did_fit))
    regret.append(best_val - u_t)


print("Acyclicity fallback count:", get_acyclicity_fallback_count())
print("Cumulative regret (approx):", float(np.sum(regret)))
print("UCB rounds:", sum(p == "ucb" for p in policies), "Under-sample rounds:", sum(p == "under-sample" for p in policies))
animate_S_robust(diag_log, A_true, clip_lo=1, clip_hi=99, log_scale=True)



True adjacency (learner orientation i->j):
 [[0 0 1 0 1]
 [0 0 0 1 0]
 [0 0 0 0 1]
 [0 0 0 0 1]
 [0 0 0 0 0]]
Forced exploration done.
Counts in A0: {0: 5, 1: 5, 2: 5, 3: 5, 4: 5, 5: 5}
[]
[]
[]
[]
[]
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
First UCB at t=2826
[]
[]
[]
[]
[]
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
[]
[]
[]
[]
[]
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
[]
[]
[]
[]
[]
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
[]
[]
[]
[]
[]
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]
[[0. 0. 0. 0. 0.]
 [0. 0.

In [86]:
# Animate A (estimate of B)
animate_matrix_sequence(
    est_log, key="A", interval=400, clip_lo=1, clip_hi=99, log_scale=False,
    true_matrix=B_true, title_prefix="A (B_hat)"
)


In [87]:
# Animate Astar (estimate of B*)
animate_matrix_sequence(
    est_log, key="Astar", interval=400, clip_lo=1, clip_hi=99, log_scale=False,
    true_matrix=Bstar_soft_true, title_prefix="Astar (B*_hat, soft)"
)


In [99]:
full_window = len(action_log["t"])
animate_action_frequency(action_log, n_latents, window=100, interval=300)