In [None]:
import dill

import matplotlib.pyplot as plt
import numpy as np

In [None]:
base_path = "/home/bryanpu1/projects/aaai_2026/scaling_jax/results"

algo_name = "bandit_ad"
# No sink + weight decay 1e-4
run_name = "adamw-06-06-25_09_56_26-8373a959-1e98-4fe9-bedb-bd9a3e42a6b5"

In [None]:
data = dill.load(open(f"{base_path}/{algo_name}/{run_name}/eval_info.dill", "rb"))

In [None]:
data.keys()

In [None]:
if algo_name.startswith("xland"):
    for k, v in data.items():
        plt.figure()
        plt.plot(range(len(v)), v, label=k)
        plt.title(f"Distribution of returns for {k}")
        plt.xlabel("Episode")
        plt.ylabel("Return")
        plt.show()
elif algo_name.startswith("bandit"):
    (num_envs, num_eps) = data["episode_lengths"].shape
    xrange = range(num_eps)

    num_cols = 5
    num_rows = int(np.ceil(num_envs / 5))

    fig, axes = plt.subplots(
        num_rows,
        num_cols,
        figsize=(5 * num_cols, 5 * num_rows),
        layout="constrained",
    )

    axes = axes.flatten()

    for env_i in range(num_envs):
        axes[env_i].set_title(f"Regret for Env {env_i}")
        axes[env_i].set_xlabel("Rounds")
        axes[env_i].set_ylabel("Regret")

        for task_i in range(int(np.ceil(num_eps / data["switch_freq"]))):
            rews = np.cumsum(
                data["episode_returns"][env_i, task_i * data["switch_freq"]:(task_i + 1) * data["switch_freq"]],
                axis=-1,
            )
            xrange = np.arange(min(data["switch_freq"], len(rews)))
            opt = (xrange + 1) * np.max(data["env_params"][env_i, task_i], axis=-1)
            regret = opt - rews
            axes[env_i].plot(task_i * data["switch_freq"] + xrange, regret)

            axes[env_i].axvline(
                x=task_i * data["switch_freq"],
                label=f"Task i" if task_i == 0 and env_i == 0 else "",
                linestyle="--",
                color="black"
            )

    fig.legend(
        bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
        loc="lower center",
        ncols=5,
        borderaxespad=0.0,
        frameon=True,
        fontsize="8", 
    )
    fig.show()

In [None]:
if algo_name.startswith("xland"):
    assert 0

In [None]:
print(data["action_counts"])

## Check policy stochasticity

In [None]:
import optax
import jax

In [None]:
action_probs = jax.nn.softmax(data["logits"])
entropies = optax.safe_softmax_cross_entropy(data["logits"], action_probs)

In [None]:
action_probs[0, :52, 0]

In [None]:
np.max(entropies), np.min(entropies)

In [None]:
(num_envs, num_eps) = data["episode_lengths"].shape
xrange = range(num_eps)

num_cols = 5
num_rows = int(np.ceil(num_envs / 5))

fig, axes = plt.subplots(
    num_rows,
    num_cols,
    figsize=(5 * num_cols, 5 * num_rows),
    layout="constrained",
)

axes = axes.flatten()

for env_i in range(num_envs):
    axes[env_i].set_title(f"Policy Entropy for Env {env_i}")
    axes[env_i].set_xlabel("Rounds")
    axes[env_i].set_ylabel("Entropy")

    for task_i in range(int(np.ceil(num_eps / data["switch_freq"]))):
        entropy = entropies[env_i, task_i * data["switch_freq"]:(task_i + 1) * data["switch_freq"]]
        xrange = np.arange(min(data["switch_freq"], len(entropy)))
        axes[env_i].plot(task_i * data["switch_freq"] + xrange, entropy)

        axes[env_i].axvline(
            x=task_i * data["switch_freq"],
            label=f"Task i" if task_i == 0 and env_i == 0 else "",
            linestyle="--",
            color="black"
        )

fig.legend(
    bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
    loc="lower center",
    ncols=5,
    borderaxespad=0.0,
    frameon=True,
    fontsize="8", 
)
fig.show()