# Balzax environments

***To be run on Collab***

## Imports and modules

### Install Flax

In [1]:
!pip install flax



### Install Balzax

In [2]:
!pip install git+https://github.com/charlypg/Balzax

Collecting git+https://github.com/charlypg/Balzax
  Cloning https://github.com/charlypg/Balzax to /tmp/pip-req-build-0_q81v_l
  Running command git clone -q https://github.com/charlypg/Balzax /tmp/pip-req-build-0_q81v_l


### Imports

In [3]:
import jax
import jax.numpy as jnp
import flax
from functools import partial
from time import time 

from balzax import BallsEnv, BallsEnvGoal

### Devices

In [4]:
print(jax.devices())

[GpuDevice(id=0, process_index=0)]


## Testing vectorized BallsEnv

In [5]:
OBS_TYPE = 'image' # @param ['position', 'image']
MAX_TIMESTEPS = 3
SEED = 0
NUM_ENV = 500

NB_ITER_1 = 1
NB_ITER_2 = 500

ACTION_0 = jnp.zeros((NUM_ENV,))
ACTION_1 = jnp.ones((NUM_ENV,))/2.

key = jax.random.PRNGKey(SEED)
keys = jax.random.split(key, num=NUM_ENV)

env = BallsEnv(obs_type=OBS_TYPE, max_timestep=MAX_TIMESTEPS)

vmap_env_reset_done = jax.jit(jax.vmap(env.reset_done)) 
vmap_env_reset = jax.jit(jax.vmap(env.reset))  
vmap_env_step = jax.jit(jax.vmap(env.step))  

print()
print("Observation type : {}".format(OBS_TYPE))
print("Seed : {}".format(SEED))
print("Number of envs : {}".format(NUM_ENV))
print()

t0 = time()
env_states = vmap_env_reset(keys)
print("Time of reset (jit+exec) : {}".format(time()-t0))
print()

print("observations : {}".format(env_states.obs.shape))
print()

t0 = time()
for _ in range(NB_ITER_1):
    env_states = vmap_env_step(env_states, ACTION_1)
    env_states = vmap_env_reset_done(env_states)
print("{0} iterations in {1}s".format(NB_ITER_1, time()-t0))
print("step and reset_done : first call reflecting compilation time")
print()

pulse = 2*jnp.pi / NB_ITER_2 * jnp.ones((NUM_ENV,))
t0 = time()
for i in range(NB_ITER_2):
    env_states = vmap_env_step(env_states, jnp.sin(pulse*i))
    env_states = vmap_env_reset_done(env_states)
print("{0} iterations in {1}s".format(NB_ITER_2, time()-t0))
print("step and reset_done")
print()


Observation type : image
Seed : 0
Number of envs : 500

Time of reset (jit+exec) : 11.572967767715454

observations : (500, 224, 224, 1)

1 iterations in 10.62477731704712s
step and reset_done : first call reflecting compilation time

500 iterations in 6.302776575088501s
step and reset_done



### Testing vectorized BallsEnvGoal

In [6]:
OBS_TYPE = 'position' # @param ['position', 'image']
MAX_TIMESTEPS = 50
SEED = 0
NUM_ENV = 10_000

NB_ITER_1 = 1
NB_ITER_2 = 10_000

ACTION_0 = jnp.zeros((NUM_ENV,))
ACTION_1 = jnp.ones((NUM_ENV,))/2.

key = jax.random.PRNGKey(SEED)
keys = jax.random.split(key, num=NUM_ENV)

env = BallsEnvGoal(obs_type=OBS_TYPE, max_timestep=MAX_TIMESTEPS)

vmap_env_reset_done = jax.jit(jax.vmap(env.reset_done)) 
vmap_env_reset = jax.jit(jax.vmap(env.reset))  
vmap_env_step = jax.jit(jax.vmap(env.step))  

print()
print("Observation type : {}".format(OBS_TYPE))
print("Seed : {}".format(SEED))
print("Number of envs : {}".format(NUM_ENV))
print()

t0 = time()
env_states = vmap_env_reset(keys)
print("Time of reset (jit+exec) : {}".format(time()-t0))
print()

for field, value in zip(env_states.goalobs.keys(), env_states.goalobs.values()):
  print("{0} : {1}".format(field, value.shape))
print()

t0 = time()
for _ in range(NB_ITER_1):
    env_states = vmap_env_step(env_states, ACTION_1)
    env_states = vmap_env_reset_done(env_states)
print("{0} iterations in {1}s".format(NB_ITER_1, time()-t0))
print("step and reset_done : first call reflecting compilation time")
print()

pulse = 2*jnp.pi / NB_ITER_2 * jnp.ones((NUM_ENV,))
t0 = time()
for i in range(NB_ITER_2):
    env_states = vmap_env_step(env_states, jnp.sin(pulse*i))
    env_states = vmap_env_reset_done(env_states)
print("{0} iterations in {1}s".format(NB_ITER_2, time()-t0))
print("step and reset_done")
print()


Observation type : position
Seed : 0
Number of envs : 10000

Time of reset (jit+exec) : 17.60360336303711

achieved_goal : (10000, 8)
desired_goal : (10000, 8)
observation : (10000, 8)

1 iterations in 20.688915729522705s
step and reset_done : first call reflecting compilation time

10000 iterations in 31.756733179092407s
step and reset_done

