In [None]:
%pip install git+https://github.com/mauricef/halite-iv-jax.git@d11c4634deb0d90a292b3fb2578b7a00146eac14

# Performance

Two random agents, 21x21 board with 400 episode steps, time is per generated
episode.

- PYTHON GPU 1.5 s
- JAX GPU 27.3 s
- JAX BATCH GPU 25 ms (batch size 1000 episodes)

## Python Performance

In [None]:
from kaggle_environments import make
from kaggle_environments.envs.halite.helpers import board_agent, ShipAction, ShipyardAction

environment = make("halite")
_ = environment.reset(2)

Loading environment football failed: No module named 'gfootball'


In [None]:
%timeit _ = environment.run(["random", "random"])

1 loop, best of 5: 1.42 s per loop


In [None]:
environment.render(mode="ipython", width=500, height=450)

## JAX Performance

In [None]:
from jax import random, jit, partial
import jax.numpy as np

from kaggle_environments import make
from kaggle_environments.envs.halite.helpers import ShipAction, ShipyardAction

from halite_jax import environment_to_initial_state, generate_episode, Action, \
episode_to_environment, random_agent

rng = random.PRNGKey(42)

environment = make("halite")
configuration = environment.configuration
compiled_random_agent = partial(random_agent, configuration)
agents = [compiled_random_agent, compiled_random_agent]
compiled_generate_episode = jit(partial(generate_episode, configuration, agents))
_ = environment.reset(2)
initial_state = environment_to_initial_state(environment)

In [None]:
rng, r = random.split(rng)
_ = compiled_generate_episode(initial_state, r)[0].halite.block_until_ready()

In [None]:
episode = compiled_generate_episode(initial_state, r)
environment = episode_to_environment(configuration, episode)
environment.render(mode="ipython", width=500, height=450)

## JAX Batch Performance

In [None]:
from jax import random, jit, partial, vmap
import jax.numpy as np

from kaggle_environments import make
from kaggle_environments.envs.halite.helpers import ShipAction, ShipyardAction

from halite_jax import environment_to_initial_state, generate_episode, Action, \
    episode_to_environment, random_agent

rng = random.PRNGKey(42)

environment = make("halite")
configuration = environment.configuration
compiled_random_agent = partial(random_agent, configuration)
agents = [compiled_random_agent, compiled_random_agent]
compiled_generate_episode = jit(vmap(partial(generate_episode, configuration, agents, initial_state)))
_ = environment.reset(2)
initial_state = environment_to_initial_state(environment)
batch_size = 1000

In [None]:
rng, *rngs = random.split(rng, batch_size)
%timeit compiled_generate_episode(np.array(rngs))[0].halite.block_until_ready()

1 loop, best of 5: 22.2 s per loop
