<a href="https://colab.research.google.com/github/azzeddineCH/RL-environment-in-JAX/blob/main/RL_env.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RL env with Jax XLA
this notebook hold an example of a simple RL environemnt implemented in JAX and possible to be compiled to XLA code. 

the following code is inspired by the blog post [Writing an RL Environment in JAX](https://medium.com/@ngoodger_7766/writing-an-rl-environment-in-jax-9f74338898ba) for Nikolaj Goodger.



In [None]:
import jax
from jax import numpy as jnp
import numpy as np
from typing import NamedTuple, Tuple
from functools import partial

## Abstract Environemnt

the following in an abstract implementation of RL environment in JAX which should follow the the stateful computration in JAX.

the main idea is to port the intenal state of environemnt across method calls and avoid having side effect inside the jitted methods: `env.reset and env.step`

In [None]:

class AbstractEnv: 

  def __init__(self): 
    self.randim_limit = 0.05 

  def _get_obs(self, state): 
    return state
  
  def _maybe_reset(self, env_state, done): 

    key = env_state[1]
    return jax.lax.cond(done, self._reset, lambda key: env_state, key)
  
  def _reset(self, key): 
    new_state = jax.random.uniform(key, minval=-self.randim_limit, maxval=self.randim_limit)
    new_key, _ = jax.random.split(key)

    return new_state, new_key

  def step(self, env_state, action): 
    state, key = env_state
    new_state = state + action

    reward, done, info = 1., False, None

    env_state = new_state, key
    env_state = [new_state, _] = self._maybe_reset(env_state, done)

    return env_state, self._get_obs(new_state), reward, done, info

  def reset(self, key): 
    [new_state, _] = env_state = self._reset(key)
    return env_state, self._get_obs(new_state)


## MountainCar environemnt implementation

the following is a simple  implementation of the [MountainCar gym environemnt](https://www.gymlibrary.dev/environments/classic_control/mountain_car/).

we carry on an implementa a jittibal rollout script that is possible to paralleize across diffrent device using pmap and vectorize inside each device using `vmap`

In [None]:
class EnvState(NamedTuple): 
  key: jax.random.PRNGKeyArray
  position: jax.Array
  velocity: jax.Array

class StepResult(NamedTuple): 
  obs: jax.Array
  reward: jax.Array
  done: jax.Array
  info: jax.Array

In [None]:
class MountainCarEnv:

  def __init__(self):
    self.min_position = -1.2 
    self.max_position = 0.6
    self.max_speed = 0.07
    self.goal_position= 0.5
    self.goal_velocity = 0

    self.force = 0.001
    self.gravity = 0.0025

    self.low = jnp.array(np.array([self.min_position, -self.max_speed], dtype=np.float32))
  
  @partial(jax.jit, static_argnums=(0,))
  def reset(self, key: jax.random.PRNGKeyArray) -> EnvState: 
    return self._reset(key)
  
  def _reset(self, key: jax.random.PRNGKeyArray) -> EnvState: 
    position = jax.random.uniform(key=key, minval=-0.6, maxval=-0.4)
    velocity = jnp.array(0, dtype=jnp.float32)
    key, _ = jax.random.split(key)

    return EnvState(key, position, velocity)

  def _maybe_reset(self, env_state: EnvState, done: bool) -> EnvState : 
    [key, *_ ] = env_state
    return jax.lax.cond(done, self._reset, lambda key: env_state, key)

  def _get_obs(self, env_state: EnvState): 
    _, position, velocity = env_state
    return jnp.array([position, velocity]) 
  
  @partial(jax.jit, static_argnums=(0,))
  def step(self, env_state: EnvState, action: jax.Array) -> Tuple[EnvState, StepResult] : 
    key, position, velocity = env_state

    new_velocity = velocity + ((action - 1) * self.force + jnp.cos(3 * position) * (-self.gravity))
    new_velocity = jnp.clip(new_velocity, -self.max_speed, self.max_speed)
    
    new_position = position + new_velocity
    new_position = jnp.clip(new_position, self.min_position, self.max_position)

    new_velocity = jax.lax.cond( (new_position == self.min_position) & (new_velocity < 0), lambda: jnp.array(0, dtype=jnp.float32) , lambda: new_velocity)
    
    done = (new_position >= self.goal_position) & (new_velocity >= self.goal_velocity)

    reward = jnp.array(-1.0, dtype=jnp.float32) 
    info = jnp.array(0.0, dtype=jnp.float32)

    env_state = EnvState(key, new_position, new_velocity)
    env_state = self._maybe_reset(env_state, done)

    return env_state, StepResult(self._get_obs(env_state), reward, done, info)


In [None]:
def fori_body(i, episode_data): 
  env_state, action_key, all_obs, all_reward, all_done = episode_data
  [action] = jax.random.randint(action_key, (1,), 0, 2)
  action_key, _ = jax.random.split(action_key)

  env_state, ( obs, reward, done , _) = env.step(env_state, action)

  all_obs.at[i].set(obs)
  all_reward.at[i].set(reward)
  all_done.at[i].set(done)

  episode_data = (env_state, action_key, all_obs, all_reward, all_done)
  return episode_data

@jax.pmap
@jax.vmap
def rollout(key):
  TIMESTEPS = 100000000
  all_obsv = jnp.zeros(shape=(TIMESTEPS, 2))
  all_reward = jnp.zeros(shape=(TIMESTEPS, 1))
  all_done = jnp.zeros(shape=(TIMESTEPS, 1), dtype=jnp.bool_)
  action_key = jax.random.PRNGKey(0)

  env_state = env.reset(key)
  val = (env_state, action_key, all_obsv, all_reward, all_done)
  val = jax.lax.fori_loop(0, TIMESTEPS, fori_body, val)
  *_, all_obsv, all_reward, all_done = val
  
  return all_obsv, all_reward, all_done

In [None]:
NUM_ENV = 4
NUM_DEVICES = 1

In [None]:
seed = 0
key = jax.random.PRNGKey(seed)

In [None]:
keys = jax.random.split(key, NUM_ENV).reshape(NUM_DEVICES, NUM_ENV // NUM_DEVICES, -1)

In [None]:
env = MountainCarEnv()

In [None]:
all_obsv, all_reward, all_done = rollout(keys)