## Imports

In [None]:
import jax
import flax
import optax
import random
import math

import jax.random as jrandom
import jax.lax as lax
import flax.linen as nn
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

from collections import deque, namedtuple
from flax.training import train_state

## Stochastic MDP

In [None]:
class StochasticMDP:

  def __init__(self):
    self.end = False
    self.current_state = 2
    self.num_actions = 2
    self.num_states = 6
    self.p_right = 0.5

  def reset(self):
    self.end = False
    self.current_state = 2
    state = jnp.zeros(self.num_states)
    state.at[self.current_state - 1].set(1)
    return state

  def step(self, action):
    if self.current_state != 1:
      if action == 1:
        if random.random() < self.p_right and self.current_state < self.num_states:
          self.current_state += 1
        else:
          self.current_state -= 1
      if action == 0:
        self.current_state -= 1
      if self.current_state == self.num_states:
        self.end = True

    state = jnp.zeros(self.num_states)
    state.at[self.current_state - 1].set(1)
    if self.current_state == 1:
      if self.end:
        return lax.stop_gradient(state), 1.00, True, {}
      else:
        return lax.stop_gradient(state), 1/100, True, {}
    else:
      return lax.stop_gradient(state), 0, False, {}


## Replay Buffer

In [None]:
Batch = namedtuple(
    "Batch",
    ["state", "action", "reward", "next_state", "done"]
)

def to_onehot(x):
  oh = jnp.zeros(6)
  oh.at[x - 1].set(1)
  return oh

class ReplayBuffer:
  def __init__(self, capacity):
    self.capacity = capacity
    self.buffer = deque(maxlen = capacity)

  def push(self, state, action, reward, next_state, done):
    state = jnp.expand_dims(state, 0)
    next_state = jnp.expand_dims(next_state, 0)
    self.buffer.append((state, action, reward, next_state, done))

  def expand(self, x):
    return [i[jnp.newaxis] for i in x]

  def sample(self, batch_size):
    state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
    action = self.expand(action)
    reward = self.expand(reward)
    done = self.expand(done)
    return Batch(
        state = jnp.concatenate(state),
        action = jnp.concatenate(action),
        reward = jnp.concatenate(reward),
        next_state = jnp.concatenate(next_state),
        done = jnp.concatenate(done)
    )

  def __len__(self):
    return len(self.buffer)

## High and Low Level Policies

In [None]:
class Actor(nn.Module):
  hidden_dim: int
  out_dim: int
  activation: any = nn.relu

  @nn.compact
  def __call__(self, x, train = True):
    x = nn.Dense(self.hidden_dim)(x)
    x = self.activation(x)
    x = nn.Dense(self.out_dim)(x)
    if not train:
      x = jnp.argmax(x)
    return x


## Create Train State

In [None]:
def create_train_state(rng, model, input_dim, learning_rate):
  dummy_input = jnp.ones((1, input_dim))
  params = model.init(rng, dummy_input)["params"]
  tx = optax.adam(learning_rate)
  return train_state.TrainState.create(
      apply_fn = model.apply,
      params = params,
      tx = tx
  )

## hDQN Update Steps

In [None]:
@jax.jit
def get_action(train_state, state):
  action = train_state.apply_fn({"params": train_state.params}, state, train = False)
  return lax.stop_gradient(action)

@jax.jit
def get_goal(train_state, state):
  action = train_state.apply_fn({"params": train_state.params}, state, train = False)
  return lax.stop_gradient(action)

@jax.jit
def update_train_state(train_state, batch):
  state, action, reward, next_state, done = batch.state, batch.action, batch.reward, batch.next_state, batch.done

  def loss_fn(params):
    outs = train_state.apply_fn({"params": train_state.params}, state)
    q_vals = outs[jnp.arange(action.shape[0]), action]
    next_outs = train_state.apply_fn({"params": train_state.params}, next_state)
    next_q_vals = jnp.max(next_outs, axis = 1)
    exp_q_vals = lax.stop_gradient(reward + 0.99 * next_q_vals * (1 - done))
    loss = jnp.mean(jnp.square(q_vals - exp_q_vals))
    return loss

  grad_fn = jax.value_and_grad(loss_fn)
  _, grads = grad_fn(train_state.params)
  new_train_state = train_state.apply_gradients(grads = grads)
  return new_train_state

## hDQN Learner

In [None]:
class hDQN:
  def __init__(self, num_goals, num_actions, lr, hidden_dim, batch_size):
    self.num_goals = num_goals
    self.num_actions = num_actions
    self.lr = lr
    self.hidden_dim = hidden_dim
    self.batch_size = batch_size

    rng = jrandom.PRNGKey(42)
    self.rng, model_key, meta_key = jrandom.split(rng, 3)
    self.model = Actor(hidden_dim = self.hidden_dim, out_dim = self.num_actions)
    self.model_state = create_train_state(model_key, self.model, 2 * self.num_goals, self.lr)
    self.meta_model = Actor(hidden_dim = self.hidden_dim, out_dim = self.num_goals)
    self.meta_model_state = create_train_state(meta_key, self.meta_model, self.num_goals, self.lr)

  def update_learner(self, buffer):
    batch = buffer.sample(self.batch_size)
    self.model_state = update_train_state(self.model_state, batch)

  def update_meta_learner(self, meta_buffer):
    batch = meta_buffer.sample(self.batch_size)
    self.meta_model_state = update_train_state(self.meta_model_state, batch)


## Train

In [None]:
def main():
  hidden_dim = 256
  batch_size = 32
  learning_rate = 1e-3
  eval_interval = 500
  frame_start = 1.0
  frame_end = 0.01
  frame_decay = 500
  steps = 20000

  env = StochasticMDP()
  num_goals = env.num_states
  num_actions = env.num_actions
  replay_buffer = ReplayBuffer(10000)
  meta_replay_buffer = ReplayBuffer(10000)

  state = env.reset()
  done = False
  all_rewards = []
  episode_reward = 0
  frame_idx = 1
  epsilon_by_frame = lambda frame_idx: frame_end + (frame_start - frame_end) * math.exp(-1 * frame_idx / frame_decay)
  hdqn = hDQN(num_goals, num_actions, learning_rate, hidden_dim, batch_size)

  while frame_idx < steps:
    eps = epsilon_by_frame(frame_idx)
    goal = get_goal(hdqn.meta_model_state, state)
    if random.random() <= eps:
      goal = jnp.array(random.randrange(num_goals))
    onehot_goal = to_onehot(goal)
    meta_state = state
    extrinsic_reward = 0
    while not done and goal != np.argmax(state):
      goal_state = np.concatenate([state, onehot_goal])
      action = get_action(hdqn.model_state, goal_state)
      if random.random() < eps:
        action = jnp.array(random.randrange(num_actions))
      next_state, reward, done, _ = env.step(action)
      episode_reward += reward
      extrinsic_reward += reward
      intrinsic_reward = jnp.array(1.0 if goal == np.argmax(next_state) else 0)

      replay_buffer.push(
          goal_state, action, intrinsic_reward, np.concatenate([next_state, onehot_goal]), jnp.array(done)
      )
      state = next_state
      if len(replay_buffer) >= batch_size:
        hdqn.update_learner(replay_buffer)
      if len(meta_replay_buffer) >= batch_size:
        hdqn.update_meta_learner(meta_replay_buffer)
      frame_idx += 1
      if frame_idx % eval_interval == 0:
        print(f"Step {frame_idx} | Return {all_rewards[-1]}")
    meta_replay_buffer.push(meta_state, goal, jnp.array(extrinsic_reward), state, jnp.array(done))
    if done:
      state = env.reset()
      done = False
      all_rewards.append(episode_reward)
      episode_reward = 0

  n = 100
  ret_mean = [np.mean(all_rewards[i:i + n]) for i in range(0, len(all_rewards), n)]
  plt.figure(figsize = (10, 5))
  plt.title("StochasticMDP")
  plt.plot(ret_mean)
  plt.ylabel("Average Return")
  plt.xlabel("Episodes (x 1000)")
  plt.show()

if __name__ == "__main__":
  main()