When running Jax training loops, how do I make ballpark estimates of the GPU memory usage? GPU memory is one of the big bottlenecks when it comes to end-to-end deep RL training using Jax.

(Specifically for the case of `train_single_task_pushworld_all_hparam_lr.py`)

All the different types of data:

- PushWorld puzzles
- Transitions
- Model weights


In [1]:
import os
import shutil
import time
from dataclasses import asdict, dataclass, field
from functools import partial
from typing import List, Optional

import imageio
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
import orbax
import pyrallis
import wandb
from flax.jax_utils import replicate, unreplicate
from flax.training import orbax_utils
from flax.training.train_state import TrainState
import xminigrid.envs.pushworld as pushworld
from xminigrid.envs.pushworld.benchmarks import BenchmarkAll

from train_single_task_pushworld_all_hparam_lr import TrainConfig
from train_single_task_pushworld_all_hparam_lr import make_states
from utils_pushworld_all import Transition, calculate_gae, rollout

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = TrainConfig(
    benchmark_id="level0_transformed_all",
    total_timesteps=1_000_000,
    num_envs=4096,
    num_steps=100,
    # num_train=1000,
    # num_test=100,
    enable_bf16=True,
)


Num devices: 1, Num updates: 2


In [3]:
benchmark = pushworld.load_all_benchmark(config.benchmark_id)

puzzle_rng = jax.random.key(config.puzzle_seed)
train_rng, test_rng = jax.random.split(puzzle_rng)

if config.num_train is not None:
    assert config.num_train <= benchmark.num_train_puzzles(), (
        "num_train is larger than num train available in benchmark"
    )
    perm = jax.random.permutation(train_rng, benchmark.num_train_puzzles())
    idxs = perm[: config.num_train]
    benchmark = benchmark.replace(train_puzzles=benchmark.train_puzzles[idxs])
else:
    config.num_train = benchmark.num_train_puzzles()

if config.num_test is not None:
    assert config.num_test <= benchmark.num_test_puzzles(), "num_test is larger than num test available in benchmark"
    perm = jax.random.permutation(test_rng, benchmark.num_test_puzzles())
    idxs = perm[: config.num_test]
    benchmark = benchmark.replace(test_puzzles=benchmark.test_puzzles[idxs])
else:
    config.num_test = benchmark.num_test_puzzles()

In [4]:
def pytree_megabytes(tree):
    def leaf_bytes(x):
        if isinstance(x, (jnp.ndarray, np.ndarray)):
            return x.size * x.dtype.itemsize
        return 0

    nbytes = sum(jtu.tree_leaves(jtu.tree_map(leaf_bytes, tree)))
    return nbytes / 1e6

### PushWorld puzzles


In [5]:
pytree_megabytes(benchmark)

14.2208

### Transitions


In [6]:
# COLLECT TRAJECTORIES
def _env_step(runner_state, _):
    rng, train_state, prev_timestep, prev_action, prev_reward, prev_hstate = runner_state

    # SELECT ACTION
    rng, _rng = jax.random.split(rng)
    dist, value, hstate = train_state.apply_fn(
        train_state.params,
        {
            # [batch_size, seq_len=1, ...]
            "obs": prev_timestep.observation[:, None],
            "prev_action": prev_action[:, None],
            "prev_reward": prev_reward[:, None],
        },
        prev_hstate,
    )
    action, log_prob = dist.sample_and_log_prob(seed=_rng)
    # squeeze seq_len where possible
    action, value, log_prob = action.squeeze(1), value.squeeze(1), log_prob.squeeze(1)

    # STEP ENV
    timestep = jax.vmap(env.step, in_axes=(None, 0, 0))(puzzle_env_params, prev_timestep, action)
    transition = Transition(
        done=timestep.last(),
        action=action,
        value=value,
        reward=timestep.reward,
        log_prob=log_prob,
        obs=prev_timestep.observation,
        prev_action=prev_action,
        prev_reward=prev_reward,
    )
    runner_state = (rng, train_state, timestep, action, timestep.reward, hstate)
    return runner_state, transition


rng, env, env_params, benchmark, init_hstate, network, network_params = make_states(config)


def linear_schedule(count):
    frac = 1.0 - (count // (config.num_minibatches * config.update_epochs)) / config.num_updates
    return config.lr * frac


tx = optax.chain(
    optax.clip_by_global_norm(config.max_grad_norm),
    optax.inject_hyperparams(optax.adam)(learning_rate=linear_schedule, eps=1e-8),  # eps=1e-5
)
train_state = TrainState.create(apply_fn=network.apply, params=network_params, tx=tx)

rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config.num_envs_per_device)
puzzle_env_params = env_params.replace(benchmark=benchmark)
timestep = jax.vmap(env.reset, in_axes=(None, 0))(puzzle_env_params, reset_rng)
prev_action = jnp.zeros(config.num_envs_per_device, dtype=jnp.int32)
prev_reward = jnp.zeros(config.num_envs_per_device)


runner_state = (rng, train_state, timestep, prev_action, prev_reward, init_hstate)
# transitions: [seq_len, batch_size, ...]
runner_state, transitions = jax.lax.scan(_env_step, runner_state, None, config.num_steps)

In [7]:
pytree_megabytes(transitions)

1320.1408

In [8]:
rng, train_state, timestep, prev_action, prev_reward, hstate = runner_state
# calculate value of the last step for bootstrapping
_, last_val, _ = train_state.apply_fn(
    train_state.params,
    {
        "obs": timestep.observation[:, None],
        "prev_action": prev_action[:, None],
        "prev_reward": prev_reward[:, None],
    },
    hstate,
)
advantages, targets = calculate_gae(transitions, last_val.squeeze(1), config.gamma, config.gae_lambda)

In [9]:
pytree_megabytes(advantages)

0.8192

In [10]:
pytree_megabytes(targets)

0.8192

In [11]:
init_hstate = init_hstate[None, :]

rng, _rng = jax.random.split(rng)
permutation = jax.random.permutation(_rng, config.num_envs_per_device)
# [seq_len, batch_size, ...]
batch = (init_hstate, transitions, advantages, targets)
# [batch_size, seq_len, ...], as our model assumes
batch = jtu.tree_map(lambda x: x.swapaxes(0, 1), batch)

# init_hstate: num_envs, 1, 1, 1024
# advantages: num_envs, 100

In [12]:
pytree_megabytes(batch)

1330.167808

In [20]:
shuffled_batch = jtu.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch)

In [21]:
pytree_megabytes(shuffled_batch)

1341.014016

In [22]:
minibatches = jtu.tree_map(lambda x: jnp.reshape(x, (config.num_minibatches, -1) + x.shape[1:]), shuffled_batch)

In [23]:
pytree_megabytes(minibatches)

1341.014016

In [11]:
rng, _rng = jax.random.split(rng)
permutation = jax.random.permutation(_rng, config.num_envs_per_device)
# [seq_len, batch_size, ...]
batch = (init_hstate, transitions, advantages, targets)
# [batch_size, seq_len, ...], as our model assumes
batch = jtu.tree_map(lambda x: x.swapaxes(0, 1), batch)

shuffled_batch = jtu.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch)
# [num_minibatches, minibatch_size, ...]
minibatches = jtu.tree_map(lambda x: jnp.reshape(x, (config.num_minibatches, -1) + x.shape[1:]), shuffled_batch)

In [16]:
eval_reset_rng = jax.random.key(config.eval_seed)
eval_test_rng, eval_train_rng = jax.random.split(eval_reset_rng)

eval_train_reset_rng = jax.random.split(eval_train_rng, num=config.num_train)
eval_train_puzzles = benchmark.get_train_puzzles()
eval_train_stats = jax.vmap(rollout, in_axes=(0, None, None, 0, None, None, None))(
    eval_train_reset_rng,
    env,
    puzzle_env_params,
    eval_train_puzzles,
    train_state,
    # TODO: make this as a static method mb?
    jnp.zeros((1, config.rnn_num_layers, config.rnn_hidden_dim), dtype=jnp.bfloat16 if config.enable_bf16 else None),
    1,
)

In [17]:
pytree_megabytes(eval_train_stats)

0.256