# Sweep Training

We can perform hyperparameter sweep directly on Colab.



```
# This is formatted as code
```

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/experiment_sweep.ipynb)

In [None]:
#@title Colab setup and imports
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'TPU'** in the dropdown.

#@markdown See [config_utils.py](https://github.com/google/brax/blob/main/brax/experimental/braxlines/common/config_utils.py)
#@markdown for the configuration format.
#@markdown See [experiments/](https://github.com/google/brax/blob/main/brax/experimental/braxlines/experiments)
#@markdown for the example configurations.
from datetime import datetime
import importlib
import os
import pprint
from IPython.display import HTML, clear_output

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

experiment = 'custom'# @param ['custom', 'mimax_sweep', 'ant_push_sweep', 'dmin_sweep']
output_path = '/tmp' #@param{'type': 'string'}
start_count = 0 # @param{'type': 'integer'}
end_count = 100000000 # @param{'type': 'integer'}
experiment_path = '' #@param{'type': 'string'}
experiment_path=experiment_path or datetime.now().strftime('%Y%m%d_%H%M%S')
output_path = f'{output_path}/{experiment_path}'

custom_agent_module = f'brax.experimental.braxlines.vgcrl.train'
custom_config = [
    dict(
        env_name = ['ant'],
        obs_indices = 'vel',
        algo_name = ['gcrl', 'diayn', 'cdiayn', 'diayn_full'],
        obs_scale = [5.0],
        seed = [0],
        normalize_obs_for_disc = False,
        evaluate_mi = False,
        evaluate_lgr = False,
        env_reward_multiplier = 0.0,
        spectral_norm = [True],
        ppo_params = dict(
          num_timesteps=int(2.5 * 1e8),
          reward_scaling=10,
          episode_length=1000,
          normalize_observations=True,
          action_repeat=1,
          unroll_length=5,
          num_minibatches=32,
          num_update_epochs=4,
          discounting=0.95,
          learning_rate=3e-4,
          entropy_cost=1e-2,
          num_envs=2048,
          batch_size=1024,)
    ),
  ]

from brax.experimental.braxlines.common import config_utils
from brax.experimental.braxlines.experiments import load_experiment
from brax.experimental.braxlines.experiments import run_experiment
if experiment == 'custom':
  config = custom_config
  agent_module = custom_agent_module
else:
  agent_module, config = load_experiment(experiment)
agent_module = importlib.import_module(agent_module)

if "COLAB_TPU_ADDR" in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

In [None]:
#@title Launch experiments
ignore_errors = False # @param{'type': 'boolean'}
run_experiment(
  experiment_name=experiment, output_path=output_path,
  start_count=start_count, end_count=end_count,
  ignore_errors=ignore_errors,
  agent_module=agent_module, config=config)