# Balzax environments

***To be run on Collab***

## Imports and modules

### Install Flax

In [9]:
!pip install flax



### Install Balzax

In [10]:
!pip install git+https://github.com/charlypg/Balzax

Collecting git+https://github.com/charlypg/Balzax
  Cloning https://github.com/charlypg/Balzax to /tmp/pip-req-build-oyzev2ol
  Running command git clone -q https://github.com/charlypg/Balzax /tmp/pip-req-build-oyzev2ol


### Imports

In [11]:
import jax
import jax.numpy as jnp
import flax
from functools import partial
import matplotlib.pyplot as plt
from time import time 

from balzax import BallsEnv, BallsEnvGoal

### Devices

In [12]:
print(jax.devices())

[GpuDevice(id=0, process_index=0)]


## Testing vectorized BallEnv

In [13]:
OBS_TYPE = 'position' # @param ['position', 'image']
SEED = 0
NUM_ENV = 10000

NB_ITER = 10000

ACTION_0 = jnp.zeros((NUM_ENV,))
ACTION_1 = jnp.ones((NUM_ENV,))/2.

key = jax.random.PRNGKey(SEED)
keys = jax.random.split(key, num=NUM_ENV)

#print(keys)

env = BallsEnv(obs_type=OBS_TYPE)

vmap_env_reset = jax.jit(jax.vmap(env.reset))  # jax.vmap(env.reset)
vmap_env_step = jax.jit(jax.vmap(env.step))  # jax.vmap(env.step)

obs_list = []

print()
print("Observation type : {}".format(OBS_TYPE))
print("Seed : {}".format(SEED))
print("Number of envs : {}".format(NUM_ENV))
print()

print("Number of iterations : {}".format(NB_ITER))
print()

t0 = time()
env_states = vmap_env_reset(keys)
print("Time du reset (jit+exec) : {}".format(time()-t0))
print()

#print(env_states)

observations = env_states.state.obs
print("observations : {}".format(observations.shape))
print()

t0 = time()
env_states = vmap_env_reset(keys)
print("Time du reset (second time exec) : {}".format(time()-t0))
print()

obs_list.append(env_states.state.obs)

t0 = time()
env_states = vmap_env_step(env_states, ACTION_0)
print("First step (jit+exec) : {}".format(time()-t0))
print()

obs_list.append(env_states.state.obs)

t0 = time()
env_states = vmap_env_step(env_states, ACTION_1)
print("Second step (exec) : {}".format(time()-t0))
print()

obs_list.append(env_states.state.obs)

t0 = time()
for _ in range(NB_ITER):
    env_states = vmap_env_step(env_states, ACTION_1)
    obs_list.append(env_states.state.obs)
print("{0} iterations in {1}s".format(NB_ITER, time()-t0))
print()

pulse = 2*jnp.pi / NB_ITER * jnp.ones((NUM_ENV,))
t0 = time()
for i in range(NB_ITER):
    env_states = vmap_env_step(env_states, jnp.sin(pulse*i))
    obs_list.append(env_states.state.obs)
print("{0} iterations in {1}s".format(NB_ITER, time()-t0))
print()


Observation type : position
Seed : 0
Number of envs : 10000

Number of iterations : 10000

Time du reset (jit+exec) : 16.14885687828064

observations : (10000, 8)

Time du reset (second time exec) : 0.0036156177520751953

First step (jit+exec) : 1.298271894454956

Second step (exec) : 1.4462976455688477

10000 iterations in 4.561103105545044s

10000 iterations in 7.61394476890564s



### Testing vectorized GoalEnv

In [14]:
OBS_TYPE = 'image'  # @param ['position', 'image']
SEED = 0
NUM_ENV = 500

NB_ITER = 500

ACTION_0 = jnp.zeros((NUM_ENV,))
ACTION_1 = jnp.ones((NUM_ENV,))/2.

key = jax.random.PRNGKey(SEED)
keys = jax.random.split(key, num=NUM_ENV)

#print(keys)

env = BallsEnvGoal(obs_type=OBS_TYPE)

vmap_env_reset = jax.jit(jax.vmap(env.reset))  # jax.vmap(env.reset)
vmap_env_step = jax.jit(jax.vmap(env.step))  # jax.vmap(env.step)

obs_list = []

print()
print("Observation type : {}".format(OBS_TYPE))
print("Seed : {}".format(SEED))
print("Number of envs : {}".format(NUM_ENV))
print()

print("Number of iterations : {}".format(NB_ITER))
print()

t0 = time()
env_states = vmap_env_reset(keys)
print("Time du reset (jit+exec) : {}".format(time()-t0))
print()

#print(env_states)

observations = env_states.state.obs
print("observations : {}".format(observations.shape))
print()

t0 = time()
env_states = vmap_env_reset(keys)
print("Time du reset (second time exec) : {}".format(time()-t0))
print()

obs_list.append(env_states.state.obs)

t0 = time()
env_states = vmap_env_step(env_states, ACTION_0)
print("First step (jit+exec) : {}".format(time()-t0))
print()

obs_list.append(env_states.state.obs)

t0 = time()
env_states = vmap_env_step(env_states, ACTION_1)
print("Second step (exec) : {}".format(time()-t0))
print()

obs_list.append(env_states.state.obs)

t0 = time()
for _ in range(NB_ITER):
    env_states = vmap_env_step(env_states, ACTION_1)
    #obs_list.append(env_states.state.obs)
print("{0} iterations in {1}s".format(NB_ITER, time()-t0))
print()

pulse = 2*jnp.pi / NB_ITER * jnp.ones((NUM_ENV,))
t0 = time()
for i in range(NB_ITER):
    env_states = vmap_env_step(env_states, jnp.sin(pulse*i))
    #obs_list.append(env_states.state.obs)
print("{0} iterations in {1}s".format(NB_ITER, time()-t0))
print()


Observation type : image
Seed : 0
Number of envs : 500

Number of iterations : 500

Time du reset (jit+exec) : 27.19322109222412

observations : (500, 224, 224, 1)

Time du reset (second time exec) : 0.007956266403198242

First step (jit+exec) : 1.3358988761901855

Second step (exec) : 1.4006242752075195

500 iterations in 4.382855176925659s

500 iterations in 4.593937158584595s

