# Sweep Training

We can perform hyperparameter sweep directly on Colab.



```
# This is formatted as code
```

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/experiment_sweep.ipynb)

In [None]:
#@title Colab setup and imports
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'TPU'** in the dropdown.

#@markdown See [config_utils.py](https://github.com/google/brax/blob/main/brax/experimental/braxlines/common/config_utils.py)
#@markdown for the configuration format.
from datetime import datetime
import importlib
import os
import pprint

from IPython.display import HTML, clear_output

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

agent = 'vgcrl' # @param ['vgcrl', 'irl_smm', 'composer']
if agent == 'composer':
  agent_module = f'brax.experimental.composer.train'
else:
  agent_module = f'brax.experimental.braxlines.{agent}.train'
output_path = '' #@param{'type': 'string'}
start_count = 0 # @param{'type': 'integer'}
end_count = 100000000 # @param{'type': 'integer'}
experiment_path = '' #@param{'type': 'string'}
experiment_path=experiment_path or datetime.now().strftime('%Y%m%d_%H%M%S')
output_path = f'{output_path}/{experiment_path}'

from brax.experimental.braxlines.common import config_utils
train = importlib.import_module(agent_module)

if agent == 'composer':
  config = [
    dict(
        env_name = ['ant_push'],
        desc_edits = {
          'components.cap1.reward_fns.goal.scale': [0.2, 1, 0.5],
          'components.cap1.reward_fns.goal.target_goal': 5, 
        },
        ppo_params = dict(
          num_timesteps=int(2.5 * 1e8),
          reward_scaling=10,
          episode_length=1000,
          normalize_observations=True,
          action_repeat=1,
          unroll_length=5,
          num_minibatches=32,
          num_update_epochs=4,
          discounting=0.95,
          learning_rate=3e-4,
          entropy_cost=1e-2,
          num_envs=2048,
          batch_size=1024,)
    ),
  ]
elif agent == 'vgcrl':
  config = [
    dict(
        env_name = ['ant'],
        obs_indices = 'vel',
        algo_name = ['gcrl', 'diayn', 'cdiayn', 'diayn_full'],
        obs_scale = [5.0],
        seed = [0],
        normalize_obs_for_disc = False,
        evaluate_mi = True,
        evaluate_lgr = False,
        env_reward_multiplier = 0.0,
        spectral_norm = [True],
        ppo_params = dict(
          num_timesteps=int(2.5 * 1e8),
          reward_scaling=10,
          episode_length=1000,
          normalize_observations=True,
          action_repeat=1,
          unroll_length=5,
          num_minibatches=32,
          num_update_epochs=4,
          discounting=0.95,
          learning_rate=3e-4,
          entropy_cost=1e-2,
          num_envs=2048,
          batch_size=1024,)
    ),
  ]
elif agent == 'irl_smm':
  config = [
  dict(
      env_name = ['ant'],
      obs_indices = 'vel',
      target_num_modes = [2],
      obs_scale = [8], 
      reward_type = ['gail2', 'mle', 'airl'],
      seed = [0],
      normalize_obs_for_disc = False,
      evaluate_dist =False,
      env_reward_multiplier = 0.0,
      spectral_norm = [True],
      ppo_params = dict(
        num_timesteps=int(1.5 * 1e8),
        reward_scaling=10,
        episode_length=1000,
        normalize_observations=True,
        action_repeat=1,
        unroll_length=5,
        num_minibatches=32,
        num_update_epochs=4,
        discounting=0.95,
        learning_rate=3e-4,
        entropy_cost=1e-2,
        num_envs=2048,
        batch_size=1024,)
  ),
]


prefix_keys = config_utils.list_keys_to_expand(config)
for c, p in zip(config, prefix_keys):
  c.update(dict(prefix_keys=p))
config_count= config_utils.count_configuration(config)
start_count= max(start_count, 0)
end_count = min(end_count, sum(config_count))
print(f'Loaded agent_module={agent_module}')
print(f'Loaded {sum(config_count)}({config_count}) experiment configurations')
print(f'Set start_count={start_count}, end_count={end_count}')
print(f'Set prefix_keys={prefix_keys}')
print(f'Set output_dir={output_path}')

if "COLAB_TPU_ADDR" in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

In [None]:
#@title Launch experiments
ignore_errors = False # @param{'type': 'boolean'}

return_dict = {}
for i in range(start_count, end_count):
  c, _= config_utils.index_configuration(config, index=i, count=config_count)
  task_name = config_utils.get_compressed_name_from_keys(
      c, train.TASK_KEYS)
  experiment_name = config_utils.get_compressed_name_from_keys(
      c, c.pop('prefix_keys'))
  output_dir = f'{output_path}/{task_name}/{experiment_name}'
  print(f'[{i+1}/{sum(config_count)}] Starting experiment...')
  print(f'\t config: {pprint.pformat(c, indent=2)}')
  print(f'\t output_dir={output_dir}')
  print(f'\t previous time_to_jit={return_dict.get("time_to_train", None)}')
  print(f'\t previous time_to_train={return_dict.get("time_to_jit", None)}')
  return_dict = {}
  if ignore_errors:
    try:
      train.train(c, output_dir=output_dir, return_dict=return_dict)
    except Exception as e:
      print(f'[{i+1}/{sum(config_count)}] FAILED experiment {e.__class__.__name__}: {e.message}')
  else:
    train.train(c, output_dir=output_dir, return_dict=return_dict)