In [23]:
import gym
from collections import namedtuple
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam

In [24]:
HIDDEN_SIZE = 128
BATCH_SIZE = 16
PERCENTILE = 70

In [25]:
class Net(tf.keras.Model):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(Net, self).__init__()
        self.net = tf.keras.Sequential([
            tf.keras.layers.Dense(hidden_size, input_shape = (obs_size,), activation='relu'),
            tf.keras.layers.Dense(n_actions)
        ])

    def call(self, x):
        x = tf.convert_to_tensor(x, dtype=tf.float32)
        return self.net(x)

In [26]:
Episode = namedtuple('Episode', field_names=['reward', 'steps'])
EpisodeStep = namedtuple('EpisodeStep', field_names=['observation', 'action'])

In [27]:
def iterate_batches(env, net, batch_size):
    batch = []
    episode_reward = 0.0
    episode_steps = []
    obs = env.reset()
    
    while True:
        obs_v = np.array(obs).reshape(1,-1)
        act_probs_v = tf.nn.softmax(net(obs_v)).numpy()
        action = np.random.choice(len(act_probs_v[0]), p=act_probs_v[0])
        next_obs, reward, terminated, truncated = env.step(action)
        episode_reward += reward
        step = EpisodeStep(observation=obs, action=action)
        episode_steps.append(step)
        if terminated or truncated:
            e = Episode(reward=episode_reward, steps=episode_steps)
            batch.append(e)
            episode_reward = 0.0
            episode_steps = []
            next_obs = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []
        obs = next_obs

In [28]:
def filter_batch(batch, percentile):
    rewards = [s.reward for s in batch]
    reward_bound = np.percentile(rewards, percentile)
    reward_mean = float(np.mean(rewards))

    train_obs = []
    train_act = []
    
    for reward, steps in batch:
        if reward < reward_bound:
            continue
        train_obs.extend(map(lambda step: step.observation, steps))
        train_act.extend(map(lambda step: step.action, steps))

    train_obs_v = tf.convert_to_tensor(train_obs, dtype=tf.float32)
    train_act_v = tf.convert_to_tensor(train_act, dtype=tf.float32) #assuming that actions are float
    return train_obs_v, train_act_v, reward_bound, reward_mean


In [29]:
if __name__ == "__main__":
    env = gym.make("CartPole-v1")
    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n

    # create a keras model using the Net class
    net = Net(obs_size, HIDDEN_SIZE, n_actions)
    optimizer = Adam(learning_rate=0.01)
    #net.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy')
    
    for iter_no, batch in enumerate(iterate_batches(env, net, BATCH_SIZE)):
        obs_v, acts_v, reward_b, reward_m = filter_batch(batch, PERCENTILE)
        with tf.GradientTape() as tape:
            action_scores_v = net(obs_v, training=True)
            loss_v = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(acts_v, action_scores_v, from_logits=True)
            )    
        gradients = tape.gradient(loss_v, net.trainable_variables)
        optimizer.apply_gradients(zip(gradients, net.trainable_variables))
        
        # train the model
        # net.train_on_batch(obs_v, acts_v)

        print("%d: reward_mean=%.1f, rw_bound=%.1f" % (
            iter_no, reward_m, reward_b))

        if reward_m > 199:
            print("Solved!")
            break

0: reward_mean=24.0, rw_bound=30.5
1: reward_mean=31.4, rw_bound=37.5
2: reward_mean=31.9, rw_bound=32.5
3: reward_mean=39.5, rw_bound=42.5
4: reward_mean=43.8, rw_bound=50.5
5: reward_mean=35.9, rw_bound=39.5
6: reward_mean=45.8, rw_bound=52.5
7: reward_mean=68.1, rw_bound=80.0
8: reward_mean=53.5, rw_bound=58.0
9: reward_mean=57.9, rw_bound=75.0
10: reward_mean=58.6, rw_bound=68.0
11: reward_mean=76.6, rw_bound=69.5
12: reward_mean=66.5, rw_bound=73.5
13: reward_mean=70.2, rw_bound=93.0
14: reward_mean=71.5, rw_bound=75.0
15: reward_mean=77.6, rw_bound=81.5
16: reward_mean=78.0, rw_bound=80.5
17: reward_mean=88.1, rw_bound=84.5
18: reward_mean=93.5, rw_bound=108.0
19: reward_mean=106.1, rw_bound=121.0
20: reward_mean=94.2, rw_bound=96.0
21: reward_mean=130.9, rw_bound=140.0
22: reward_mean=117.8, rw_bound=128.0
23: reward_mean=91.0, rw_bound=100.5
24: reward_mean=97.4, rw_bound=98.0
25: reward_mean=109.1, rw_bound=127.0
26: reward_mean=131.7, rw_bound=125.5
27: reward_mean=119.6, rw_