In [None]:
!pip install git+https://github.com/instadeepai/jumanji.git

## to run the binBox visualization 

In [None]:
import jax
import jax.numpy as jnp
import jumanji
from jumanji.wrappers import AutoResetWrapper
key = jax.random.PRNGKey(0)
env = jumanji.make("BinPack-toy-v0")
state, timestep = jax.jit(env.reset)(key)

# Randomly choose ems_id and item_id using the action mask
for i in range(0,20):
    num_ems, num_items = env.action_spec().num_values
    ems_item_id = jax.random.choice(
        key=key,
        a=num_ems * num_items,
        p=timestep.observation.action_mask.flatten(),
    )
    ems_id, item_id = jnp.divmod(ems_item_id, num_items)

    # Wrap the action as a jax array of shape (2,)
    action = jnp.array([ems_id, item_id])
    state, timestep = env.step(state, action)
    env.render(state)

##  timestep contains - observation, discount, reward

In [None]:
num_ems, num_items
# output: (DeviceArray(40, dtype=int32), DeviceArray(20, dtype=int32))

In [None]:
num_ems * num_items
# DeviceArray(800, dtype=int32)  picking an id from this can be done using NN

## PPO
### Reference of implementation: 
https://medium.com/analytics-vidhya/coding-ppo-from-scratch-with-pytorch-part-1-4-613dfc1b14c8  ( follow first 3 parts to understand the code) 
https://github.com/ericyangyu/PPO-for-Beginners/blob/9abd435771aa84764d8d0d1f737fa39118b74019/ppo.py#L260 github repo

### Clean RL videos to understand inside, but it is using gym wrappers alot, i didn't see that in jumanji. So, i decided to go with above reference.
https://docs.cleanrl.dev/rl-algorithms/ppo/

### Reference to understand RL concepts: 
https://spinningup.openai.com/en/latest/spinningup/rl_intro.html (basic concepts of RL important, will help us on how to change inside ppo to adjust 3D bins)


### functions of jumanji that can be used : 
https://github.com/instadeepai/jumanji/blob/main/jumanji/environments/combinatorial/binpack/env.py
reference to snake game : https://colab.research.google.com/github/instadeepai/jumanji/blob/main/examples/anakin_snake.ipynb#scrollTo=mNd-1Zgp5MGZ
but this snake game is 2D, we are dealing with 3D

### gym wrapper concept 
https://www.gymlibrary.dev/api/wrappers/

### DQN 
https://jaromiru.com/2016/09/27/lets-make-a-dqn-theory/



In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
class FeedForwardNN(nn.Module):
  def __init__(self):
    super(FeedForwardNN, self).__init__()
  def __init__(self, in_dim, out_dim):
    super(FeedForwardNN, self).__init__()
    self.layer1 = nn.Linear(in_dim, 64)
    self.layer2 = nn.Linear(64, 64)
    self.layer3 = nn.Linear(64, out_dim)
  def forward(self, obs):
  # Convert observation to tensor if it's a numpy array
    if isinstance(obs, np.ndarray):
      obs = torch.tensor(obs, dtype=torch.float)
      print("inside tensor")
  
    activation1 = F.relu(self.layer1(obs))
    activation2 = F.relu(self.layer2(activation1))
    output = self.layer3(activation2)
    return output

In [None]:
### observation instead of state 

from torch.distributions import MultivariateNormal
from torch.optim import Adam
class PPO:
  def __init__(self,env):
    self._init_hyperparameters()
    self.env = env
    #####################################
    self.obs_dim = env.obs_num_ems
    self.act_dim = env.action_spec().shape[0]
    ######################################
    

    #initiate actor and critic
    self.actor = FeedForwardNN(self.obs_dim,self.act_dim)
    self.critic = FeedForwardNN(self.obs_dim,1)

      # Create our variable for the matrix.
    # Note that I chose 0.5 for stdev arbitrarily.
    self.cov_var = torch.full(size=(self.act_dim,), fill_value=0.5)
    
    # Create the covariance matrix
    self.cov_mat = torch.diag(self.cov_var)
    self.actor_optim = Adam(self.actor.parameters(), lr=self.lr)
    self.critic_optim = Adam(self.critic.parameters(), lr=self.lr)
    
    
  def _init_hyperparameters(self):
    # Default values for hyperparameters, will need to change later.
    self.timesteps_per_batch = 4800            # timesteps per batch
    self.max_timesteps_per_episode = 1600      # timesteps per episode
    self.gamma = 0.95
    self.n_updates_per_iteration = 5
    self.clip = 0.2 # As recommended by the paper
    self.lr = 0.005

  def compute_rtgs(self, batch_rews): 
    # The rewards-to-go (rtg) per episode per batch to return.
    # The shape will be (num timesteps per episode)
    batch_rtgs = []
    # Iterate through each episode backwards to maintain same order
    # in batch_rtgs
    for ep_rews in reversed(batch_rews):
      discounted_reward = 0 # The discounted reward so far
      for rew in reversed(ep_rews):
        discounted_reward = rew + discounted_reward * self.gamma
        batch_rtgs.insert(0, discounted_reward)
    # Convert the rewards-to-go into a tensor
    batch_rtgs = torch.tensor(batch_rtgs, dtype=torch.float)
    return batch_rtgs


  def rollout(self):
    # Batch data
    batch_obs = []             # batch observations
    batch_acts = []            # batch actions
    batch_log_probs = []       # log probs of each action
    batch_rews = []            # batch rewards
    batch_rtgs = []            # batch rewards-to-go
    batch_lens = []            # episodic lengths in batch
    # Number of timesteps run so far this batch
    t = 0 
    while t < self.timesteps_per_batch:
      # Rewards this episode
      ep_rews = []
      key = jax.random.PRNGKey(0)
      ###############################
      #jax.jit(env.reset)(key)
      state, timestep = self.env.reset(key)
      ###############################
    
      for ep_t in range(self.max_timesteps_per_episode):

        # Increment timesteps ran this batch so far
        t += 1
        # Collect observation
        ################################################
        batch_obs.append(state)
        num_ems, num_items = env.action_spec().num_values
        ems_item_id = jax.random.choice(
            key=key,
            a=num_ems * num_items,
            p=timestep.observation.action_mask.flatten(),
          )
        ems_id, item_id = jnp.divmod(ems_item_id, num_items)

        # Wrap the action as a jax array of shape (2,)
        action = jnp.array([ems_id, item_id])
        mean = self.actor(state)
        dist = MultivariateNormal(mean, self.cov_mat)
        log_prob = dist.log_prob(action)

        state,timestep = self.env.step(action)
        rew = timestep.reward-1 
        ##################################################
        # Collect reward, action, and log prob
        ep_rews.append(rew)
        batch_acts.append(action)
        batch_log_probs(log_prob.detach())
      # Collect episodic length and rewards
      batch_lens.append(ep_t + 1) # plus 1 because timestep starts at 0
      batch_rews.append(ep_rews) 
      # Reshape data as tensors in the shape specified before returning
    batch_obs = torch.tensor(batch_obs, dtype=torch.float)
    batch_acts = torch.tensor(batch_acts, dtype=torch.float)
    batch_log_probs = torch.tensor(batch_log_probs, dtype=torch.float)
    # ALG STEP #4
    batch_rtgs = self.compute_rtgs(batch_rews)
    # Return the batch data
    return batch_obs, batch_acts,batch_log_probs, batch_rtgs, batch_lens

  def learn(self, total_timesteps):
    t_so_far = 0 # Timesteps simulated so far
    while t_so_far < total_timesteps:              # ALG STEP 2
      # Increment t_so_far somewhere below
      # ALG STEP 3
      batch_obs, batch_acts,batch_log_probs, batch_rtgs, batch_lens = self.rollout()
      # Calculate how many timesteps we collected this batch   
      t_so_far += np.sum(batch_lens)
      # Calculate V_{phi, k}
      V, _ = self.evaluate(batch_obs, batch_acts)
      # ALG STEP 5
      # Calculate advantage
      A_k = batch_rtgs - V.detach()
      # Normalize advantages
      A_k = (A_k - A_k.mean()) / (A_k.std() + 1e-10)
      for _ in range(self.n_updates_per_iteration):
        # Calculate V_phi and pi_theta(a_t | s_t)    
        V, curr_log_probs = self.evaluate(batch_obs, batch_acts)
        # Calculate ratios
        ratios = torch.exp(curr_log_probs - batch_log_probs)
        # Calculate surrogate losses
        surr1 = ratios * A_k
        surr2 = torch.clamp(ratios, 1 - self.clip, 1 + self.clip) * A_k
        actor_loss = (-torch.min(surr1, surr2)).mean()
        # Calculate gradients and perform backward propagation for actor 
        # network
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()
        critic_loss = nn.MSELoss()(V, batch_rtgs)
        # Calculate gradients and perform backward propagation for critic network    
        self.critic_optim.zero_grad()    
        critic_loss.backward()    
        self.critic_optim.step()
    
  def evaluate(self, batch_obs,batch_acts):
    # Query critic network for a value V for each obs in batch_obs.
    V = self.critic(batch_obs).squeeze()
    # Calculate the log probabilities of batch actions using most 
    # recent actor network.
    # This segment of code is similar to that in get_action()
    mean = self.actor(batch_obs)
    dist = MultivariateNormal(mean, self.cov_mat)
    log_probs = dist.log_prob(batch_acts)
    # Return predicted values V and log probs log_probs
    return V, log_probs
 
  

In [None]:
model = PPO(env)
model.learn(10)