# Create Environments with Braxlines Composer

[Braxlines Composer](https://github.com/google/brax/blob/main/brax/experimental/composer) allows modular composition of Brax environments. Let's try it out! 



```
# This is formatted as code
```

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/composer.ipynb)

In [None]:
#@title Colab setup and imports
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'TPU'** in the dropdown.
from datetime import datetime
import functools
import os
import pprint
import jax
import jax.numpy as jnp
from IPython.display import HTML, clear_output
import matplotlib.pyplot as plt

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax.io import html
from brax.experimental.composer import composer
from brax.experimental.composer.training import mappo
from brax.experimental.braxlines import experiments
from brax.experimental.braxlines.common import evaluators
from brax.experimental.braxlines.common import logger_utils
from brax.experimental.braxlines.training import ppo
from brax.experimental.braxlines import experiments

if "COLAB_TPU_ADDR" in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

def show_env(env, mode):
  if mode == 'print_obs':
    pprint.pprint(composer.get_env_obs_dict_shape(env.unwrapped))
  elif mode == 'print_sys':
    pprint.pprint(env.unwrapped.composer.metadata.config_json)
  elif mode == 'print_step':
    jit_env_reset = jax.jit(env.reset)
    jit_env_step = jax.jit(env.step)
    state0 = jit_env_reset(rng=jax.random.PRNGKey(seed=0))
    state1 = jit_env_step(state0, jnp.zeros((env.action_size,)))
    print(f'obs0={state0.obs.shape}') 
    print(f'obs1={state1.obs.shape}') 
    print(f'rew0={state0.reward}') 
    print(f'rew1={state1.reward}')
    print(f'action0={(env.action_size,)}') 
  else:
    jit_env_reset = jax.jit(env.reset)
    state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))
    clear_output(wait=True)
    return HTML(html.render(env.sys, [state.qp]))

In [None]:
# @title List registerd environments
#@markdown See [composer/envs](https://github.com/google/brax/blob/main/brax/experimental/composer/envs)
#@markdown for registered `env_name`'s.
env_list = composer.list_env()
print(f'{len(env_list)} registered envs, e.g. {env_list[:5]}...')

In [None]:
#@title Create a custom env
env_name = 'sumo' # @param ['squidgame', 'sumo', 'follow', 'chase', 'pro_ant_run', 'ant_run', 'ant_chase', 'ant_push']
env_params = None # @param{'type': 'raw'}
mode = 'viewer'# @param ['print_step', 'print_obs', 'print_sys', 'viewer']
output_path = '' # @param {type: 'string'}
if output_path:
  output_path = f'{output_path}/{datetime.now().strftime("%Y%m%d")}' 
  output_path = f'{output_path}/{env_name}'
  print(f'Saving outputs to {output_path}')

# check supported params
env_params = env_params or {}
supported_params, support_kwargs = composer.inspect_env(env_name=env_name)
assert support_kwargs or all(
    k in supported_params for k in env_params
  ), f'invalid {env_params} for {supported_params}' 

# create env
env_fn = composer.create_fn(env_name=env_name,
  **(env_params or {}))
env = env_fn()
show_env(env, mode)

In [None]:
#@title Training the custom env
num_timesteps_multiplier =   3# @param {type: 'number'}
seed = 0 # @param{type: 'integer'}
skip_training = False # @param {type: 'boolean'}

log_path = output_path
if log_path:
  log_path = f'{log_path}/training_curves.csv'
tab = logger_utils.Tabulator(output_path=log_path,
    append=False)

ppo_lib = mappo if env.is_multiagent else ppo
ppo_params = experiments.defaults.get_ppo_params(
    'ant', num_timesteps_multiplier)
train_fn = functools.partial(ppo_lib.train, **ppo_params)

times = [datetime.now()]
plotpatterns = ['eval/episode_reward', 'eval/episode_score']

progress, _, _, _ = experiments.get_progress_fn(
    plotpatterns, times, tab=tab, max_ncols=5,
    xlim=[0, train_fn.keywords['num_timesteps']],
    pre_plot_fn = lambda : clear_output(wait=True),
    post_plot_fn = plt.show)

if skip_training:
  action_size = (env.group_action_shapes if 
    env.is_multiagent else env.action_size)
  params, inference_fn = ppo_lib.make_params_and_inference_fn(
    env.observation_size, action_size,
    normalize_observations=True)
  inference_fn = jax.jit(inference_fn)
else:
  inference_fn, params, _ = train_fn(
    environment_fn=env_fn, seed=seed,
    extra_step_kwargs=False, progress_fn=progress)
  print(f'time to jit: {times[1] - times[0]}')
  print(f'time to train: {times[-1] - times[1]}')
  print(f'Saved logs to {log_path}')



In [None]:
#@title Visualizing a trajectory of the learned inference function
eval_seed = 0  # @param {'type': 'integer'}
batch_size =  0# @param {type: 'integer'}

env, states = evaluators.visualize_env(
    env_fn=env_fn, inference_fn=inference_fn,
    params=params, batch_size=batch_size,
    seed = eval_seed, output_path=output_path,
    verbose=True,
)
HTML(html.render(env.sys, [state.qp for state in states]))

In [None]:
#@title Plot information of the trajectory
experiments.plot_states(states[1:-1], max_ncols=5)
plt.show()