# Create Environments with Braxlines Composer

[Braxlines Composer](https://github.com/google/brax/blob/main/brax/experimental/composer) allows modular composition of Brax environments. Let's try it out! 



```
# This is formatted as code
```

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/composer.ipynb)

In [None]:
#@title Colab setup and imports
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'TPU'** in the dropdown.
import os
from datetime import datetime
import functools
import pprint
import jax
import jax.numpy as jnp
from IPython.display import HTML, clear_output
import matplotlib.pyplot as plt

try:
  import brax
except ImportError:
  !pip install git+https://github.com/frt03/brax.git@features/amorpheus
  clear_output()
  import brax

from brax.io import html
from brax.experimental.composer import composer
from brax.experimental.composer import component_editor
from brax.experimental.composer import register_default_components
from brax.experimental.braxlines.common import evaluators
from brax.experimental.braxlines.common import logger_utils
from brax.experimental.braxlines.training import ppo
from brax.experimental.braxlines.training.utils import create_modular_fn
from brax.training.networks import make_transformers


register_default_components()

if "COLAB_TPU_ADDR" in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

def show_env(env, mode):
  if mode == 'print_obs':
    pprint.pprint(composer.get_env_obs_dict_shape(env))
  elif mode == 'print_sys':
    pprint.pprint(env.unwrapped.composer.metadata.config_json)
  else:
    jit_env_reset = jax.jit(env.reset)
    state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))
    clear_output(wait=True)
    return HTML(html.render(env.sys, [state.qp]))

In [None]:
#@title Create a custom env
#@markdown See [env_descs.py](https://github.com/google/brax/blob/main/brax/experimental/composer/env_descs.py)
#@markdown for more supported `env_name`.
env_name = 'mod_ant_run' # @param ['mod_ant_run', 'ant_run', 'ant_chase', 'ant_push']
mode = 'print_obs'# @param ['print_obs', 'print_sys', 'viewer']
output_path = '' # @param {type: 'string'}
if output_path:
  output_path = f'{output_path}/{datetime.now().strftime("%Y%m%d")}' 
  output_path = f'{output_path}/{env_name}'
  print(f'Saving outputs to {output_path}')

env_descs = {
    'mod_ant_run':
        dict(
            components=dict(
                ant1=dict(
                    component='mod_ant',
                    pos=(0, 0, 0),
                    reward_fns=dict(
                        goal=dict(
                            reward_type='root_goal',
                            sdcomp='vel',
                            indices=(0, 1),
                            offset=5,
                            target_goal=(4, 0))),
                    score_fns=dict(
                        goal=dict(
                            reward_type='root_goal',
                            sdcomp='vel',
                            indices=(0, 1),
                            target_goal=(4, 0))),
                ),
            )
        ),
}


if env_name in env_descs:
    env_desc = env_descs[env_name]
    env_fn = create_modular_fn(
        env_desc=env_desc)
else:
    env_fn = composer.create_fn(env_name=env_name)
    
env = env_fn()
show_env(env, mode)

In [None]:
#@title Training the custom env
num_timesteps_multiplier =   3# @param {type: 'number'}
skip_training = False # @param {type: 'boolean'}

log_path = output_path
if log_path:
  log_path = f'{log_path}/training_curves.csv'
tab = logger_utils.Tabulator(output_path=log_path,
    append=False)

# We determined some reasonable hyperparameters offline and share them here.
n = num_timesteps_multiplier
train_fn = functools.partial(
    ppo.train,
    num_timesteps=int(50_000_000 * n),
    log_frequency=20, 
    reward_scaling=10,
    episode_length=1000, 
    normalize_observations=True,
    action_repeat=1, 
    unroll_length=5,
    num_minibatches=32, 
    num_update_epochs=4,
    discounting=0.95, 
    learning_rate=3e-4,
    entropy_cost=1e-2, 
    num_envs=2048,
    extra_step_kwargs=False, 
    batch_size=1024,
    make_models_fn=make_transformers
)

times = [datetime.now()]
plotdata = {}
plotpatterns = ['eval/episode_reward', 'eval/episode_score']

def progress(num_steps, metrics, params):
  times.append(datetime.now())
  plotkeys = []
  for key, v in metrics.items():
    assert not jnp.isnan(v), f'{key} {num_steps} NaN'
    plotdata[key] = plotdata.get(key, dict(x=[], y=[]))
    plotdata[key]['x'] += [num_steps]
    plotdata[key]['y'] += [v]
    if any(x in key for x in plotpatterns):
      plotkeys += [key]
  if num_steps > 0:
    tab.add(num_steps=num_steps, **metrics)
    tab.dump()
  clear_output(wait=True)
  num_figs = max(len(plotkeys), 2)
  fig, axs = plt.subplots(ncols=num_figs, figsize=(3.5 * num_figs, 3))
  for i, key in enumerate(plotkeys):
    if key in plotdata:
      axs[i].plot(plotdata[key]['x'], plotdata[key]['y'])
    axs[i].set(xlabel='# environment steps', ylabel=key)
    axs[i].set_xlim([0, train_fn.keywords['num_timesteps']])
  fig.tight_layout()
  plt.show()

if skip_training:
  core_env = env_fn()
  params, inference_fn = ppo.make_params_and_inference_fn(
    core_env.observation_size,
    core_env.action_size,
    normalize_observations=True)
  inference_fn = jax.jit(inference_fn)
else:
  inference_fn, params, _ = train_fn(
    environment_fn=env_fn,
    progress_fn=progress)
  print(f'time to jit: {times[1] - times[0]}')
  print(f'time to train: {times[-1] - times[1]}')
  print(f'Saved logs to {log_path}')



In [None]:
#@title Visualizing a trajectory of the learned inference function
eval_seed = 0  # @param {'type': 'integer'}

env, states = evaluators.visualize_env(
    env_fn=env_fn, inference_fn=inference_fn,
    params=params, batch_size=0,
    seed = eval_seed, output_path=output_path,
    verbose=True,
)
HTML(html.render(env.sys, [state.qp for state in states]))