<a href="https://colab.research.google.com/github/kaiamj/deep-reinforcement-learning-jumanji/blob/main/ppo_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install git+https://github.com/instadeepai/jumanji.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/instadeepai/jumanji.git
  Cloning https://github.com/instadeepai/jumanji.git to /tmp/pip-req-build-aqoxg89g
  Running command git clone --filter=blob:none --quiet https://github.com/instadeepai/jumanji.git /tmp/pip-req-build-aqoxg89g
  Resolved https://github.com/instadeepai/jumanji.git to commit 10958866909d434ba50edc1915247e4cebc3cb3e
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting matplotlib>=3.3.4
  Downloading matplotlib-3.6.3-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (9.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.4/9.4 MB[0m [31m60.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dm-env>=1.5
  Downloading dm_env-1.6-py3-none-any.whl (26 kB)
Collecting brax>=0.0.10
  Downloa

In [2]:
import numpy as np
import jax
from jax import vmap
import jax.numpy as jnp
import jumanji
from jumanji.wrappers import AutoResetWrapper
import flax.linen as nn

In [3]:

def flatten_jax(obs):
    return jnp.concatenate([obs.ems.x1,obs.ems.x2,
                        obs.ems.y1,obs.ems.y2,
                        obs.ems.z1,obs.ems.z2,
                        obs.ems_mask.flatten(),obs.items.x_len,
                        obs.items.y_len,obs.items.z_len,
                        obs.items_mask.flatten(),obs.items_placed.flatten()])

# flax nn

In [4]:
from flax import linen as nn 
import optax

class SimpleClassifierCompact(nn.Module):
    num_hidden : int   # Number of hidden neurons
    num_outputs : int  # Number of output neurons

    @nn.compact  # Tells Flax to look for defined submodules
    def __call__(self, x):
        # Perform the calculation of the model to determine the prediction
        # while defining necessary layers
        x = nn.Dense(features=self.num_hidden)(x)
        x = nn.tanh(x)
        x = nn.Dense(features=self.num_outputs)(x)
        return x
def critic_calculate_loss( state, V,batch_rtgs):
    #logits = critic_state.apply_fn(params, data).squeeze(axis=-1)
    loss = optax.sigmoid_binary_cross_entropy(V, batch_rtgs).mean()
    #print("loss of critic ", loss)
    return loss

@jax.jit  # Jit the function for efficiency
def critic_train_step(state, V,batch_rtgs):
    # Gradient function
    grad_fn = jax.value_and_grad(critic_calculate_loss,  # Function to calculate the loss
                                 argnums=0  # Parameters are second argument of the function
                                 #has_aux=False  # Function has additional outputs, here accuracy
                                )
    # Determine gradients for current model, parameters and batch
    loss, grads = grad_fn(state.params,V,batch_rtgs)
    #print("gradient of critic ",grads," type ", type(grads))

    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state
def actor_calculate_loss( state, surr1,surr2):
    #logits = critic_state.apply_fn(params, data).squeeze(axis=-1)
    actor_loss = (-jnp.minimum(surr1, surr2)).mean()
    #print("actor loss ",actor_loss)
    return actor_loss

@jax.jit  # Jit the function for efficiency
def actor_train_step(state,  surr1,surr2):
    # Gradient function
    grad_fn = jax.value_and_grad(actor_calculate_loss,  # Function to calculate the loss
                                 argnums=0  # Parameters are second argument of the function
                                 #has_aux=False  # Function has additional outputs, here accuracy
                                )
    # Determine gradients for current model, parameters and batch
    loss, grads = grad_fn(state.params, surr1,surr2)
    #print("gradient of actor ",grads," type ", type(grads))
    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state
#critic_state, loss = critic_train_step(critic_state, data)

In [5]:
from flax.training import train_state
import pandas as pd
import time


class PPO:
  def __init__(self,env):
    self._init_hyperparameters()
    self.env = env
    
    #initiate actor and critic =========================================================
    

    self.optimizer = optax.sgd(learning_rate=self.lr)

    self.actor = SimpleClassifierCompact(num_hidden=64, num_outputs=self.act_dim)
    self.critic = SimpleClassifierCompact(num_hidden=64, num_outputs=1)
    
    self.params = self.actor.init(self.subkey, jnp.arange(self.obs_dim))
    self.cparams = self.critic.init(self.subkey, jnp.arange(self.obs_dim))

    self.actor_state = train_state.TrainState.create(apply_fn=self.actor.apply,
                                            params=self.params,
                                            tx=self.optimizer)
    self.critic_state = train_state.TrainState.create(apply_fn=self.critic.apply,
                                            params=self.cparams,
                                            tx=self.optimizer)
     
    #====================================================================================
    
    
    
  def _init_hyperparameters(self):
    # Default values for hyperparameters, will need to change later.
    self.timesteps_per_batch = 2  #4800            # timesteps per batch
    #self.max_timesteps_per_episode = 1600      # timesteps per episode
    self.gamma = 0.95
    self.n_updates_per_iteration = 1
    self.clip = 0.2 # As recommended by the paper
    self.lr = 0.1

    self.obs_dim = 380
    self.act_dim = 800

    self.key = jax.random.PRNGKey(0)
    self.key, self.subkey = jax.random.split(self.key)


  def compute_rtgs_jax(self, batch_rews): 
    # The rewards-to-go (rtg) per episode per batch to return.
    # The shape will be (num timesteps per episode)
    
    batch_rtgs = jnp.array([])
    #print("batch_rews ",batch_rews)
    # Iterate through each episode backwards to maintain same order in batch_rtgs
    for ep_rews in reversed(batch_rews):
      discounted_reward = 0 # The discounted reward so far
      #print(" ep_rews ",ep_rews)
      discounted_reward = ep_rews + discounted_reward * self.gamma
      #print("discounted_reward ",discounted_reward)
      batch_rtgs = jnp.insert(batch_rtgs, 0, discounted_reward)
    #print("batch_rtgs",batch_rtgs)
    return batch_rtgs


 
  def get_action_jax(self, obs, action_jnp):
    #flatten observation  p = flatten(timestep.observation)
    # inside critic and actor by converting it to np.array(p) by actor you will get action

    #initiate actor and critic
     #actor = FeedForwardNN(380,800)
     logits = self.actor_state.apply_fn(self.actor_state.params, obs)
    
     valid_indices = jnp.nonzero(action_jnp)  # getting valid indicies
     valid_logits = logits[valid_indices]  # getting proper valid action probablities
     valid_logits_array = jnp.reshape(valid_logits, (-1,))  # reshapping it to make it 1D
     index_mapping = [i for i, include in enumerate(action_jnp) if include] # mapping of valid actions on whole set of actions
     # Gumbel's trick
     
     u = jax.random.uniform(self.subkey, shape=valid_logits_array.shape) # generates random uniform values
     
     probs = valid_logits_array - jnp.log(-jnp.log(u)) # logits + random uniform noise
     action = jnp.argmax(probs) # argmax of probs -> action id in filtered array from valid actions
     action_id = index_mapping[action] # action index in the 800 size array from actor output
     log_prob_action = jnp.log(action) # log probability of selected action
     
     return action_id, action, log_prob_action
  
  def rollout(self):
    # batch observations, batch actions, log probs of each action, batch rewards,batch rewards-to-go,episodic lengths in batch
    batch_obs, batch_acts, batch_log_probs, batch_rews, batch_rtgs, batch_lens = jnp.array([]), jnp.array([]), jnp.array([]), jnp.array([]), jnp.array([]),jnp.array([])
    batch_states = []
    step_fn = jax.jit(self.env.step)
    reset_fn = jax.jit(self.env.reset)
    t = 0 
    
    while t < self.timesteps_per_batch: # Number of timesteps run so far this batch
      # Rewards this episode
      ep_rews = jnp.array([])
      
      
      state, timestep = reset_fn(self.key)
      ep_t = 1
      rew = 0.0
      while rew == 0.0:
        batch_states.append(state)
        obs = flatten_jax(timestep.observation)  # Collect observation
        #print("obs  flat ",len(obs))
        if t == 0 and ep_t == 1:
          batch_obs = jnp.append(batch_obs, obs)
        else:
          batch_obs = jnp.vstack([batch_obs, obs])
        #print("batch  ",batch_obs)
        num_ems, num_items = self.env.action_spec().num_values
        action_mask = timestep.observation.action_mask.flatten()
        action_jnp = jnp.array(action_mask, dtype=jnp.float32)

        ems_item_id, action_,log_prob  = self.get_action_jax(obs,action_jnp)
        ems_id, item_id = jnp.divmod(ems_item_id, num_items)

        action = jnp.array([ems_id, item_id])  # Wrap the action as a jax array of shape (2,)
        #batch_states = jnp.append(batch_states, state)

        state,timestep = step_fn(state, action)
        rew = jnp.array(timestep.reward.flatten())[0]
        #print(" rew ", rew," type ", type(rew))
        ep_rews = jnp.append(ep_rews, rew)
        #print(" ep_rews ",ep_rews )
        batch_acts = jnp.append(batch_acts, action_)
        batch_log_probs = jnp.append(batch_log_probs, log_prob)
        ep_t += 1 # Increment timesteps ran this batch so far

      t += ep_t
      batch_rews = jnp.append(batch_rews, ep_rews) 
    
    batch_rtgs = self.compute_rtgs_jax(batch_rews)
    return batch_obs, batch_acts,batch_log_probs, batch_rtgs, t ,rew,batch_states

  def learn(self, total_timesteps):
    t_so_far = 0 # Timesteps simulated so far
    episode_reward = jnp.array([])
    while t_so_far < total_timesteps:              # ALG STEP 
      batch_obs, batch_acts, batch_log_probs, batch_rtgs, t, rew,batch_states = self.rollout()
      # print("len  of batch obs and obs of episode itself",len(batch_obs),batch_obs)
      # print("batch_ act ", batch_acts)
      # print("batch_log_prob ",batch_log_probs)
      # print("batch_rtg ", batch_rtgs)
      
      episode_reward = jnp.append(episode_reward, rew)
      t_so_far += t # Calculate how many timesteps we collected this batch   
      V, _ = self.evaluate(batch_obs, batch_acts)
      A_k = batch_rtgs - V # ALG STEP 5 Calculate advantage
      A_k = (A_k - A_k.mean()) / (A_k.std() + 1e-10) # Normalize advantages
      for i in range(self.n_updates_per_iteration):
        V, curr_log_probs = self.evaluate(batch_obs, batch_acts)
        ratios = jax.lax.exp(curr_log_probs - batch_log_probs)   # Calculate ratios
        surr1 = ratios * A_k  # Calculate surrogate losses
        surr2 = jax.lax.clamp( 1 - self.clip, ratios, 1 + self.clip) * A_k
        #print(surr1,surr2,"  surrrrrr  ","  ratios ",ratios, " ak ",A_k)
        #======================================================================================================
        self.actor_state = actor_train_step(self.actor_state, surr1, surr2)
        self.critic_state = critic_train_step(self.critic_state, V, batch_rtgs)
        #=====================================================================================================
        
        
    return batch_states, episode_reward
    
  def evaluate(self, batch_obs,batch_acts):
    # Query critic network for a value V for each obs in batch_obs.
    #V = self.critic(batch_obs).squeeze()
#vmap(sum_samples)(data)
    #V = vmap(self.critic_state.apply_fn)(self.critic_state.params, batch_obs)#.squeeze(axis=-1)
    V = jnp.array([])
    #batch_logits = jnp.array([])
    for i in batch_obs:
      
      # print("len of actor param ", len(self.actor_state.params),self.actor_state.params)

      #batch_logits = jnp.append(batch_logits,self.actor_state.apply_fn(self.actor_state.params, i))
      # print("len of out actor ",len(outa), outa )
      # print("critic param ",len(self.critic_state.params),self.critic_state.params)
      # print("obs ", len(i))
      V = jnp.append(V,self.critic_state.apply_fn(self.critic_state.params, i))
    #print("Value func ", V)
    #print("batch_logits ",batch_logits)
    # Calculate the log probabilities of batch actions using most 
    # recent actor network.  # This segment of code is similar to that in get_action()
    #batch_logits = self.actor_state.apply_fn(self.actor_state.params, batch_obs).squeeze(axis=-1)
    # rescaling 
    #log_softmax_probs = jax.nn.log_softmax(batch_logits)
    # print("log_softmax_probs len ",type(log_softmax_probs),log_softmax_probs)
    # print("batch_ actions ", type(batch_acts),batch_acts)
    log_probs = jnp.log(batch_acts)

    return V, log_probs  # Return predicted values V and log probs log_probs
 
  

In [6]:

env = jumanji.make("BinPack-toy-v0")
model = PPO(env)
#batch_obs, batch_acts, batch_log_probs, batch_rtgs, t, rew = model.rollout()
  
%time batch_states, episode_rewards = model.learn(1)
#### with jit env , reset one episode 
# CPU times: user 7.4 s, sys: 103 ms, total: 7.5 s
# Wall time: 7.1 s

  return asarray(x, dtype=self.dtype)
  return asarray(x, dtype=self.dtype)
  return asarray(x, dtype=self.dtype)


CPU times: user 21.7 s, sys: 796 ms, total: 22.5 s
Wall time: 46.9 s
