In [1]:
import gym
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers

from plot_utils import plot_animation
from cart_pole_utils import play_one_episode

In [2]:
SEED = 47

tf.random.set_seed(SEED)
np.random.seed(SEED)

In [3]:
model = tf.keras.models.Sequential([
    layers.Dense(5, activation='elu', input_shape=[4]),
    layers.Dense(1, activation='sigmoid')
])

In [4]:
env = gym.make("CartPole-v1")
env.seed(SEED)

obs = env.reset()
print(f"obs = {obs}")

prob = model.predict(obs[np.newaxis])
print(f"prob = {prob}")

env.close()

obs = [ 0.02232422 -0.02619596  0.02769897  0.00666565]
prob = [[0.50284725]]


In [5]:
def policy(obs):
    left_prob = model.predict(obs[np.newaxis])
    action = int(np.random.rand() > left_prob)
    return action

In [6]:
rewards, frames = play_one_episode(policy, render=True, seed=SEED)
print(f"rewards = {rewards}")
# plot_animation(frames)

rewards = 32.0


In [7]:
NUM_ENVS = 50
NUM_ITERS = 5000

np.random.seed(SEED)

envs = [gym.make("CartPole-v1") for _ in range(NUM_ENVS)]
for index, env in enumerate(envs):
    env.seed(index)

observations = [env.reset() for env in envs]
optimizer = keras.optimizers.RMSprop()
loss_fn = keras.losses.binary_crossentropy

for i in range(NUM_ITERS):
    # if angle < 0, we want prob(left) = 1., or else proba(left) = 0.
    target_left_probs = np.array([([1.] if obs[2] < 0 else [0.]) for obs in observations])
    with tf.GradientTape() as tape:
        predict_left_probs = model(np.array(observations))
        loss = tf.reduce_mean(loss_fn(target_left_probs, predict_left_probs))
    print("\rIteration: {}, loss: {:.3f}".format(i, loss.numpy()), end="")
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    actions = (np.random.rand(NUM_ENVS, 1) > predict_left_probs.numpy()).astype(np.int32)
    for env_index, env in enumerate(envs):
        obs, reward, done, info = env.step(actions[env_index][0])
        observations[env_index] = obs if not done else env.reset()
        
for env in envs:
    env.close()

Iteration: 4999, loss: 0.080

In [8]:
rewards, frames = play_one_episode(policy, True)
print(f"rewards = {rewards}")
# plot_animation(frames)

rewards = 38.0
