In [1]:
import numpy as np
from brax import envs
from brax.envs.wrappers import gym as gym_wrapper
from time import perf_counter

batch_size = 512
env = envs.create("halfcheetah", batch_size=batch_size)
env = gym_wrapper.VectorGymWrapper(env)

action_low = -1
action_high = 1

t = perf_counter()
env.reset()
env.step(np.random.uniform(low=action_low, high=action_high, size=env.action_space.shape))
print(f"Time to jit: {perf_counter() - t}")

for i in range(20):
    t = perf_counter()
    action = np.random.uniform(low=action_low, high=action_high, size=env.action_space.shape)
    obs, reward, done, info = env.step(action)
    print(f"Time to step: {perf_counter() - t}")
    print(f"Time per step: {(perf_counter() - t) / batch_size}")
    print("")


2023-10-23 12:45:30.694314: W external/xla/xla/service/gpu/nvptx_compiler.cc:673] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.52). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


Time to jit: 11.63927029899969
Time to step: 0.023467693999918993
Time per step: 4.5879714843266584e-05

Time to step: 0.02212195299944142
Time per step: 4.324233007579892e-05

Time to step: 0.023632163000002038
Time per step: 4.619971289088198e-05

Time to step: 0.023825573998692562
Time per step: 4.657916796801942e-05

Time to step: 0.020513552000920754
Time per step: 4.010707421997495e-05

Time to step: 0.02183429199976672
Time per step: 4.274037499740757e-05

Time to step: 0.028414775999408448
Time per step: 5.555128124967723e-05

Time to step: 0.01963811200039345
Time per step: 3.841189843711845e-05

Time to step: 0.020617651000065962
Time per step: 4.033840039241454e-05

Time to step: 0.021142352999959257
Time per step: 4.134600195371263e-05

Time to step: 0.01887301099850447
Time per step: 3.689138867102315e-05

Time to step: 0.02134436199958145
Time per step: 4.1726390627161436e-05

Time to step: 0.036306120000517694
Time per step: 7.095253906541643e-05

Time to step: 0.0322568

In [2]:
import gymnasium as gym

batch_size = 512
env = gym.vector.AsyncVectorEnv([
    lambda: gym.make("HalfCheetah-v4") for _ in range(batch_size)])
observation, info = env.reset()

for _ in range(20):
    t = perf_counter()
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action)
    print(f"Time to step: {perf_counter() - t}")
    print(f"Time per step: {(perf_counter() - t) / batch_size}")
    print("")



Time to step: 0.030944047000957653
Time per step: 6.071587500144915e-05

Time to step: 0.020289121001042076
Time per step: 3.96688105475107e-05

Time to step: 0.018549810998592875
Time per step: 3.6266583983035616e-05

Time to step: 0.01891779100151325
Time per step: 3.6986935548100064e-05

Time to step: 0.019148780998875736
Time per step: 3.743820507651208e-05

Time to step: 0.018943430000945227
Time per step: 3.703669921861774e-05

Time to step: 0.019184560998837696
Time per step: 3.75074648424345e-05

Time to step: 0.01872877000096196
Time per step: 3.6618789064135626e-05

Time to step: 0.01814410099905217
Time per step: 3.547437695061717e-05

Time to step: 0.01915184099925682
Time per step: 3.744972851649209e-05

Time to step: 0.023946753000927856
Time per step: 4.6817916018682126e-05

Time to step: 0.02065223200042965
Time per step: 4.044324609608907e-05

Time to step: 0.01930884100147523
Time per step: 3.7751857423273805e-05

Time to step: 0.02063681199979328
Time per step: 4.036