In [1]:
import gym
import keras_gym as km
from tensorflow.keras.layers import Conv2D, Lambda, Dense, Flatten
from tensorflow.keras import backend as K



In [2]:
env = gym.make('Freeway-v0')
env = km.wrappers.ImagePreprocessor(env, height=84, width=84, grayscale=True)
env = km.wrappers.FrameStacker(env, num_frames=3)
env = km.wrappers.TrainMonitor(env)

In [3]:
km.enable_logging()

In [4]:
class Func(km.FunctionApproximator):
    def body(self, S):
        def diff_transform(S):
            S = K.cast(S, 'float32') / 255
            M = km.utils.diff_transform_matrix(num_frames=3)
            return K.dot(S, M)

        X = Lambda(diff_transform)(S)
        X = Conv2D(filters=32, kernel_size=8, strides=4, activation='relu')(X)
        X = Conv2D(filters=64, kernel_size=4, strides=2, activation='relu')(X)
        X = Conv2D(filters=64, kernel_size=3, strides=2, activation='relu')(X)
        X = Flatten()(X)
        X = Dense(units=512, activation='relu')(X)
        X = Dense(units=256, activation='relu')(X)        
        X = Dense(units=3, activation='linear')(X)
        return X

In [5]:
func = Func(env, lr=0.01)
pi = km.SoftmaxPolicy(func, update_strategy='ppo')
v = km.V(func, gamma=0.99, bootstrap_n=250000, bootstrap_with_target_model=True)
actor_critic = km.ActorCritic(pi, v)

In [6]:
buffer = km.caching.ExperienceReplayBuffer.from_value_function(
    value_function=v, capacity=100000, batch_size=64)

In [9]:
while env.T < 1000000:
    s = env.reset()

    for t in range(env.spec.max_episode_steps):
        a = pi(s, use_target_model=True)  # target_model == pi_old
        s_next, r, done, info = env.step(a)

        buffer.add(s, a, r, done, env.ep)

        if len(buffer) >= buffer.capacity:
            num_batches = int(4 * buffer.capacity / buffer.batch_size)
            for _ in range(num_batches):
                actor_critic.batch_update(*buffer.sample())
            buffer.clear()

            actor_critic.sync_target_model(tau=0.5)

        if done:
            break

        s = s_next

    if env.ep % 50 == 0:
        km.utils.generate_gif(
            env=env,
            policy=pi,
            filepath='Logs/{:06d}.gif'.format(env.ep),
            resize_to=(320, 420))

ms
INFO:TrainMonitor:ep: 60, T: 163,907, G: 0, avg_G: 0, t: 2725, dt: 4.903ms
INFO:TrainMonitor:ep: 61, T: 166,652, G: 0, avg_G: 0, t: 2744, dt: 4.910ms
INFO:TrainMonitor:ep: 62, T: 169,376, G: 0, avg_G: 0, t: 2723, dt: 4.923ms
INFO:TrainMonitor:ep: 63, T: 172,104, G: 0, avg_G: 0, t: 2727, dt: 4.904ms
INFO:TrainMonitor:ep: 64, T: 174,843, G: 0, avg_G: 0, t: 2738, dt: 4.917ms
INFO:TrainMonitor:ep: 65, T: 177,573, G: 0, avg_G: 0, t: 2729, dt: 4.976ms
INFO:TrainMonitor:ep: 66, T: 180,312, G: 0, avg_G: 0, t: 2738, dt: 4.907ms
INFO:TrainMonitor:ep: 67, T: 183,057, G: 0, avg_G: 0, t: 2744, dt: 4.904ms
INFO:TrainMonitor:ep: 68, T: 185,810, G: 0, avg_G: 0, t: 2752, dt: 4.968ms
INFO:TrainMonitor:ep: 69, T: 188,533, G: 0, avg_G: 0, t: 2722, dt: 4.957ms
INFO:TrainMonitor:ep: 70, T: 191,247, G: 0, avg_G: 0, t: 2713, dt: 4.940ms
INFO:TrainMonitor:ep: 71, T: 193,979, G: 0, avg_G: 0, t: 2731, dt: 4.935ms
INFO:TrainMonitor:ep: 72, T: 196,718, G: 0, avg_G: 0, t: 2738, dt: 4.901ms
INFO:TrainMonitor:ep: 

KeyboardInterrupt: 