[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/to_torch/notebooks/pytorch.ipynb)

# Using Brax with Gym + PyTorch

In [5]:
# Imports:
import time
import torch
from functools import partial
import time

import gym
import tqdm
import numpy as np

from brax.envs.to_torch import JaxToTorchWrapper
from brax.envs import _envs, create_gym_env

COLAB = "google.colab" in str(get_ipython())

if COLAB:
    # Detect if running in Colab or not.
    from jax.tools import colab_tpu
    # configure jax to run on tpu:
    colab_tpu.setup_tpu()

# Registering the Brax envs in Gym (so we can use `gym.make` as usual):
for env_name, env_class in _envs.items():
    env_id = f"brax_{env_name}-v0"
    entry_point = partial(create_gym_env, env_name=env_name)
    if env_id not in gym.envs.registry.env_specs:
        print(f"Registring brax's '{env_name}' env under id '{env_id}'.")
        gym.register(env_id, entry_point=entry_point)

## Benchmarking the performance of the Pytorch'ed Brax gym Env:

In [6]:
# Number of parallel environments
batch_size = 1024  #@param { type:"slider", min:0, max:4096, step: 1 }

# Number of steps to take in the batched env.
n_steps = 1000

# Simple utility function for benchmarking.
_times = [time.time()]
def tick(name: str = ""):
    _times.append(time.time())
    elapsed = times[-1] - times[-2]
    if name:
        print(f"Time for {name}: {elapsed}")
    return elapsed

if torch.cuda.is_available():
    # BUG: (@lebrice): Getting a weird "CUDA error: out of memory" RuntimeError during
    # JIT, which can be "fixed" by first creating a dummy cuda tensor!
    v = torch.ones(1, device="cuda")

CUDA = torch.cuda.is_available()

env = gym.make("brax_halfcheetah-v0", batch_size=batch_size)
tick("Creating the env")

env = JaxToTorchWrapper(env)
tick("Wrapping the env")

obs = env.reset()  # this can be relatively slow (~10 secs)
tick("First reset") 

obs, reward, done, info = env.step(env.action_space.sample())
tick("First step")  # this can be relatively slow (~10 secs)

obs, reward, done, info = env.step(env.action_space.sample())
tick("Second step")  # this can be relatively slow (~10 secs)

_times.clear()
_times.append(time.time())

# Create a Progress bar. NOTE: This Pbar effectively shows the step rate as well!
pbar = tqdm.tqdm(range(n_steps), unit_scale=batch_size)
for i in pbar:
    # NOTE: Could use the `action_space` property like so even with CUDA, but we'd need
    # to move the numpy array from CPU to GPU/TPU, which would kind-of defeat the
    # purpose of this demo! Here we instead create a Tensor which is already on the GPU.
    if not CUDA:
        action = env.action_space.sample()
    else:
        action = torch.rand(env.action_space.shape, device="cuda") * 2 - 1
    obs, rewards, done, info = env.step(action)
    tick()


elapsed_times = [times[i+1] - times[i] for i in range(len(times)-1)]

time_for_n_steps_avg = np.mean(elapsed_times)
time_for_n_steps_std = np.std(elapsed_times)

frequency =  1 / time_for_n_steps_avg
effective_frequency = (batch_size or 1) * frequency

print(f"Device used: {obs.device}")
print(f"Number of parallel environments: {batch_size}")
print(f"Average time per batched step:  {time_for_n_steps_avg} ± {time_for_n_steps_std} seconds")
print(f"Frequency (after first two steps): ~{frequency:.3f} batched steps / second.")
print(f"Effective Frequency (after first two steps): ~{effective_frequency:.3f} steps / second.")

Time for Creating the env: 0.0038437843322753906
Time for Wrapping the env: 0.0038437843322753906
Time for First reset: 0.0038437843322753906
Time for First step: 0.0038437843322753906
Time for Second step: 0.0038437843322753906


100%|██████████| 1024000/1024000 [00:03<00:00, 289150.23it/s]

Device used: cuda:0
Number of parallel environments: 1024
Average time per batched step:  0.0039951517013366335 ± 0.00026153933497828803 seconds
Frequency: ~250.303 batched steps / second.
Effective Frequency: ~256310.668 steps / second.



