In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from cliff_jax import Cliff

In [3]:
env = Cliff.default()

In [4]:
internal_state = env.reset()
env.observe(internal_state)

Array(16, dtype=int32)

In [5]:
env.step(internal_state, 1)

(InternalState(state=Array(16, dtype=int32), step=Array(1, dtype=int32, weak_type=True)),
 Array(-1, dtype=int32, weak_type=True),
 Array(False, dtype=bool))

In [129]:
import typing
import jax, jax.numpy as jnp, jax.random as jr
key = jr.key(0)

class Embedding(typing.NamedTuple):
  num_emb: int
  emb_dim: int

  def __call__(self, params, _key, x):
    return params[x]
  
  def logp(self, params, x, y):
    logits = self(params, ..., x)
    return jax.nn.log_softmax(logits)[y]

  def init(self, key):
    return jr.normal(key, (self.num_emb, self.emb_dim))

In [130]:
class EpsGreedy(typing.NamedTuple):
  model: typing.Any
  num_actions: int
  eps: float = .1

  def __call__(self, params, key, *args):
    key, eps_key = jr.split(key)
    # use cond to run the model only when it's necessary
    return jax.lax.cond(
      jr.uniform(eps_key) >= self.eps,
      # take model's prediction (greedy)
      lambda: self.model(params, key, *args).argmax(axis=-1),
      # select a random action
      lambda: jr.choice(key, self.num_actions),
    )
  
  def logp(self, params, x, y):
    return jnp.log(self.eps / self.num_actions + (1 - self.eps) * jnp.exp(self.model.logp(params, x, y)))

  def init(self, key):
    return self.model.init(key)

In [131]:
action_tab = Embedding(24, 4)
eps_greedy = EpsGreedy(action_tab, 4)
key, model_key = jr.split(key)
params = eps_greedy.init(model_key)

In [132]:
state = env.reset()
for _ in range(3):
  obs = env.observe(state) # extract the observation from the state
  key, action_key = jr.split(key)
  action = eps_greedy(params, action_key, obs)
  state, reward, done = env.step(state, action)
  if done:
    break

In [133]:
class Buf(typing.NamedTuple):
  max_num_eps: int
  max_episode_len: int

  class State(typing.NamedTuple):
    offset: int # current location in the buffer
    num_eps: int # number of episodes that is contained in this buffer
    ep_ends: jax.Array # end markers for episodes in the buffer (exclusive)
    observations: typing.Any
    actions: typing.Any
    rewards: typing.Any

    @property
    def ep_starts(self) -> jax.Array:
      starts = jnp.zeros_like(self.ep_ends)
      return starts.at[1:].set(self.ep_ends[:-1])
    
    @property
    def buf_size(self) -> int:
      return len(self.observations)

    def reset(self) -> "Buf.State":
      return Buf.State(
        offset=0,
        num_eps=0,
        ep_ends=self.ep_ends,
        observations=self.observations,
        actions=self.actions,
        rewards=self.rewards,
      )
  
  def can_append_episode(self, state: "State") -> bool:
    free = state.buf_size - state.offset
    return (state.num_eps < self.max_num_eps) and (free >= self.max_episode_len)

  def append(self, state: "State", obs, action, reward) -> "State":
    newoffset = state.offset + 1
    in_bounds = newoffset < len(state.observations)
    return Buf.State(
      offset=jnp.where(in_bounds, newoffset, state.offset),
      num_eps=state.num_eps,
      ep_ends=state.ep_ends,
      observations=jnp.where(in_bounds, state.observations.at[state.offset].set(obs), state.observations),
      actions=jnp.where(in_bounds, state.actions.at[state.offset].set(action), state.actions),
      rewards=jnp.where(in_bounds, state.rewards.at[state.offset].set(reward), state.rewards),
    )

  def end_episode(self, state: "State") -> "State":
    # Safety: calling can_append_episode before appending/ending an episode ensures that there is enough space
    return Buf.State(
      offset=state.offset,
      num_eps=state.num_eps + 1,
      ep_ends=state.ep_ends.at[state.num_eps].set(state.offset),
      observations=state.observations,
      actions=state.actions,
      rewards=state.rewards,
    )
  
  def empty(self, buf_size: typing.Optional[int] = None) -> "State":
    if buf_size is None:
      buf_size = self.max_num_eps * self.max_episode_len
    return Buf.State(
      offset=0,
      num_eps=0,
      ep_ends=jnp.zeros(self.max_num_eps, dtype=int),
      observations=jnp.zeros(buf_size, dtype=int),
      actions=jnp.zeros(buf_size, dtype=int),
      rewards=jnp.zeros(buf_size),
    )

In [134]:
buf = Buf(20, env.max_steps)
buf_state = buf.empty(buf_size=30)

In [137]:
buf_state = buf_state.reset()
while buf.can_append_episode(buf_state):
  state = env.reset()
  while True: # NB: the environment is responsible for terminating after a couple of steps
    obs = env.observe(state) # extract the observation from the state
    key, action_key = jr.split(key)
    action = eps_greedy(params, action_key, obs)
    state, reward, done = env.step(state, action)
    buf_state = buf.append(buf_state, obs, action, reward)
    if done:
      break
  buf_state = buf.end_episode(buf_state)
buf_state

State(offset=Array(11, dtype=int32, weak_type=True), num_eps=9, ep_ends=Array([ 2,  3,  4,  5,  7,  8,  9, 10, 11, 11, 11,  0,  0,  0,  0,  0,  0,
        0,  0,  0], dtype=int32), observations=Array([16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0], dtype=int32), actions=Array([3, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0], dtype=int32), rewards=Array([ -1., -50., -50., -50., -50.,  -1., -50., -50., -50., -50., -50.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.], dtype=float32))

In [139]:
buf_state.ep_ends[0], buf_state.rewards[:10]

(Array(2, dtype=int32),
 Array([ -1., -50., -50., -50., -50.,  -1., -50., -50., -50., -50.],      dtype=float32))

In [152]:
# compute cumulative reward per episode
eprew = jnp.zeros(buf.max_num_eps)
epidx, offset = 0, 0
while epidx < buf_state.num_eps:
  # accumulate rewards
  reward = buf_state.rewards[offset]
  eprew = eprew.at[epidx].set(eprew[epidx] + reward) # TODO: introduce a discount factor

  # update loop variables
  offset += 1
  epidx = jnp.where(offset >= buf_state.ep_ends[epidx], epidx + 1, epidx)
eprew

Array([-51., -50., -50., -50., -51., -50., -50., -50., -50.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],      dtype=float32)

In [147]:
def accumulate_rewards(buf_state, ep_idx, offset, ep_rew):
  reward = buf_state.rewards[offset]
  return ep_rew.at[ep_idx].set(ep_rew[ep_idx] + reward)

In [173]:
def reduce_episodes(fn, carry_init, buf_state: Buf.State):
  class LoopState(typing.NamedTuple):
    carry: typing.Any
    epidx: jax.Array # int

  def body(offset: int, state: LoopState):
    not_done = offset < buf_state.offset # buf's current level is not reached yet
    next_offset_overflow = (offset + 1) >= buf_state.ep_ends[state.epidx] # increase eps idx preventive
    return LoopState(
      carry=jnp.where(not_done, fn(buf_state, state.epidx, offset, state.carry), state.carry),
      epidx=jnp.where(not_done & next_offset_overflow, state.epidx + 1, state.epidx)
    )
  
  state = LoopState(carry_init, jnp.asarray(0))
  state = jax.lax.fori_loop(0, buf_state.buf_size, body, state)
  return state.carry
weights = reduce_episodes(accumulate_rewards, jnp.zeros(buf.max_num_eps), buf_state)
weights

Array([-51., -50., -50., -50., -51., -50., -50., -50., -50.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],      dtype=float32)

In [174]:
def accumulate_logps(buf_state: Buf.State, ep_idx, offset, logps):
  obs, act = buf_state.observations[offset], buf_state.actions[offset]
  logp = eps_greedy.logp(params, obs, act) # compute \pi(s_t\vert a_t)
  return logps.at[ep_idx].set(logps[ep_idx] + logp)

In [175]:
logps = reduce_episodes(accumulate_logps, jnp.zeros(buf.max_num_eps), buf_state)
logps

Array([-2.094716  , -0.87386847, -0.87386847, -0.87386847, -2.5958457 ,
       -0.87386847, -0.87386847, -0.87386847, -0.87386847,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ],      dtype=float32)

In [176]:
def expected_reward(params, buf_state: Buf.State) -> jax.Array:
  # compute sum of rewards for each episode
  def accumulate_rewards(buf_state, ep_idx, offset, ep_rew):
    reward = buf_state.rewards[offset]
    return ep_rew.at[ep_idx].set(ep_rew[ep_idx] + reward)

  # compute sum of logps for each episode
  def accumulate_logps(buf_state: Buf.State, ep_idx, offset, logps):
    obs, act = buf_state.observations[offset], buf_state.actions[offset]
    logp = eps_greedy.logp(params, obs, act) # compute \pi(s_t\vert a_t)
    return logps.at[ep_idx].set(logps[ep_idx] + logp)
    
  weights = reduce_episodes(accumulate_rewards, jnp.zeros(buf.max_num_eps), buf_state)
  logps = reduce_episodes(accumulate_logps, jnp.zeros(buf.max_num_eps), buf_state)
  # weight logps and compute average to approximate E[J(\pi)] over all episodes
  return jnp.sum(weights * logps) / buf_state.num_eps

In [181]:
jax.grad(expected_reward)(params, buf_state)

Array([[  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  0.      ,   0.      ,   0.      ,   0.      ],
       [  5.343898,   4.90797 , -22.246414,  11.994543],
       [  0.      ,   0.      ,