Standard imports:

In [1]:
import gym
import jax
import jax.numpy as jnp
import optax
import coax
import haiku as hk

Define an environment:

In [2]:
env = gym.make('FrozenLakeNonSlippery-v0')

print(env.observation_space) # a 4 x 4 grid
print(env.action_space) # up, down, left, right

Discrete(16)
Discrete(4)


To roll out an episode:

In [3]:
# initialize a current state
s = env.reset()
print("Initial state:")
env.render()

print("\nUnrolling an episode:")
for t in range(env.spec.max_episode_steps):
    # randomly select an action
    a = env.action_space.sample()
    
    # take the action and visualize the new state
    s_next, r, done, info = env.step(a)
    env.render()
    
    if done:
        break
        
    # move current state to the next state
    s = s_next

Initial state:

[41mS[0mFFF
FHFH
FFFH
HFFG

Unrolling an episode:
  (Down)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Down)
SFFF
FHFH
[41mF[0mFFH
HFFG
  (Up)
SFFF
[41mF[0mHFH
FFFH
HFFG
  (Right)
SFFF
F[41mH[0mFH
FFFH
HFFG


Define a Q-function to select an action:

In [4]:
def forward_pass(S, is_training):
    lin = hk.Linear(env.action_space.n, w_init=jnp.zeros)
    return lin(S)

q = coax.Q(forward_pass, env)



Define a value-based policy:

In [5]:
pi = coax.BoltzmannPolicy(q, temperature=0.1)

How to update the policy:

In [6]:
qlearning = coax.td_learning.QLearning(q)
nstep = coax.reward_tracing.NStep(n=1, gamma=0.9)

Train for 500 episodes:

In [7]:
%%time
# wrapper to get some training logs
# need to restart kernel?
env = coax.wrappers.TrainMonitor(env)

for _ in range(500):
    s = env.reset()
    
    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, info = env.step(a)
        
        # update the q-function
        nstep.add(s, a, r, done)
        while nstep:
            transition = nstep.pop()
            qlearning.update(transition)
        
        if done:
            break
        
        s = s_next

CPU times: user 2min 4s, sys: 1min 23s, total: 3min 27s
Wall time: 2min 4s


Render an episode with the trained policy:

In [8]:
coax.render_episode(env, policy=pi.mode)


[41mS[0mFFF
FHFH
FFFH
HFFG
  (Right)
S[41mF[0mFF
FHFH
FFFH
HFFG
  (Right)
SF[41mF[0mF
FHFH
FFFH
HFFG
  (Down)
SFFF
FH[41mF[0mH
FFFH
HFFG
  (Down)
SFFF
FHFH
FF[41mF[0mH
HFFG
  (Down)
SFFF
FHFH
FFFH
HF[41mF[0mG
  (Right)
SFFF
FHFH
FFFH
HFF[41mG[0m
