<a href="https://colab.research.google.com/gist/epignatelli/da5c0f63b8c4a189ae261232121ae446/navix_profiling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [38]:
!pip install -q git+https://github.com/epignatelli/navix gymnasium minigrid

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m925.5/925.5 kB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.1/103.1 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m74.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [57]:
import jax
import jax.numpy as jnp
import navix as nx


N_TIMESTEPS = 10_000


def profile_navix(seed):
    env = nx.environments.Room(16, 16, 8)
    key = jax.random.PRNGKey(seed)
    timestep = env.reset(key)
    actions = jax.random.randint(key, (N_TIMESTEPS,), 0, 6)

    def body_fun(carry, x):
        timestep = carry
        action = x
        timestep = env.step(timestep, jnp.asarray(action))
        return timestep, ()

    return jax.lax.scan(body_fun, timestep, jnp.asarray(actions, dtype=jnp.int32))[0]


f = jax.jit(jax.vmap(profile_navix))

In [None]:
# running 10_000 seeds in parallel
%timeit -n 5 -r 1 f(jnp.arange(10_000)).state.grid.block_until_ready()

In [None]:
import gymnasium as gym
import minigrid
import random


def profile_minigrid(seed):
    env = gym.make("MiniGrid-Empty-16x16-v0", render_mode=None)
    observation, info = env.reset(seed=42)
    for _ in range(N_TIMESTEPS):
        action = random.randint(0, 4)
        observation, reward, terminated, truncated, info = env.step(action)

        if terminated or truncated:
            observation, info = env.reset()
    env.close()
    return observation

In [None]:
# running 1 seed
%timeit -n 5 -r 1 profile_minigrid(0)