In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"

from gymnax.environments.classic_control.cartpole import EnvParams, CartPole
from rejax import get_algo
from rejax.evaluate import evaluate
from tqdm.notebook import tqdm

import _pickle as pickle
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import yaml


In [2]:
config_path = "./configs/custom/cartpole.yaml"

with open(config_path, "r") as f:
    config = yaml.safe_load(f.read())

In [3]:
algo_name = "ppo"
config = config[algo_name]
num_seeds = 500

num_files = 5
starting_seed_id = 0


env = CartPole()
config["env"] = env
max_steps_in_episode = 200

In [4]:
config

{'agent_kwargs': {'activation': 'tanh'},
 'num_envs': 1,
 'num_steps': 100,
 'num_epochs': 5,
 'num_minibatches': 5,
 'learning_rate': 0.00075,
 'max_grad_norm': 0.5,
 'total_timesteps': 250000,
 'eval_freq': 5000,
 'gamma': 0.99,
 'gae_lambda': 0.95,
 'clip_eps': 0.2,
 'ent_coef': 0.01,
 'vf_coef': 0.5,
 'buffer_size': 250000,
 'env': <gymnax.environments.classic_control.cartpole.CartPole at 0x7680c4ac1f90>}

In [None]:
for file_i in tqdm(range(num_files)):
    seed_id = starting_seed_id + file_i

    key = jax.random.PRNGKey(seed_id)
    keys = jax.random.split(key, num_seeds)

    gravities = jax.random.uniform(key, shape=(num_seeds,))

    algo_cls = get_algo(algo_name)
    algo = jax.vmap(
        lambda gravity: algo_cls.create(
            **config,
            env_params=EnvParams(
                gravity=gravity,
                max_steps_in_episode=max_steps_in_episode,
            ),
        )
    )(
        gravities
    )


    def eval_callback(algo, ts, rng):
        act = algo.make_act(ts)
        return evaluate(act, rng, env, algo.env_params, 50, max_steps_in_episode)

    def get_returns(algo, ts, rng):
        eval_info = eval_callback(algo, ts, rng)
        return eval_info.length, eval_info.return_

    algo = algo.replace(
        eval_callback=get_returns
    )

    # Train
    vmap_train = jax.jit(jax.vmap(algo_cls.train, in_axes=(0, 0)))
    ts, (_, returns) = vmap_train(algo, keys)
    returns.block_until_ready()

    # "Expert data" from last PPO iteration
    eval_info = jax.vmap(eval_callback)(algo, ts, keys)

    # Save data
    pickle.dump(
        {
            "buffer_info": {k: v for k, v in ts.store_buffer.__dict__.items() if k != "data"},
            "data": {k: np.array(v) for k, v in ts.store_buffer.data._asdict().items()},
            "algorithm": {
                "algo": algo_name,
                **{k: v for k, v in config.items() if k != "env"},
            },
            "env": type(config["env"]).__name__,
            "env_params": np.array(gravities),
            "observation_space": env.observation_space(EnvParams()),
            "action_space": env.action_space(EnvParams()),
            "expert_data": eval_info.trajectory
        },
        open("{}/learning_hist-cartpole-num_tasks_{}-seed_{}-{}.pkl".format(
            "/home/bryanpu1/projects/aaai_2026/data/cartpole",
            num_seeds,
            seed_id,
            algo_name,
            ), "wb"),
    )


  0%|          | 0/5 [00:00<?, ?it/s]

  return lax_numpy.astype(self, dtype, copy=copy, device=device)


In [None]:
assert 0

In [None]:
regrets = (np.arange(ts.store_buffer.data.reward.shape[1])[None] + 1) * np.max(env_params, axis=-1, keepdims=True) - np.cumsum(ts.store_buffer.data.reward, axis=-1)

for regret in regrets[:5]:
    xrange = np.arange(len(regret))
    plt.plot(xrange, regret)
plt.show()


In [None]:
for env_returns, env_param in zip(returns[:2], env_params):
    xrange = np.arange(len(env_returns)) * config["eval_freq"]
    regret = np.max(env_param, axis=-1) - env_returns
    print(regret.shape)
    mean = np.mean(regret, axis=-1)
    std = np.std(regret, axis=-1) / np.sqrt(regret.shape[-1])
    plt.plot(xrange, mean)
    plt.fill_between(xrange, mean - std, mean + std, alpha=0.2)
plt.show()


In [None]:
for env_returns in returns[:10]:
    xrange = np.arange(len(env_returns)) * config["eval_freq"]
    mean = np.mean(env_returns, axis=-1)
    std = np.std(env_returns, axis=-1) / np.sqrt(env_returns.shape[-1])
    plt.plot(xrange, mean)
    plt.fill_between(xrange, mean - std, mean + std, alpha=0.2)
plt.show()
