<a href="https://colab.research.google.com/github/google/brax/blob/main/notebooks/pytorch.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab">

# Using Brax with PyTorch

Brax integrates seamlessly with [PyTorch](https://pytorch.org/) via a [Gym interface](https://github.com/openai/gym) without sacrificing any performance.  In this notebook we demonstrate how to operate Brax by passing in PyTorch arrays.

Author: @lebrice (Fabrice Normandin)

In [1]:
#@title Set up Colab dependencies
#@markdown ### ⚠️ PLEASE NOTE:
#@markdown Brax and PyTorch can share GPUs but not TPUs.  To access a GPU runtime, select Runtime > Change Runtime Type and set 'Hardware Accelerator' to `GPU`.
#@markdown
#@markdown Using TPU is OK too, but then PyTorch should run on CPU.

import os

try:
  import brax
except ImportError:
  from IPython.display import clear_output 
  !pip install git+https://github.com/google/brax.git@main
  clear_output()

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

In [2]:
import time
import torch
from functools import partial

import gym
import tqdm
import numpy as np

from brax.envs.to_torch import JaxToTorchWrapper
from brax.envs import _envs, create_gym_env

# Registering the Brax envs in Gym (so we can use `gym.make` as usual):
for env_name, env_class in _envs.items():
    env_id = f"brax_{env_name}-v0"
    entry_point = partial(create_gym_env, env_name=env_name)
    if env_id not in gym.envs.registry.env_specs:
        print(f"Registring brax's '{env_name}' env under id '{env_id}'.")
        gym.register(env_id, entry_point=entry_point)

# Simple utility function for benchmarking below.
def tick(name: str = ""):
    global _times
    _times.append(time.time())
    elapsed = _times[-1] - _times[-2]
    if name:
        print(f"Time for {name}: {elapsed}", flush=True)
    return elapsed

CUDA = torch.cuda.is_available()
if CUDA:
    # BUG: (@lebrice): Getting a weird "CUDA error: out of memory" RuntimeError
    # during JIT, which can be "fixed" by first creating a dummy cuda tensor!
    v = torch.ones(1, device="cuda")


Registring brax's 'fetch' env under id 'brax_fetch-v0'.
Registring brax's 'ant' env under id 'brax_ant-v0'.
Registring brax's 'grasp' env under id 'brax_grasp-v0'.
Registring brax's 'halfcheetah' env under id 'brax_halfcheetah-v0'.
Registring brax's 'humanoid' env under id 'brax_humanoid-v0'.
Registring brax's 'ur5e' env under id 'brax_ur5e-v0'.
Registring brax's 'reacher' env under id 'brax_reacher-v0'.
Registring brax's 'reacherangle' env under id 'brax_reacherangle-v0'.


## Benchmarking the performance of the Pytorch'ed Brax gym Env:

In [4]:
_times = [time.time()]

# Number of parallel environments
batch_size = 2048  #@param { type:"slider", min:0, max:4096, step: 1 }

# Number of steps to take in the batched env.
n_steps = 1000

env = gym.make("brax_halfcheetah-v0", batch_size=batch_size)
tick("creating the env")

env = JaxToTorchWrapper(env)
tick("wrapping the env")

obs = env.reset()  # this can be relatively slow (~10 secs)
tick("first reset") 

obs, reward, done, info = env.step(env.action_space.sample())
tick("first step")  # this can be relatively slow (~10 secs)

obs, reward, done, info = env.step(env.action_space.sample())
tick("second step")  # this can be relatively slow (~10 secs)

_times.clear()
_times.append(time.time())

# create a progress bar that displays step rate
pbar = tqdm.tqdm(range(n_steps), unit_scale=batch_size)
for i in pbar:
    # for GPU, we create the pytorch data directly on device in order to avoid
    # expensive cross-device copying
    if CUDA:
        action = torch.rand(env.action_space.shape, device="cuda") * 2 - 1
    else:
        action = env.action_space.sample()
    obs, rewards, done, info = env.step(action)
    tick()


elapsed_times = [_times[i+1] - _times[i] for i in range(len(_times)-1)]

time_per_step_avg = np.mean(elapsed_times)
time_per_step_std = np.std(elapsed_times)

frequency =  1 / time_per_step_avg
effective_frequency = (batch_size or 1) * frequency
if (isinstance(obs, torch.Tensor)):
  device = obs.device
else:
  device = obs.device_buffer.device()

print(f"\nDevice used: {device}")
print(f"Number of parallel environments: {batch_size}")
print(f"Average time per batched step:  {time_per_step_avg} ± {time_per_step_std} seconds")
print(f"Frequency (after first two steps): ~{frequency:.3f} batched steps / second.")
print(f"Effective Frequency (after first two steps): ~{effective_frequency:.3f} steps / second.")

Time for creating the env: 7.492677688598633
Time for wrapping the env: 0.0012292861938476562
Time for first reset: 8.5595543384552
Time for first step: 4.626008033752441
Time for second step: 4.190889835357666


100%|██████████| 2048000/2048000 [00:05<00:00, 364426.28it/s]


Device used: cuda:0
Number of parallel environments: 2048
Average time per batched step:  0.005624113082885743 ± 0.0010372177569917425 seconds
Frequency (after first two steps): ~177.806 batched steps / second.
Effective Frequency (after first two steps): ~364146.305 steps / second.



