In [1]:
import dataclasses
import functools

import jax
from jax import numpy as jnp
import numpy as np

from flax import linen
from flax import struct
import optax
from tensorflow_probability.substrates import jax as tfp

tfb = tfp.bijectors
tfd = tfp.distributions

import gym
from matplotlib import pylab as plt

import daves_rl_lib
from daves_rl_lib import networks
from daves_rl_lib.algorithms import policy_gradient
from daves_rl_lib.environments import environment_lib
from daves_rl_lib.internal import video_util

video_util.initialize_virtual_display()

Initialized virtual display.


In [10]:
env_name = 'CartPole-v1'  # 'MountainCar-v0'
discount_factor = 0.99
learning_rate = 0.01

In [3]:
env = environment_lib.GymEnvironment(
    gym.make(env_name),
    discount_factor=discount_factor)

In [4]:

policy_net = networks.make_model(
            [24, 24, 2],
            obs_size=env.observation_size,
            activate_final=tfp.distributions.Categorical)

In [11]:
agent = policy_gradient.PolicyGradientAgent(
    policy_net=policy_net,
    policy_optimizer=optax.adam(learning_rate),
    discount_factor=discount_factor,
    max_num_steps=500,
    dummy_observation=env.reset(seed=jax.random.PRNGKey(0)).observation,
    dummy_action=env.action_space.dummy_action())

In [12]:
import time

seed = jax.random.PRNGKey(0)
seed, weights_init_seed, state_init_seed = jax.random.split(seed, 3)
weights = agent.init_weights(seed=weights_init_seed)

select_action = jax.jit(lambda obs, w, s: agent.action_dist(obs, w).sample(seed=s))
update = jax.jit(agent.update)

In [13]:
for episode in range(1000):

    state = env.reset(seed=state_init_seed)
    t0 = time.time()
    num_steps = 0
    while not state.done:
        seed, action_seed, env_seed = jax.random.split(seed, 3)
        action = select_action(state.observation, weights, action_seed)
        next_state = env.step(action)
        weights = update(weights,
                         transition=environment_lib.Transition(
                            observation=state.observation,
                            action=action,
                            next_observation=next_state.observation,
                            reward=next_state.reward,
                            done=next_state.done)
                        )
        state = next_state
        num_steps += 1
    print(f"Finished episode {episode} of length {num_steps} in {time.time() - t0 : .2f}s.")

Finished episode 0 of length 14 in  1.46s.
Finished episode 1 of length 12 in  0.04s.
Finished episode 2 of length 17 in  0.07s.
Finished episode 3 of length 13 in  0.05s.
Finished episode 4 of length 30 in  0.14s.
Finished episode 5 of length 13 in  0.05s.
Finished episode 6 of length 26 in  0.10s.
Finished episode 7 of length 18 in  0.07s.
Finished episode 8 of length 19 in  0.08s.
Finished episode 9 of length 12 in  0.05s.
Finished episode 10 of length 16 in  0.06s.
Finished episode 11 of length 21 in  0.10s.
Finished episode 12 of length 40 in  0.18s.
Finished episode 13 of length 15 in  0.07s.
Finished episode 14 of length 97 in  0.40s.
Finished episode 15 of length 22 in  0.09s.
Finished episode 16 of length 23 in  0.08s.
Finished episode 17 of length 18 in  0.08s.
Finished episode 18 of length 20 in  0.08s.
Finished episode 19 of length 28 in  0.11s.
Finished episode 20 of length 25 in  0.11s.
Finished episode 21 of length 18 in  0.07s.
Finished episode 22 of length 37 in  0.16s

KeyboardInterrupt: 

In [None]:
def surrogate_objective(policy_weights):
    action_dists = agent._policy_net.apply(policy_weights,
                                            observations[:-1])
    return jnp.sum(action_dists.log_prob(actions) * value_estimates)

policy_gradient = jax.grad(surrogate_objective)(weights.policy_weights)

In [None]:
policy_gradient

FrozenDict({
    params: {
        hidden_0: {
            bias: DeviceArray([-10.760525 ,   6.3141522,  -7.884516 ,  -6.9908977,
                         -12.335312 ,  -3.427275 ,   2.4799757, -23.585716 ,
                         -20.826658 , -13.423718 ,  -9.814324 ,   5.314212 ,
                          -7.981817 ,  -4.926464 , -21.78423  ,   9.9800205,
                          10.143917 , -12.011162 ,   1.0859412,  -1.1683055,
                          -7.2421722,  -2.0665948,  11.42239  ,   1.877053 ],            dtype=float32),
            kernel: DeviceArray([[  0.40485567,  -0.17495032,   0.25116804,   0.28446737,
                            0.4553626 ,   0.13152549,  -0.05846979,   0.86966723,
                            0.7634902 ,   0.4890773 ,   0.34746814,  -0.15581599,
                            0.30770284,   0.17460369,   0.7693363 ,  -0.3452842 ,
                           -0.26926816,   0.4176701 ,  -0.04409439,   0.03221723,
                            0.27160507,