This is my first attempt at an OpenAI Gym environment. I am using vanilla REINFORCE to learn a policy for the CartPole environment.

In [None]:
!pip install jax jaxlib

In [None]:
!pip install gym

In [None]:
import random

random.seed(0)

In [None]:
# We need 4 inputs and 2 softmax outputs for our policy network.

import jax.numpy as jnp
import jax.tree_util as tree_util
from jax import grad, jit, vmap
from jax import random as jrandom

from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, Softmax

policy_init_fun, policy_net = stax.serial(Dense(256), Relu,
                                          Dense(2), Softmax)

Update rule for REINFORCE:

$\theta \leftarrow \theta + \alpha G_t \nabla \ln \pi(A_t \mid S_t, \theta)$

In [None]:
def select_action(policy_params, state):
    """
    Select an action according to a sample from the policy distribution.
    """
    
    # Select an action.
    policy_dist = policy_net(policy_params, jnp.array(state))
    action = random.choices([0, 1], weights=policy_dist)[0]
    
    return action

In [None]:
import numpy as np

# This was used for debugging.
def sum_params(params) -> float:
    acc = 0.0
    for param in params:
        if type(param) != np.float32:
            acc += sum_params(param)
        else:
            acc += param
    return acc

In [None]:
step_size = 1e-3
consecutive_solutions_required = 3 # How many times should we have to solve the task in a row to stop training?

from jax.experimental.optimizers import adam

opt_init, opt_update, opt_get_params = adam(step_size)

In [None]:
import gym

# Initialize the environment.
env = gym.make("CartPole-v1")
observation = env.reset()

# Initialize parameters.
policy_output_shape, policy_params = policy_init_fun(jrandom.PRNGKey(0), (1, 4))

# Initialize optimizer.
opt_state = opt_init(policy_params)

# This accumulator speeds up the parameter update computation.
total_reward = 0
# This is a list of tuples containing (state, action, reward).
episode_SARs = []

# Used for plotting.
total_rewards = []

consecutive_solutions = 0
episode = 1
print(f"Training until reward threshold {env.spec.reward_threshold} is attained.")
while True:
    action = select_action(policy_params, observation)
    previous_observation = observation
    observation, reward, done, info = env.step(action)
    
    # Update episode training data.
    total_reward += reward
    episode_SARs.append((previous_observation, action, reward))
    
    if done:
        total_rewards.append(total_reward)
        print(f"Episode {episode} reward achieved: {int(total_reward)}.")
        
        if total_reward >= env.spec.reward_threshold:
            consecutive_solutions += 1
            if consecutive_solutions == consecutive_solutions_required:
                # Training complete.
                print("Task solved!")
                break
        else:
            consecutive_solutions = 0
        
        # Reset the state to initial conditions.
        observation = env.reset()
        
        # Update the parameters.
        def loss(policy_params, total_reward):
            acc = 0
            for _state, _action, _reward in episode_SARs:
                # Increase the loss.
                acc -= total_reward * jnp.log(policy_net(policy_params, _state)[_action])
                # Update the return for the next step.
                total_reward -= _reward
            
            return acc
        
        policy_grad = grad(loss)(policy_params, total_reward)
        
        opt_state = opt_update(episode, policy_grad, opt_state)
        
        # Grab the updated parameters.
        policy_params = opt_get_params(opt_state)
        
        episode_SARs = []
        episode += 1
        total_reward = 0

In [None]:
from matplotlib import pyplot as plt

plt.plot(total_rewards)
plt.xlabel("Episode")
plt.ylabel("Total reward")
plt.show()

In [None]:
!apt-get install python-opengl -y
!pip install pyvirtualdisplay

In [None]:
from pyvirtualdisplay import Display
import os

display = Display(visible=0, size=(1400, 900))
display.start()
os.environ["DISPLAY"] = ":" + str(display.display) + "." + str(display._obj._screen)

In [None]:
from matplotlib import animation, rc

fig = plt.figure()

frame = []

total_reward = 0
observation = env.reset()

while True:
    action = select_action(policy_params, observation)
    observation, reward, done, info = env.step(action)
    
    total_reward += reward
    
    img = plt.imshow(env.render("rgb_array"))
    frame.append([img])
    if done:
        break

anim = animation.ArtistAnimation(fig, frame, interval=100, repeat_delay=1000, blit=True)
rc("animation", html="jshtml")

print(f"Final reward: {total_reward}.")

anim