In [1]:
%load_ext autoreload
%autoreload 2

import os
import shutil

import jax
from omegaconf import OmegaConf
from orbax import checkpoint as ocp
import wandb

from rl.config.config import Config
from rl.mappo import init_config, make_train
from utils import RunnerState, init_or_restore_run




In [3]:
config = OmegaConf.create(Config())
init_config(config)
if config.overwrite:
    shutil.rmtree(config.exp_dir, ignore_errors=True)

options = ocp.CheckpointManagerOptions(
    max_to_keep=2, create=True)
checkpoint_manager = ocp.CheckpointManager(
    config.ckpt_dir, options=options)

rng = jax.random.PRNGKey(config.SEED)
latest_update_step = checkpoint_manager.latest_step()
runner_state, env, scenario, latest_update_step, wandb_run_id, wandb_resume = \
    init_or_restore_run(config, checkpoint_manager, latest_update_step, rng)
latest_update_step = 0 if latest_update_step is None else latest_update_step

os.makedirs(config.exp_dir, exist_ok=True)

run = wandb.init(
    # entity=config.ENTITY,
    project=config.PROJECT,
    tags=["MAPPO", config.MAP_NAME],
    config=OmegaConf.to_container(config),
    mode=config.WANDB_MODE,
    dir=config.exp_dir,
    id=wandb_run_id,
    resume=wandb_resume,
)
wandb_run_id = run.id
with open(os.path.join(config.exp_dir, "wandb_run_id.txt"), "w") as f:
    f.write(wandb_run_id)

{'LR': 0.001, 'BATCH_SIZE': 1024, 'EPOCHS': 10, 'NUM_WORKERS': 4, 'NUM_ENVS': 128, 'NUM_STEPS': 128, 'TOTAL_TIMESTEPS': 10000000.0, 'FC_DIM_SIZE': 128, 'HIDDEN_DIM': 128, 'UPDATE_EPOCHS': 4, 'NUM_MINIBATCHES': 4, 'GAMMA': 0.99, 'GAE_LAMBDA': 0.95, 'CLIP_EPS': 0.2, 'SCALE_CLIP_EPS': False, 'ENT_COEF': 0.0, 'VF_COEF': 0.5, 'MAX_GRAD_NORM': 0.25, 'ACTIVATION': 'relu', 'OBS_WITH_AGENT_ID': True, 'MAP_NAME': '2s3z', 'SEED': 0, 'ANNEAL_LR': False, 'overwrite': False, 'ckpt_freq': 50, 'render_freq': 1, 'WANDB_MODE': 'run', 'ENTITY': '', 'PROJECT': 'waymax_saphne', 'NUM_ACTORS': -1, 'MINIBATCH_SIZE': -1, 'NUM_UPDATES': -1, 'exp_dir': '', 'ckpt_dir': '', 'max_num_objects': 1}
--- TOTAL STARTUP COSTS (make data_iter + env) = 80.29227455891669 s ---
 of which 

--- DATA ITER COSTS = 8.169934153556824e-05 s ---
--- NEXT DATA ITER COSTS = 80.29162715747952 s ---


In [6]:
with jax.disable_jit(False):
    train_jit = jax.jit(make_train(config, checkpoint_manager, env=env, scenario=scenario,
                                    latest_update_step=latest_update_step, wandb_run_id=run.id)) 
    out = train_jit(rng, runner_state=runner_state)

runner_state = out["runner_state"]
n_updates = runner_state[-1]
runner_state: RunnerState = runner_state[0]

TypeError: Scanned function carry input and carry output must have the same pytree structure, but they differ:
  * the input carry runner_state is a <class 'utils.RunnerState'> but the corresponding component of the carry output is a <class 'utils.RunnerState'>, so their Python types differ

Revise the scanned function so that its output is a pair where the first element has the same pytree structure as the first argument.

: 