In [1]:
import jax
import jax.lax as lax
from jax.tree_util import tree_multimap
import jax.numpy as np
import jax.random as random
import seaborn as sns
from typing import NamedTuple

In [2]:
grid_shape = (7, 10)
upwind = np.array([0, 0, 0, 1, 1, 1, 2, 2, 1, 0])
start = np.array([3, 0])
goal = np.array([3, 7])

max_position = np.array([grid_shape[0] - 1, grid_shape[1] - 1])
actions = np.array([
  [-1, 0], # up
  [1, 0], # down
  [0, -1], # left 
  [0, 1], # right
])
action_count = len(actions)
action_shape = (4,)

def wind_at_position(position):
  return np.array([-upwind[position[1]], 0])

def move(position, action):
  position += wind_at_position(position)
  position += actions[action]
  position = np.clip(position, np.array([0, 0]), max_position)
  return position

def is_terminal_state(S):
  return (S[0] == goal[0]) & (S[1] == goal[1])
  
def policy(rng, Q, S, epsilon):
  Amax = np.argmax(Q[tuple(S)])
  p = np.repeat(epsilon / action_count, action_count)
  p = jax.ops.index_add(p, Amax, (1 - epsilon))
  A = random.choice(rng, np.arange(action_count), p=p)
  return A
  
@jax.jit
def run_episode(rng, Q, epsilon, alpha):

  def cond(state):
    Q, S, A, i, rng = state
    return ~is_terminal_state(S)

  def body(state):
    Q, S, A, i, rng = state
    Sp = move(S, A)
    R = np.where(is_terminal_state(Sp), 0, -1)
    rng, r = random.split(rng)
    Ap = policy(r, Q, Sp, epsilon)
    SA = tuple(S) + (A,)
    SAp = tuple(Sp) + (Ap,)
    Q = jax.ops.index_add(Q, SA, alpha * (R + Q[SAp] - Q[SA]))
    S = Sp
    A = Ap
    i += 1
    return Q, S, A, i, rng

  S = start
  rng, r = random.split(rng)
  A = policy(r, Q, S, epsilon)
  rng, r = random.split(rng)
  initial_state = (Q, S, A, 0, r)
  Q, S, A, i, rng = lax.while_loop(cond, body, initial_state)
  return Q, i



In [3]:
rng = random.PRNGKey(42)
episodes = 1000
Q = np.zeros(grid_shape + action_shape)
epsilon = .1
alpha = .5

def scan(state, _):
  Q, rng = state
  rng, r = random.split(rng)
  Q, steps = run_episode(r, Q, epsilon=epsilon, alpha=alpha)
  return (Q, rng), (Q, steps)

rng, r = random.split(rng)
(Q, rng), (Qs, steps) = lax.scan(scan, (Q, r), xs=None, length=episodes)

In [4]:
np.argmax(Q, axis=-1)

DeviceArray([[3, 0, 3, 3, 3, 3, 3, 3, 3, 1],
             [2, 3, 0, 3, 0, 3, 3, 0, 0, 1],
             [3, 3, 3, 0, 3, 3, 3, 3, 3, 1],
             [3, 3, 3, 3, 3, 3, 1, 0, 3, 1],
             [3, 3, 3, 3, 3, 3, 0, 1, 2, 2],
             [1, 3, 3, 3, 3, 0, 0, 1, 2, 0],
             [3, 3, 3, 3, 0, 0, 0, 0, 0, 3]], dtype=int32)

In [5]:
np.sum(steps)

DeviceArray(24551, dtype=int32)

In [6]:
np.min(steps)

DeviceArray(15, dtype=int32)

In [7]:
np.average(steps[-10:])

DeviceArray(17.8, dtype=float32)