[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/gym_pytorch.ipynb)

# Using Brax with Gym + PyTorch

Author: @lebrice (Fabrice Normandin)

In [1]:
# Imports:
import time
import torch
from functools import partial
import time

import gym
import tqdm
import numpy as np

from brax.envs.to_torch import JaxToTorchWrapper
from brax.envs import _envs, create_gym_env

# Detect if running in Colab or not.
COLAB = "google.colab" in str(get_ipython())

if COLAB:
    from jax.tools import colab_tpu
    # configure jax to run on tpu:
    colab_tpu.setup_tpu()

# Registering the Brax envs in Gym (so we can use `gym.make` as usual):
for env_name, env_class in _envs.items():
    env_id = f"brax_{env_name}-v0"
    entry_point = partial(create_gym_env, env_name=env_name)
    if env_id not in gym.envs.registry.env_specs:
        print(f"Registring brax's '{env_name}' env under id '{env_id}'.")
        gym.register(env_id, entry_point=entry_point)

# Simple utility function for benchmarking below.
def tick(name: str = ""):
    global _times
    _times.append(time.time())
    elapsed = _times[-1] - _times[-2]
    if name:
        print(f"Time for {name}: {elapsed}")
    return elapsed

CUDA = torch.cuda.is_available()
if CUDA:
    # BUG: (@lebrice): Getting a weird "CUDA error: out of memory" RuntimeError during
    # JIT, which can be "fixed" by first creating a dummy cuda tensor!
    v = torch.ones(1, device="cuda")


Registring brax's 'fetch' env under id 'brax_fetch-v0'.
Registring brax's 'ant' env under id 'brax_ant-v0'.
Registring brax's 'grasp' env under id 'brax_grasp-v0'.
Registring brax's 'halfcheetah' env under id 'brax_halfcheetah-v0'.
Registring brax's 'humanoid' env under id 'brax_humanoid-v0'.
Registring brax's 'ur5e' env under id 'brax_ur5e-v0'.
Registring brax's 'reacher' env under id 'brax_reacher-v0'.
Registring brax's 'reacherangle' env under id 'brax_reacherangle-v0'.


## Benchmarking the performance of the Pytorch'ed Brax gym Env:

In [2]:
_times = [time.time()]

# Number of parallel environments
batch_size = 1024  #@param { type:"slider", min:0, max:4096, step: 1 }

# Number of steps to take in the batched env.
n_steps = 1000



env = gym.make("brax_halfcheetah-v0", batch_size=batch_size)
tick("Creating the env")

env = JaxToTorchWrapper(env)
tick("Wrapping the env")

obs = env.reset()  # this can be relatively slow (~10 secs)
tick("First reset") 

obs, reward, done, info = env.step(env.action_space.sample())
tick("First step")  # this can be relatively slow (~10 secs)

obs, reward, done, info = env.step(env.action_space.sample())
tick("Second step")  # this can be relatively slow (~10 secs)

_times.clear()
_times.append(time.time())

# Create a Progress bar. NOTE: This Pbar effectively shows the step rate as well!
pbar = tqdm.tqdm(range(n_steps), unit_scale=batch_size)
for i in pbar:
    # NOTE: Could use the `action_space` property like so even with CUDA, but we'd need
    # to move the numpy array from CPU to GPU/TPU, which would kind-of defeat the
    # purpose of this demo! Here we instead create a Tensor which is already on the GPU.
    if not CUDA:
        action = env.action_space.sample()
    else:
        action = torch.rand(env.action_space.shape, device="cuda") * 2 - 1
    obs, rewards, done, info = env.step(action)
    tick()


elapsed_times = [_times[i+1] - _times[i] for i in range(len(_times)-1)]

time_per_step_avg = np.mean(elapsed_times)
time_per_step_std = np.std(elapsed_times)

frequency =  1 / time_per_step_avg
effective_frequency = (batch_size or 1) * frequency

print(f"Device used: {obs.device}")
print(f"Number of parallel environments: {batch_size}")
print(f"Average time per batched step:  {time_per_step_avg} ± {time_per_step_std} seconds")
print(f"Frequency (after first two steps): ~{frequency:.3f} batched steps / second.")
print(f"Effective Frequency (after first two steps): ~{effective_frequency:.3f} steps / second.")

Time for Creating the env: 10.294265508651733
Time for Wrapping the env: 0.00015664100646972656
Time for First reset: 3.978239059448242
Time for First step: 2.3907294273376465
Time for Second step: 2.2790026664733887


100%|██████████| 1024000/1024000 [00:03<00:00, 296106.44it/s]

Device used: cuda:0
Number of parallel environments: 1024
Average time per batched step:  0.0034599435329437257 ± 0.00036434224775974217 seconds
Frequency (after first two steps): ~289.022 batched steps / second.
Effective Frequency (after first two steps): ~295958.587 steps / second.



