In [14]:
import numpy as np
import jax
import jax.numpy as jnp
import jumanji
from jumanji.wrappers import AutoResetWrapper
import flax.linen as nn

In [15]:
def flatten_jax(obs):
    return jnp.concatenate([obs.ems.x1,obs.ems.x2,
                        obs.ems.y1,obs.ems.y2,
                        obs.ems.z1,obs.ems.z2,
                        obs.ems_mask.flatten(),obs.items.x_len,
                        obs.items.y_len,obs.items.z_len,
                        obs.items_mask.flatten(),obs.items_placed.flatten()])

In [16]:
class SimpleClassifierCompact(nn.Module):
    num_hidden : int   # Number of hidden neurons
    num_outputs : int  # Number of output neurons

    @nn.compact  # Tells Flax to look for defined submodules
    def __call__(self, x):
        # Perform the calculation of the model to determine the prediction
        # while defining necessary layers
        x = nn.Dense(features=self.num_hidden)(x)
        x = nn.tanh(x)
        x = nn.Dense(features=self.num_outputs)(x)
        return x


In [17]:
# Check if works
rng = jax.random.PRNGKey(0)
model = SimpleClassifierCompact(num_hidden=64, num_outputs=1)
# Printing the model shows its attributes
# print(model)
rng, init_rng = jax.random.split(rng, 2)
inp = jnp.arange(380)  # Batch size 8, input size 2
# Initialize the model
params = model.init(init_rng, inp)
# print(params)
model.apply(params, inp)

Array([-0.21167278], dtype=float32)

In [18]:
def get_action_jax(obs, action_jnp):
    #flatten observation  p = flatten(timestep.observation)
    # inside critic and actor by converting it to np.array(p) by actor you will get action

    #initiate actor and critic
     #actor = FeedForwardNN(380,800)
     key = jax.random.PRNGKey(0)
     key, subkey = jax.random.split(key)
     
     actor = SimpleClassifierCompact(num_hidden=64, num_outputs=800)
     
     params = actor.init(init_rng, inp)
    
     logits = actor.apply(params,obs)
     valid_indices = jnp.nonzero(action_jnp)  # getting valid indicies
     valid_logits = logits[valid_indices]  # getting proper valid action probablities
     valid_logits_array = jnp.reshape(valid_logits, (-1,))  # reshapping it to make it 1D
     index_mapping = [i for i, include in enumerate(action_jnp) if include] # mapping of valid actions on whole set of actions
     # Gumbel's trick
     key, subkey = jax.random.split(key)
     u = jax.random.uniform(subkey, shape=valid_logits_array.shape) # generates random uniform values
     
     probs = valid_logits_array - jnp.log(-jnp.log(u)) # logits + random uniform noise
     action = jnp.argmax(probs) # argmax of probs -> action id in filtered array from valid actions
     action_id = index_mapping[action] # action index in the 800 size array from actor output
     log_prob_action = jnp.log(action) # log probability of selected action
     
     return action_id, action, log_prob_action

In [19]:
# Default values for hyperparameters, will need to change later.

timesteps_per_batch = 2  # timesteps per batch
#self.max_timesteps_per_episode = 1600      # timesteps per episode
gamma = 0.95
n_updates_per_iteration = 5
clip = 0.2     # As recommended by (Schulman, 2017)
lr = 0.005

In [20]:
def compute_rtgs_jax(batch_rews): 
    # The rewards-to-go (rtg) per episode per batch to return.
    # The shape will be (num timesteps per episode)
    
    batch_rtgs = jnp.array([])
    
    # Iterate through each episode backwards to maintain same order in batch_rtgs
    for ep_rews in reversed(batch_rews):
      discounted_reward = 0 # The discounted reward so far
      for rew in reversed(ep_rews):
        discounted_reward = rew + discounted_reward * gamma
        batch_rtgs = jnp.insert(batch_rtgs, 0, discounted_reward)
        
        #batch_rtgs.append(discounted_reward)
      #batch_rtgs = jnp.array(batch_rtgs[::-1])
      
    return batch_rtgs

In [21]:
env = jumanji.make("BinPack-toy-v0")


# Rollout function
batch_obs = jnp.array([])        # batch observations
batch_acts = jnp.array([])       # batch actions
batch_log_probs = jnp.array([])  # log probs of each action
batch_rews = jnp.array([])       # batch rewards
batch_rtgs = jnp.array([])       # batch rewards-to-go
batch_lens = jnp.array([])       # episodic lengths in batch

step_fn = jax.jit(env.step)
reset_fn = jax.jit(env.reset)
t = jnp.array(0, dtype=jnp.int32)

while t < timesteps_per_batch:
    # Rewards this episode
    ep_rews = jnp.array([])
    key = jax.random.PRNGKey(0)
    
    #jax.jit(env.reset)(key)
    state, timestep = jax.jit(env.reset)(key)
    
    ep_t = jnp.array(0, dtype=jnp.int32)
    rew = jnp.array(0.0, dtype=jnp.int32)

    while rew == 0.0:
        # Increment timesteps ran this batch so far
        t += 1
        
        obs = flatten_jax(timestep.observation)  # Collect observation
        batch_obs = jnp.append(batch_obs, obs)
    
        num_ems, num_items = env.action_spec().num_values
        action_mask = timestep.observation.action_mask.flatten()
        action_jnp = jnp.array(action_mask, dtype=jnp.float32)

        ems_item_id, action_,log_prob  = get_action_jax(obs,action_jnp)
        ems_id, item_id = jnp.divmod(ems_item_id, num_items)

        # Wrap the action as a jax array of shape (2,)
        action = jnp.array([ems_id, item_id])

        step_fn = jax.jit(env.step)
        reset_fn = jax.jit(env.reset)
        state, timestep = step_fn(state, action)
        
        rew = jnp.array(timestep.reward.flatten())[0]
        ep_rews = jnp.append(ep_rews, rew)
        batch_acts = jnp.append(batch_acts,action_)
        batch_log_probs = jnp.append(batch_log_probs, log_prob)
        
        ep_t += 1
    
    # Collect episodic length and rewards
    batch_lens = jnp.append(batch_lens, ep_t + 1) # plus 1 because timestep starts at 0
    batch_rews = jnp.append(batch_rews, ep_rews)

batch_rtgs = compute_rtgs_jax(batch_rews)

#return batch_obs, batch_acts, batch_log_probs, batch_rtgs, batch_lens, rew


  return asarray(x, dtype=self.dtype)
  return asarray(x, dtype=self.dtype)


TypeError: len() of unsized object

In [35]:
def evaluate(self, batch_obs, batch_acts):
    # Query critic network for a value V for each obs in batch_obs.
    V = self.critic(batch_obs).squeeze()
    #print("in evaluate , value function after critic ", V)
    
    # Calculate the log probabilities of batch actions using most recent actor network.
    # This segment of code is similar to that in get_action()
    batch_logits = self.actor(batch_obs)
    log_softmax_probs = jax.nn.log_softmax(batch_logits)

    log_probs = log_softmax_probs[batch_acts]
    # Return predicted values V and log probs log_probs
    return V, log_probs

In [34]:
from torch.distributions import Categorical

actor = SimpleClassifierCompact(num_hidden=64, num_outputs=800)   
mean = actor.init(init_rng, inp)

# Create our Multivariate Normal Distribution
dist = Categorical(mean)
# Sample an action from the distribution and get its log prob
action = dist.sample()
log_prob = dist.log_prob(action)

AttributeError: 'FrozenDict' object has no attribute 'dim'