# `gymnax`: Classic Gym Environments in JAX
### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/gymnax/blob/main/examples/getting_started.ipynb)

Welcome to `gymnax`, the one stop shop for fast classic Reinforcement Learning environments powered by JAX.

## Basic API: `gymnax.make()`, `env.reset()`, `env.step()`

In [1]:
import os

# Set the number of (emulated) host devices
#num_devices = 8
#os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={num_devices}"

import jax
import jax.numpy as jnp
import gymnax
import networkx as nx
import json
import dataclasses
from xlron.environments.env_funcs import *
from xlron.environments.rsa import *
from xlron.environments.vone import *

num_devices = jax.device_count()

jax.default_device = jax.devices()[0]

jax.device_count(), jax.devices()

(1, [CpuDevice(id=0)])

In [2]:
rng = jax.random.PRNGKey(0)
rng, key_init, key_reset, key_policy, key_step = jax.random.split(rng, 5)

#env = RSAEnv(key_init, graph, env_params)

#env = VONEEnv(key_init, graph, env_params)
#env, env_params = make_rsa_env()
env, env_params = make_vone_env(topology_name="nsfnet", max_node_resources=2, min_slots=2, max_slots=4, load=60)

# Inspect default environment settings
env_params
#dataclasses.fields(env_params)

TypeError: make_vone_env() got an unexpected keyword argument 'topology_name'

In [None]:
obs, state = env.reset(key_reset, env_params)
#obs, state = env.reset(key_reset, env_params.path_link_array)
obs, state

In [None]:
action = env.action_space(env_params).sample(key_policy)
#print(jnp.squeeze(action).shape)
#print(action.shape)
print(action)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
n_obs, n_state, reward, done

`gymnax` provides fully functional environment dynamics that can leverage the full power of JAX's function transformations. E.g. one common RL use-case the parallel rollout of multiple workers. Using a `vmap` across random seeds (one per worker) allows us to implement such a parallelization on a single machine:

In [None]:
vmap_reset = jax.vmap(env.reset, in_axes=(0, None))
vmap_step = jax.vmap(env.step, in_axes=(0, 0, 0, None))

num_envs = 10
vmap_keys = jax.random.split(rng, num_envs)

obs, state = vmap_reset(vmap_keys, env_params)
if isinstance(env, VONEEnv):
    n_obs, n_state, reward, done, _ = vmap_step(vmap_keys, state, (jnp.zeros(num_envs), jnp.zeros(num_envs), jnp.zeros(num_envs)), env_params)
elif isinstance(env, RSAEnv):
    n_obs, n_state, reward, done, _ = vmap_step(vmap_keys, state, jnp.zeros(num_envs), env_params)
print(n_obs.shape)

Similarly, you can also choose to `pmap` across rollout workers ("actors") across multiple devices:

In [None]:
pmap_reset = jax.pmap(env.reset, in_axes=(0, None), static_broadcasted_argnums=(1))
pmap_step = jax.pmap(env.step, in_axes=(0, 0, 0, None), static_broadcasted_argnums=(3))


pmap_keys = jax.random.split(rng, num_devices)
obs, state = pmap_reset(pmap_keys, env_params)
if isinstance(env, VONEEnv):
    n_obs, n_state, reward, done, _ = pmap_step(pmap_keys, state, (jnp.zeros(num_devices), jnp.zeros(num_devices), jnp.zeros(num_devices)), env_params)
elif isinstance(env, RSAEnv):
    n_obs, n_state, reward, done, _ = pmap_step(pmap_keys, state, jnp.zeros(num_devices), env_params)
print(n_obs.shape)

The code above has executed each worker-specific environment transition on a separate device, but we can also chain `vmap` and `pmap` to execute multiple workers on a single device and at the same time across multiple devices:

In [None]:
map_reset = jax.pmap(vmap_reset, in_axes=(0, None), static_broadcasted_argnums=(1))
map_step = jax.pmap(vmap_step, in_axes=(0, 0, 0, None), static_broadcasted_argnums=(3))

map_keys = jnp.tile(vmap_keys, (num_devices, 1, 1))
obs, state = map_reset(map_keys, env_params)
if isinstance(env, VONEEnv):
    n_obs, n_state, reward, done, _ = map_step(map_keys, state, (jnp.zeros((num_devices, num_envs)), jnp.zeros((num_devices, num_envs)), jnp.zeros((num_devices, num_envs))), env_params)
elif isinstance(env, RSAEnv):
    n_obs, n_state, reward, done, _ = map_step(map_keys, state, jnp.zeros((num_devices, num_envs)), env_params)
print(n_obs.shape)

We can now easily leverage massive accelerator parallelism to churn through millions/billions of environment transitions when training 'sentient' agents. Note that in the code snippet above we have executed 4 times the same 8 environment workers, since we tiled/repeated the same key across the device axis. In general `pmap`-ing will require you to pay special attention to the shapes of the arrays that come out your operations.

## Jitted Episode Rollouts via `lax.scan`

Let's now walk through an example of using `gymnax` with one of the common neural network libraries to parametrize a simple policy: `flax`. 

In [None]:
from flax import linen as nn


class MLP(nn.Module):
    """Simple ReLU MLP."""

    num_hidden_units: int
    num_hidden_layers: int
    num_output_units: int

    @nn.compact
    def __call__(self, x, rng):
        for l in range(self.num_hidden_layers):
            x = nn.Dense(features=self.num_hidden_units)(x)
            x = nn.relu(x)
        x = nn.Dense(features=self.num_output_units)(x)
        return x
    

model = MLP(64, 2, 3)
policy_params = model.init(rng, jnp.zeros(2128), None)
#policy_params = model.init(rng, jnp.zeros(319904), None)
# obs = env.get_obs(state)
# model_action = model.apply(policy_params, obs, key_step)
# sample_action = env.action_space(env_params).sample(key_policy)
# model_action, sample_action

In [None]:
def rollout(rng_input, policy_params, env_params, steps_in_episode):
    """Rollout a jitted gymnax episode with lax.scan."""
    # Reset the environment
    rng_reset, rng_episode = jax.random.split(rng_input)
    obs, state = env.reset(rng_reset, env_params)

    def policy_step(state_input, tmp):
        """lax.scan compatible step transition in jax env."""
        obs, state, policy_params, rng = state_input
        rng, rng_step, rng_net = jax.random.split(rng, 3)
        action = jnp.squeeze(model.apply(policy_params, obs, rng_net)) # Squeeze works for RSAEnv!
        next_obs, next_state, reward, done, _ = env.step(
          rng_step, state, action, env_params
        )
        carry = [next_obs, next_state, policy_params, rng]
        return carry, [obs, action, reward, next_obs, done]

    # Scan over episode step loop
    _, scan_out = jax.lax.scan(
      policy_step,
      [obs, state, policy_params, rng_episode],
      (),
      steps_in_episode
    )
    # Return masked sum of rewards accumulated by agent in episode
    obs, action, reward, next_obs, done = scan_out
    return obs, action, reward, next_obs, done

In [None]:
# Jit-Compiled Episode Rollout
jit_rollout = jax.jit(rollout, static_argnums=(2, 3,))
obs, action, reward, next_obs, done = jit_rollout(rng, policy_params, env_params, 10000)
obs.shape, reward.shape, jnp.sum(reward)

Again, you can wrap this `rollout` function with the magic of JAX and for all implemented RL environments. But we also provide a simple that does so for you:

In [None]:
import jax
import jax.numpy as jnp
import gymnax
from functools import partial
from typing import Optional
from gymnax.environments import environment
import timeit
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

# TODO: Add RNN forward with init_carry/hidden
# TODO: Add pmap utitlities if multi-device
# TODO: Use as backend in `GymFitness` or keep separated?


class RolloutWrapper(object):
    def __init__(
        self,
        model_forward=None,
        env: environment.Environment = None,
        num_env_steps: Optional[int] = None,
        env_params: EnvParams = None,
    ):
        """Wrapper to define batch evaluation for generation parameters."""
        self.env = env
        # Define the RL environment & network forward function
        self.env_params = env_params
        self.model_forward = model_forward

        if num_env_steps is None:
            self.num_env_steps = self.env_params.max_requests
        else:
            self.num_env_steps = num_env_steps

    @partial(jax.jit, static_argnums=(0,))
    def population_rollout(self, rng_eval, policy_params):
        """Reshape parameter vector and evaluate the generation."""
        # Evaluate population of nets on gymnax task - vmap over rng & params
        pop_rollout = jax.vmap(self.batch_rollout, in_axes=(None, 0)) # Same rng different parmas
        return pop_rollout(rng_eval, policy_params)

    @partial(jax.jit, static_argnums=(0,))
    def batch_rollout(self, rng_eval, policy_params):
        """Evaluate a generation of networks on RL/Supervised/etc. task."""
        # vmap over different MC fitness evaluations for single network
        batch_rollout = jax.vmap(self.single_rollout, in_axes=(0, None))
        return batch_rollout(rng_eval, policy_params)

    def pmap_batch_rollout(self, rng_eval, policy_params, device_count, batch_size):
        """Evaluate a generation of networks on RL/Supervised/etc. task."""
        # vmap over different MC fitness evaluations for single network
        # Broadcast the params to each device and env (identical for each)
        broadcast = lambda x: jnp.broadcast_to(x, (device_count, batch_size) + x.shape)
        params = jax.tree_map(broadcast, policy_params)  # broadcast to cores and batch.
        # Reshape the rngs so that each env on each device has a unique rng
        reshape = lambda x: x.reshape((device_count, batch_size) + x.shape[1:])
        pmap_rngs = reshape(jnp.stack(rng_eval))  # add dimension to pmap over.
        # In-axes not specified so that everything is split (in_axes=None) instead of broadcast (in_axes=0)
        batch_rollout = jax.vmap(self.single_rollout, axis_name="envs")
        batch_rollout = jax.pmap(batch_rollout, axis_name="devices")
        return batch_rollout(pmap_rngs, params)

    @partial(jax.jit, static_argnums=(0,))
    def single_rollout(self, rng_input, policy_params):
        """Rollout a pendulum episode with lax.scan."""
        # Reset the environment
        rng_reset, rng_episode = jax.random.split(rng_input)
        obs, state = self.env.reset(rng_reset, self.env_params)

        def policy_step(state_input, tmp):
            """lax.scan compatible step transition in jax env."""
            obs, state, policy_params, rng, cum_reward, valid_mask = state_input
            rng, rng_step, rng_net = jax.random.split(rng, 3)
            if self.model_forward is not None:
                action = jnp.squeeze(self.model_forward(policy_params, obs, rng_net))
            else:
                action = self.env.action_space(self.env_params).sample(rng_net)
            next_obs, next_state, reward, done, _ = self.env.step(
                rng_step, state, action, self.env_params
            )
            new_cum_reward = cum_reward + reward * valid_mask
            new_valid_mask = valid_mask * (1 - done)
            carry = [
                next_obs,
                next_state,
                policy_params,
                rng,
                new_cum_reward,
                new_valid_mask,
            ]
            y = [obs, action, reward, next_obs, done]
            return carry, y

        # Scan over episode step loop
        carry_out, scan_out = jax.lax.scan(
            policy_step,
            [
                obs,
                state,
                policy_params,
                rng_episode,
                jnp.array([0.0]),
                jnp.array([1.0]),
            ],
            (),
            self.num_env_steps,
        )
        # Return the sum of rewards accumulated by agent in episode rollout
        obs, action, reward, next_obs, done = scan_out
        cum_return = carry_out[-2]
        return obs, action, reward, next_obs, done, cum_return

    @property
    def input_shape(self):
        """Get the shape of the observation."""
        rng = jax.random.PRNGKey(0)
        obs, state = self.env.reset(rng, self.env_params)
        return obs.shape

class TimeIt():
    def __init__(self, tag, frames=None):
        self.tag = tag
        self.frames = frames

    def __enter__(self):
        self.start = timeit.default_timer()
        return self

    def __exit__(self, *args):
        self.elapsed_secs = timeit.default_timer() - self.start
        msg = self.tag + (': Elapsed time=%.2fs' % self.elapsed_secs)
        if self.frames:
            msg += ', FPS=%.2e' % (self.frames / self.elapsed_secs)
        print(msg)


In [None]:
# Define rollout manager for env
manager = RolloutWrapper(model.apply, env=env, env_params=env_params)

# Simple single episode rollout for policy
with TimeIt(tag='COMPILATION'):
    manager.single_rollout(rng, policy_params)  # compiles

#num_frames = cores_count * iterations * rollout_len * batch_size
with TimeIt(tag='EXECUTION', frames=env_params.max_requests):
    # Run compiled func
    obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(rng, policy_params)
    reward.block_until_ready()


In [None]:
obs.shape, action.shape, reward.shape, next_obs.shape, done.shape, cum_ret.shape

In [None]:
# What about pmapped rollouts?
# Create a random key for every env on every device
num_envs_per_device = 10
rng_envs = jax.random.split(rng, num_envs_per_device*num_devices)

with TimeIt(tag='COMPILATION'):
    manager.pmap_batch_rollout(rng_envs, policy_params, num_devices, num_envs_per_device)  # compiles

with TimeIt(tag='EXECUTION', frames=env_params.max_requests*num_envs_per_device*num_devices):
    obs, action, reward, next_obs, done, cum_ret = manager.pmap_batch_rollout(
        rng_envs, policy_params, num_devices, num_envs_per_device
    )
    reward.block_until_ready()
#obs.shape, action.shape, reward.shape, next_obs.shape, done.shape, cum_ret.shape
obs[0][0]

In [None]:
#print(jax.lax.pmean(jax.lax.pmean(cum_ret, axis_name="envs"), axis_name="devices"))

In [None]:
# Multiple rollouts for same network (different rng, e.g. eval)
num_envs = 10
rng_batch = jax.random.split(rng, num_envs)

with TimeIt(tag='COMPILATION'):
    manager.batch_rollout(rng_batch, policy_params)  # compiles

with TimeIt(tag='EXECUTION', frames=env_params.max_requests*num_envs):
    obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout(
        rng_batch, policy_params
    )
    reward.block_until_ready()
obs.shape, action.shape, reward.shape, next_obs.shape, done.shape, cum_ret.shape

In [None]:
# Multiple rollouts for different networks + rng (e.g. for ES)
batch_params = jax.tree_map(  # Stack parameters or use different
    lambda x: jnp.tile(x, (2, 1)).reshape(2, *x.shape), policy_params
)

with TimeIt(tag='COMPILATION'):
    manager.population_rollout(rng_batch, batch_params)  # compiles

with TimeIt(tag='EXECUTION', frames=env_params.max_requests*num_envs):
    obs, action, reward, next_obs, done, cum_ret = manager.population_rollout(
        rng_batch, batch_params
    )
    reward.block_until_ready()
obs.shape, action.shape, reward.shape, next_obs.shape, done.shape, cum_ret.shape

## Visualizing Episode Rollouts

In [None]:
from gymnax.visualize import Visualizer

state_seq, reward_seq = [], []
rng, rng_reset = jax.random.split(rng)
obs, env_state = env.reset(rng_reset, env_params)
t_counter = 0
while True:
    state_seq.append(env_state)
    rng, rng_act, rng_step = jax.random.split(rng, 3)
    action = env.action_space(env_params).sample(rng_act)
    next_obs, next_env_state, reward, done, info = env.step(
        rng_step, env_state, action, env_params
    )
    reward_seq.append(reward)
    t_counter += 1
    if done or t_counter >= 50:
        break
    else:
        obs = next_obs
        env_state = next_env_state

cum_rewards = jnp.cumsum(jnp.array(reward_seq))
vis = Visualizer(env, env_params, state_seq, cum_rewards)
vis.animate(f"anim.gif")

In [None]:
from IPython.display import Image
Image(url='anim.gif')

Benchmarking against CPU VoneEnv

In [None]:
import gymnasium as gym
from heuristics import *
import pandas as pd
import numpy as np
import os
from pathlib import Path
import argparse
import env.envs
import yaml
from datetime import datetime
from util_funcs import make_env, parse_args, choose_schedule
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy

num_envs = 50
num_steps = 25

args = parse_args()
conf = yaml.safe_load(Path("./config/agent_conus.yaml").read_text())
#print(args)
#print(conf)

env = [
    make_env(conf["env_name"], seed=i, **conf["env_args"])
    for i in range(num_envs)
]
#env = DummyVecEnv(env)
env = SubprocVecEnv(env, start_method="fork")

agent_kwargs = dict(
    verbose=0,
    device="cuda",
    gamma=args.gamma,
    learning_rate=choose_schedule(args.lr_schedule, args.learning_rate),
    gae_lambda=args.gae_lambda,
    n_steps=args.n_steps,
    batch_size=10,
    clip_range=choose_schedule(args.clip_range_schedule, args.clip_range),
    clip_range_vf=choose_schedule(args.clip_range_vf_schedule, args.clip_range_vf),
    n_epochs=args.n_epochs,
    ent_coef=args.ent_coef,
    policy_kwargs={"net_arch": dict(pi=[64, 64], vf=[64, 64])},
)
if args.multistep_masking:
    agent_kwargs.update(
        multistep_masking=args.multistep_masking,
        multistep_masking_attr=args.multistep_masking_attr,
        multistep_masking_n_steps=args.multistep_masking_n_steps,
        action_interpreter=args.action_interpreter,
    )
agent_args = ("MultiInputPolicy", env)

model = PPO(*agent_args, **agent_kwargs)

action = np.array([[0,0,0,0,0,0]]*num_envs)
action = np.tile(env.action_space.sample(), (num_envs, 1))
#print(action)
#print(action.shape)

#logger = logging.getLogger(__name__)
#logger.setLevel(logging.WARN)

print('start')
with TimeIt(tag='EXECUTION', frames=num_envs*num_steps):
    for i in range(num_steps):
        action = env.action_space.sample()
        action = np.tile(env.action_space.sample(), (num_envs, 1))
        x = env.step(action)
        # eva = evaluate_policy(
        #     model,
        #     env,
        #     n_eval_episodes=1,
        # )
