In [None]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [None]:
import dill

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

sns.set(style="ticks")
sns.set_theme()

In [None]:
base_path = "/home/bryanpu1/projects/aaai_2026/scaling_jax/results"

algo_name = "bandit_ad"
run_name = "adamw-06-09-25_10_17_25-dd55f7aa-c8f9-49f9-b58c-a8af9d8e6d69"

algo_name = "bandit_dpt"
run_name = "adamw-06-09-25_10_12_16-0accc7c0-d4f9-42ae-b70e-8b3c590d90e1"
# run_name = "adamw-06-10-25_14_11_46-4a5795d8-b591-4af2-96b1-1d247753dc83"
run_name = "adamw-06-10-25_22_15_07-a934d939-7314-4334-925d-4751dfed506c"

algo_name = "ns_bandit_ad"
run_name = "adamw-06-10-25_17_34_41-9f12ef85-c567-4d61-a1f2-dec104e22df1"

In [None]:
data = dill.load(open(f"{base_path}/{algo_name}/{run_name}/eval_info.dill", "rb"))

In [None]:
data.keys()

In [None]:
switch_freq = data["eval_config"]["switch_freq"]

In [None]:

(num_envs, num_eps) = data["episode_lengths"].shape
xrange = range(num_eps)

num_cols = 5
num_rows = int(np.ceil(num_envs / 5))

fig, axes = plt.subplots(
    num_rows,
    num_cols,
    figsize=(5 * num_cols, 5 * num_rows),
    layout="constrained",
)

try:
    axes = axes.flatten()
except:
    axes = [axes]

for env_i in range(num_envs):
    axes[env_i].set_title(f"Regret for Env {env_i}")
    axes[env_i].set_xlabel("Rounds")
    axes[env_i].set_ylabel("Regret")

    for task_i in range(int(np.ceil(num_eps / switch_freq))):
        rews = np.cumsum(
            data["episode_returns"][env_i, task_i * switch_freq:(task_i + 1) * switch_freq],
            axis=-1,
        )
        xrange = np.arange(min(switch_freq, len(rews)))
        opt = np.max(data["env_params"][env_i, task_i], axis=-1)
        regret = opt - rews / (xrange + 1)
        axes[env_i].plot(task_i * switch_freq + xrange, regret)

        axes[env_i].axvline(
            x=task_i * switch_freq,
            label=f"Task i" if task_i == 0 and env_i == 0 else "",
            linestyle="--",
            color="black"
        )
    axes[env_i].set_ylim(0.0, 1.0)

fig.legend(
    bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
    loc="lower center",
    ncols=5,
    borderaxespad=0.0,
    frameon=True,
    fontsize="8", 
)
fig.show()

In [None]:
print(data["action_counts"])

In [None]:
print(np.argmax(data["env_params"], axis=-1))

In [None]:
data["env_params"]

## Check policy stochasticity

In [None]:
import optax
import jax

In [None]:
action_probs = jax.nn.softmax(data["logits"])
entropies = optax.safe_softmax_cross_entropy(data["logits"], action_probs)

In [None]:
np.max(entropies), np.min(entropies)

In [None]:
(num_envs, num_eps) = data["episode_lengths"].shape
xrange = range(num_eps)

num_cols = 5
num_rows = int(np.ceil(num_envs / 5))

fig, axes = plt.subplots(
    num_rows,
    num_cols,
    figsize=(5 * num_cols, 5 * num_rows),
    layout="constrained",
)

axes = axes.flatten()

for env_i in range(num_envs):
    axes[env_i].set_title(f"Policy Entropy for Env {env_i}")
    axes[env_i].set_xlabel("Rounds")
    axes[env_i].set_ylabel("Entropy")

    for task_i in range(int(np.ceil(num_eps / switch_freq))):
        entropy = entropies[env_i, task_i * switch_freq:(task_i + 1) * switch_freq]
        xrange = np.arange(min(switch_freq, len(entropy)))
        axes[env_i].plot(task_i * switch_freq + xrange, entropy)

        axes[env_i].axvline(
            x=task_i * switch_freq,
            label=f"Task i" if task_i == 0 and env_i == 0 else "",
            linestyle="--",
            color="black"
        )

fig.legend(
    bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
    loc="lower center",
    ncols=5,
    borderaxespad=0.0,
    frameon=True,
    fontsize="8", 
)
fig.show()

In [None]:
data["env_params"]