In [3]:
import torch
import jax
import jax.numpy as jnp
import time
from env_jax import ParticleEnvJAX

n_steps = 100
n_env = 10000
n_particles = 50
dt = 0.01
key = jax.random.PRNGKey(0)


In [4]:
# ---------------- JAX ---------
env_jax = ParticleEnvJAX(key, n_env=n_env, n_particles=n_particles, dt=dt)
start = time.time()
for _ in range(n_steps):
    env_jax.step()
jax.device_get(env_jax.pos)  # sync computation
print("JAX:", time.time() - start, "s")
del env_jax

JAX: 0.6312322616577148 s


In [None]:

# ---------------- PyTorch standard ----------------
torch.cuda.empty_cache()
env = ParticleEnv(n_env=n_env, n_particles=n_particles, device='cuda')
start = time.time()
for _ in range(n_steps):
    env.step(dt=dt)
torch.cuda.synchronize()
print("PyTorch standard:", time.time()-start, "s")
del env; torch.cuda.empty_cache()

# ---------------- PyTorch prealloc ----------------



In [None]:
env_pre = ParticleEnvPreAlloc(n_env=n_env, n_particles=n_particles, device='cuda')
start = time.time()
for _ in range(n_steps):
    env_pre.step(dt=dt)
torch.cuda.synchronize()
print("PyTorch prealloc:", time.time()-start, "s")
del env_pre; torch.cuda.empty_cache()


KeyboardInterrupt: 

In [3]:
import importlib.util

for pkg in ["taichi", "pybullet", "mujoco", "jax", "hoomd"]:
    print(pkg, importlib.util.find_spec(pkg) is not None)

taichi False
pybullet False
mujoco False
jax True
hoomd False
