In [None]:
# %pip install tf-agents==0.14.0

In [1]:
from tf_agents.environments.tf_py_environment import TFPyEnvironment
from tf_agents.environments.gym_wrapper import GymWrapper

from utils import make_env


env = make_env(max_step=1000)
env = GymWrapper(env)
env = TFPyEnvironment(env)

env.reset()
_ = env.render()

In [2]:
train_iters = 5_000_000
buffer_size = 1_000_000

update_period = 10 # run a training step every 10 collect steps
episode_max_step = 1_000

In [3]:
import tensorflow as tf
from tensorflow.keras import layers

from tf_agents.networks.q_network import QNetwork


preprocessing_layer = layers.Lambda(lambda obs: tf.cast(obs, tf.float32) / 4.)
fc_layer_params = [128, 64, 64]
dropout_layer_params = [0.5, 0.5, 0.5]

q_net = QNetwork(
    env.observation_spec(),
    env.action_spec(),
    preprocessing_layers=preprocessing_layer,
    fc_layer_params=fc_layer_params,
    dropout_layer_params=dropout_layer_params
)

In [4]:
from tensorflow.keras import optimizers, losses

from tf_agents.agents.dqn.dqn_agent import DqnAgent


train_step = tf.Variable(0)

optimizer = optimizers.Adam(learning_rate=3e-7)
                                     
epsilon_fn = optimizers.schedules.PolynomialDecay(
    initial_learning_rate=1.0,
    decay_steps=train_iters // update_period,
    end_learning_rate=0.01
)

agent = DqnAgent(
    env.time_step_spec(),
    env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    target_update_period=2_000,
    td_errors_loss_fn=losses.Huber(reduction="none"),
    gamma=0.95, # discount factor
    train_step_counter=train_step,
    epsilon_greedy=lambda: epsilon_fn(train_step)
)

agent.initialize()

In [5]:
from tf_agents.replay_buffers import tf_uniform_replay_buffer


replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=env.batch_size,
    max_length=buffer_size
)
replay_buffer_observer = replay_buffer.add_batch

dataset = replay_buffer.as_dataset(
    sample_batch_size=64,
    num_steps=2,
    num_parallel_calls=3
).prefetch(3)

Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


In [6]:
from tf_agents.metrics import tf_metrics
from tf_agents.eval.metric_utils import log_metrics

import logging


class ShowProgress:

    def __init__(self, total):
        self.counter = 0
        self.total = total
        
    def __call__(self, trajectory):
        if not trajectory.is_boundary():
            self.counter += 1
        if self.counter % 100 == 0:
            print("\r{}/{}".format(self.counter, self.total), end="")
            

train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric(),
]

logging.getLogger().setLevel(logging.INFO)
log_metrics(train_metrics)

INFO:absl: 
		 NumberOfEpisodes = 0
		 EnvironmentSteps = 0
		 AverageReturn = 0.0
		 AverageEpisodeLength = 0.0


In [7]:
from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver


# class LimitSteps:

#     def __init__(self, max):
#         self.max = max
#         self.counter = 0
        
#     @tf.function
#     def __call__(self, trajectory):
#         # if not trajectory.is_boundary():
#         #     self.counter += 1
#         # else:
#         #     self.counter = 0

#         if self.counter > self.max:
#             env.reset()


collect_driver = DynamicStepDriver(
    env,
    agent.collect_policy,
    observers=[replay_buffer_observer] + train_metrics,
    num_steps=update_period
)

In [8]:
from tf_agents.policies.random_tf_policy import RandomTFPolicy


initial_collect_policy = RandomTFPolicy(
    env.time_step_spec(),
    env.action_spec()
)

init_driver = DynamicStepDriver(
    env,
    initial_collect_policy,
    observers=[replay_buffer.add_batch, ShowProgress(buffer_size//4)],
    num_steps=buffer_size//4
)

final_time_step, final_policy_state = init_driver.run()

250000/250000

In [9]:
from tf_agents.utils.common import function


collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)

In [10]:
def train_agent(n_iterations):
    best_average_return = 0

    time_step = None
    policy_state = agent.collect_policy.get_initial_state(env.batch_size)

    iterator = iter(dataset)
    for iteration in range(n_iterations):
        time_step, policy_state = collect_driver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        
        train_loss = agent.train(trajectories)

        print("\r{} loss:{:.5f}".format(iteration, train_loss.loss.numpy()), end="")

        if iteration % (n_iterations//1000) == 0:
            log_metrics(train_metrics)

        if train_metrics[2].result() > best_average_return:
            best_average_return = train_metrics[2].result()
            print(f"Saving best model with: Average Return of {best_average_return} in iters of {iteration}")
            tf.saved_model.save(agent._q_network, f"models/tf-agents/logs/Snake DQN TF-Agents ({n_iterations} iters) with rwd#{best_average_return}")

    log_metrics(train_metrics)
    tf.saved_model.save(agent._q_network, f"models/tf-agents/Snake DQN TF-Agents ({n_iterations} iters)")

In [11]:
train_agent(n_iterations=train_iters)

Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))


Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
INFO:absl: 
		 NumberOfEpisodes = 0
		 EnvironmentSteps = 10
		 AverageReturn = 0.0
		 AverageEpisodeLength = 0.0


4999 loss:0.45013

INFO:absl: 
		 NumberOfEpisodes = 138
		 EnvironmentSteps = 50010
		 AverageReturn = -315.5895080566406
		 AverageEpisodeLength = 395.79998779296875


9998 loss:1.05939

INFO:absl: 
		 NumberOfEpisodes = 271
		 EnvironmentSteps = 100010
		 AverageReturn = -234.10958862304688
		 AverageEpisodeLength = 325.5


14995 loss:0.90009

INFO:absl: 
		 NumberOfEpisodes = 399
		 EnvironmentSteps = 150010
		 AverageReturn = -294.71002197265625
		 AverageEpisodeLength = 358.8999938964844


20000 loss:0.90225

INFO:absl: 
		 NumberOfEpisodes = 522
		 EnvironmentSteps = 200010
		 AverageReturn = -323.2605895996094
		 AverageEpisodeLength = 370.3999938964844


25000 loss:0.45416

INFO:absl: 
		 NumberOfEpisodes = 661
		 EnvironmentSteps = 250010
		 AverageReturn = -280.0398864746094
		 AverageEpisodeLength = 331.6000061035156


29998 loss:0.74252

INFO:absl: 
		 NumberOfEpisodes = 794
		 EnvironmentSteps = 300010
		 AverageReturn = -264.5798034667969
		 AverageEpisodeLength = 313.79998779296875


34998 loss:1.04908

INFO:absl: 
		 NumberOfEpisodes = 918
		 EnvironmentSteps = 350010
		 AverageReturn = -424.1109313964844
		 AverageEpisodeLength = 515.5999755859375


39998 loss:0.75457

INFO:absl: 
		 NumberOfEpisodes = 1058
		 EnvironmentSteps = 400010
		 AverageReturn = -227.77963256835938
		 AverageEpisodeLength = 336.5


44997 loss:0.44660

INFO:absl: 
		 NumberOfEpisodes = 1180
		 EnvironmentSteps = 450010
		 AverageReturn = -391.54034423828125
		 AverageEpisodeLength = 509.0


49997 loss:0.73099

INFO:absl: 
		 NumberOfEpisodes = 1310
		 EnvironmentSteps = 500010
		 AverageReturn = -356.4299621582031
		 AverageEpisodeLength = 412.70001220703125


54996 loss:0.89279

INFO:absl: 
		 NumberOfEpisodes = 1449
		 EnvironmentSteps = 550010
		 AverageReturn = -160.9098663330078
		 AverageEpisodeLength = 258.6000061035156


60000 loss:0.60291

INFO:absl: 
		 NumberOfEpisodes = 1589
		 EnvironmentSteps = 600010
		 AverageReturn = -268.09954833984375
		 AverageEpisodeLength = 369.8999938964844


64998 loss:1.35089

INFO:absl: 
		 NumberOfEpisodes = 1724
		 EnvironmentSteps = 650010
		 AverageReturn = -192.39974975585938
		 AverageEpisodeLength = 277.79998779296875


70000 loss:0.60685

INFO:absl: 
		 NumberOfEpisodes = 1850
		 EnvironmentSteps = 700010
		 AverageReturn = -336.020263671875
		 AverageEpisodeLength = 417.0


74998 loss:0.46386

INFO:absl: 
		 NumberOfEpisodes = 1978
		 EnvironmentSteps = 750010
		 AverageReturn = -268.47027587890625
		 AverageEpisodeLength = 387.29998779296875


79997 loss:0.45176

INFO:absl: 
		 NumberOfEpisodes = 2117
		 EnvironmentSteps = 800010
		 AverageReturn = -205.1697998046875
		 AverageEpisodeLength = 329.20001220703125


84996 loss:1.33308

INFO:absl: 
		 NumberOfEpisodes = 2253
		 EnvironmentSteps = 850010
		 AverageReturn = -196.8495635986328
		 AverageEpisodeLength = 293.79998779296875


89996 loss:1.04686

INFO:absl: 
		 NumberOfEpisodes = 2374
		 EnvironmentSteps = 900010
		 AverageReturn = -327.5998840332031
		 AverageEpisodeLength = 479.8999938964844


94997 loss:0.89337

INFO:absl: 
		 NumberOfEpisodes = 2501
		 EnvironmentSteps = 950010
		 AverageReturn = -253.4294891357422
		 AverageEpisodeLength = 417.70001220703125


99998 loss:0.15334

INFO:absl: 
		 NumberOfEpisodes = 2638
		 EnvironmentSteps = 1000010
		 AverageReturn = -158.35995483398438
		 AverageEpisodeLength = 262.70001220703125


104997 loss:0.89357

INFO:absl: 
		 NumberOfEpisodes = 2764
		 EnvironmentSteps = 1050010
		 AverageReturn = -147.1097412109375
		 AverageEpisodeLength = 252.6999969482422


109997 loss:0.74286

INFO:absl: 
		 NumberOfEpisodes = 2891
		 EnvironmentSteps = 1100010
		 AverageReturn = -252.47958374023438
		 AverageEpisodeLength = 386.8999938964844


114998 loss:0.74605

INFO:absl: 
		 NumberOfEpisodes = 3014
		 EnvironmentSteps = 1150010
		 AverageReturn = -190.419921875
		 AverageEpisodeLength = 345.1000061035156


119997 loss:0.15958

INFO:absl: 
		 NumberOfEpisodes = 3139
		 EnvironmentSteps = 1200010
		 AverageReturn = -254.4198455810547
		 AverageEpisodeLength = 433.70001220703125


124999 loss:0.75955

INFO:absl: 
		 NumberOfEpisodes = 3273
		 EnvironmentSteps = 1250010
		 AverageReturn = -332.12103271484375
		 AverageEpisodeLength = 533.2999877929688


129996 loss:1.33991

INFO:absl: 
		 NumberOfEpisodes = 3393
		 EnvironmentSteps = 1300010
		 AverageReturn = -293.5908203125
		 AverageEpisodeLength = 447.8999938964844


134999 loss:0.59962

INFO:absl: 
		 NumberOfEpisodes = 3515
		 EnvironmentSteps = 1350010
		 AverageReturn = -162.10968017578125
		 AverageEpisodeLength = 356.8999938964844


139997 loss:0.74831

INFO:absl: 
		 NumberOfEpisodes = 3651
		 EnvironmentSteps = 1400010
		 AverageReturn = -229.6694793701172
		 AverageEpisodeLength = 415.70001220703125


144995 loss:1.45971

INFO:absl: 
		 NumberOfEpisodes = 3770
		 EnvironmentSteps = 1450010
		 AverageReturn = -217.54931640625
		 AverageEpisodeLength = 444.0


149995 loss:1.04579

INFO:absl: 
		 NumberOfEpisodes = 3889
		 EnvironmentSteps = 1500010
		 AverageReturn = -321.4305725097656
		 AverageEpisodeLength = 495.8999938964844


154999 loss:0.60186

INFO:absl: 
		 NumberOfEpisodes = 4009
		 EnvironmentSteps = 1550010
		 AverageReturn = -282.5997009277344
		 AverageEpisodeLength = 522.2999877929688


159996 loss:0.30760

INFO:absl: 
		 NumberOfEpisodes = 4125
		 EnvironmentSteps = 1600010
		 AverageReturn = -154.29953002929688
		 AverageEpisodeLength = 355.6000061035156


164998 loss:0.88539

INFO:absl: 
		 NumberOfEpisodes = 4242
		 EnvironmentSteps = 1650010
		 AverageReturn = -302.7800598144531
		 AverageEpisodeLength = 606.4000244140625


169997 loss:0.60161

INFO:absl: 
		 NumberOfEpisodes = 4361
		 EnvironmentSteps = 1700010
		 AverageReturn = -233.52975463867188
		 AverageEpisodeLength = 374.1000061035156


174999 loss:0.45455

INFO:absl: 
		 NumberOfEpisodes = 4484
		 EnvironmentSteps = 1750010
		 AverageReturn = -189.41954040527344
		 AverageEpisodeLength = 433.0


179998 loss:0.60480

INFO:absl: 
		 NumberOfEpisodes = 4609
		 EnvironmentSteps = 1800010
		 AverageReturn = -230.4996795654297
		 AverageEpisodeLength = 450.79998779296875


184999 loss:0.30958

INFO:absl: 
		 NumberOfEpisodes = 4730
		 EnvironmentSteps = 1850010
		 AverageReturn = -206.6894073486328
		 AverageEpisodeLength = 413.3999938964844


189997 loss:0.45050

INFO:absl: 
		 NumberOfEpisodes = 4836
		 EnvironmentSteps = 1900010
		 AverageReturn = -232.7694549560547
		 AverageEpisodeLength = 511.70001220703125


194999 loss:1.31801

INFO:absl: 
		 NumberOfEpisodes = 4950
		 EnvironmentSteps = 1950010
		 AverageReturn = -194.45936584472656
		 AverageEpisodeLength = 422.70001220703125


199998 loss:0.45941

INFO:absl: 
		 NumberOfEpisodes = 5074
		 EnvironmentSteps = 2000010
		 AverageReturn = -209.48934936523438
		 AverageEpisodeLength = 448.79998779296875


205000 loss:0.29559

INFO:absl: 
		 NumberOfEpisodes = 5193
		 EnvironmentSteps = 2050010
		 AverageReturn = -297.67010498046875
		 AverageEpisodeLength = 484.6000061035156


209997 loss:0.90197

INFO:absl: 
		 NumberOfEpisodes = 5306
		 EnvironmentSteps = 2100010
		 AverageReturn = -212.810546875
		 AverageEpisodeLength = 479.1000061035156


214998 loss:0.16178

INFO:absl: 
		 NumberOfEpisodes = 5406
		 EnvironmentSteps = 2150010
		 AverageReturn = -251.4903106689453
		 AverageEpisodeLength = 572.2000122070312


219996 loss:0.30651

INFO:absl: 
		 NumberOfEpisodes = 5510
		 EnvironmentSteps = 2200010
		 AverageReturn = -257.3700256347656
		 AverageEpisodeLength = 530.5


224999 loss:0.60536

INFO:absl: 
		 NumberOfEpisodes = 5612
		 EnvironmentSteps = 2250010
		 AverageReturn = -267.21002197265625
		 AverageEpisodeLength = 518.2999877929688


229999 loss:0.45973

INFO:absl: 
		 NumberOfEpisodes = 5711
		 EnvironmentSteps = 2300010
		 AverageReturn = -196.3594207763672
		 AverageEpisodeLength = 488.5


234998 loss:0.30912

INFO:absl: 
		 NumberOfEpisodes = 5811
		 EnvironmentSteps = 2350010
		 AverageReturn = -265.7491149902344
		 AverageEpisodeLength = 647.4000244140625


240000 loss:0.75824

INFO:absl: 
		 NumberOfEpisodes = 5903
		 EnvironmentSteps = 2400010
		 AverageReturn = -200.4195556640625
		 AverageEpisodeLength = 482.70001220703125


244996 loss:1.04235

INFO:absl: 
		 NumberOfEpisodes = 6009
		 EnvironmentSteps = 2450010
		 AverageReturn = -146.23965454101562
		 AverageEpisodeLength = 376.3999938964844


250000 loss:0.73391

INFO:absl: 
		 NumberOfEpisodes = 6101
		 EnvironmentSteps = 2500010
		 AverageReturn = -227.43954467773438
		 AverageEpisodeLength = 531.5999755859375


254997 loss:0.15970

INFO:absl: 
		 NumberOfEpisodes = 6192
		 EnvironmentSteps = 2550010
		 AverageReturn = -328.01953125
		 AverageEpisodeLength = 712.7999877929688


259997 loss:1.04636

INFO:absl: 
		 NumberOfEpisodes = 6286
		 EnvironmentSteps = 2600010
		 AverageReturn = -274.08990478515625
		 AverageEpisodeLength = 614.2000122070312


264998 loss:0.60858

INFO:absl: 
		 NumberOfEpisodes = 6372
		 EnvironmentSteps = 2650010
		 AverageReturn = -128.3494873046875
		 AverageEpisodeLength = 461.1000061035156


269999 loss:0.16060

INFO:absl: 
		 NumberOfEpisodes = 6459
		 EnvironmentSteps = 2700010
		 AverageReturn = -228.4298553466797
		 AverageEpisodeLength = 538.4000244140625


274998 loss:0.45542

INFO:absl: 
		 NumberOfEpisodes = 6546
		 EnvironmentSteps = 2750010
		 AverageReturn = -195.6701202392578
		 AverageEpisodeLength = 522.0999755859375


279998 loss:0.75063

INFO:absl: 
		 NumberOfEpisodes = 6637
		 EnvironmentSteps = 2800010
		 AverageReturn = -178.35000610351562
		 AverageEpisodeLength = 537.0


284999 loss:0.45280

INFO:absl: 
		 NumberOfEpisodes = 6723
		 EnvironmentSteps = 2850010
		 AverageReturn = -221.5706787109375
		 AverageEpisodeLength = 544.9000244140625


289997 loss:0.30515

INFO:absl: 
		 NumberOfEpisodes = 6815
		 EnvironmentSteps = 2900010
		 AverageReturn = -144.25018310546875
		 AverageEpisodeLength = 469.1000061035156


294997 loss:0.44471

INFO:absl: 
		 NumberOfEpisodes = 6888
		 EnvironmentSteps = 2950010
		 AverageReturn = -211.57992553710938
		 AverageEpisodeLength = 610.2000122070312


299996 loss:0.45076

INFO:absl: 
		 NumberOfEpisodes = 6966
		 EnvironmentSteps = 3000010
		 AverageReturn = -209.77023315429688
		 AverageEpisodeLength = 607.7999877929688


304996 loss:0.01582

INFO:absl: 
		 NumberOfEpisodes = 7041
		 EnvironmentSteps = 3050010
		 AverageReturn = -234.0299835205078
		 AverageEpisodeLength = 638.2000122070312


309997 loss:0.30241

INFO:absl: 
		 NumberOfEpisodes = 7121
		 EnvironmentSteps = 3100010
		 AverageReturn = -248.5295867919922
		 AverageEpisodeLength = 649.5


314997 loss:0.59737

INFO:absl: 
		 NumberOfEpisodes = 7195
		 EnvironmentSteps = 3150010
		 AverageReturn = -227.0993194580078
		 AverageEpisodeLength = 635.4000244140625


319999 loss:0.16609

INFO:absl: 
		 NumberOfEpisodes = 7272
		 EnvironmentSteps = 3200010
		 AverageReturn = -190.68994140625
		 AverageEpisodeLength = 679.5


324996 loss:0.61209

INFO:absl: 
		 NumberOfEpisodes = 7341
		 EnvironmentSteps = 3250010
		 AverageReturn = -161.579833984375
		 AverageEpisodeLength = 610.2000122070312


329996 loss:0.16209

INFO:absl: 
		 NumberOfEpisodes = 7413
		 EnvironmentSteps = 3300010
		 AverageReturn = -135.1595001220703
		 AverageEpisodeLength = 550.4000244140625


334998 loss:0.45390

INFO:absl: 
		 NumberOfEpisodes = 7482
		 EnvironmentSteps = 3350010
		 AverageReturn = -214.5096435546875
		 AverageEpisodeLength = 694.5999755859375


339997 loss:0.45106

INFO:absl: 
		 NumberOfEpisodes = 7551
		 EnvironmentSteps = 3400010
		 AverageReturn = -260.9894104003906
		 AverageEpisodeLength = 703.9000244140625


345000 loss:0.45862

INFO:absl: 
		 NumberOfEpisodes = 7618
		 EnvironmentSteps = 3450010
		 AverageReturn = -275.3587951660156
		 AverageEpisodeLength = 923.7000122070312


349997 loss:0.15139

INFO:absl: 
		 NumberOfEpisodes = 7686
		 EnvironmentSteps = 3500010
		 AverageReturn = -246.73941040039062
		 AverageEpisodeLength = 717.5


354998 loss:0.45203

INFO:absl: 
		 NumberOfEpisodes = 7752
		 EnvironmentSteps = 3550010
		 AverageReturn = -174.7889404296875
		 AverageEpisodeLength = 703.2999877929688


359998 loss:0.15728

INFO:absl: 
		 NumberOfEpisodes = 7812
		 EnvironmentSteps = 3600010
		 AverageReturn = -279.3794250488281
		 AverageEpisodeLength = 831.0


364997 loss:0.72297

INFO:absl: 
		 NumberOfEpisodes = 7873
		 EnvironmentSteps = 3650010
		 AverageReturn = -224.3194580078125
		 AverageEpisodeLength = 887.2999877929688


369998 loss:0.45142

INFO:absl: 
		 NumberOfEpisodes = 7936
		 EnvironmentSteps = 3700010
		 AverageReturn = -176.89938354492188
		 AverageEpisodeLength = 853.2999877929688


374997 loss:0.60208

INFO:absl: 
		 NumberOfEpisodes = 7996
		 EnvironmentSteps = 3750010
		 AverageReturn = -208.71875
		 AverageEpisodeLength = 824.5999755859375


379998 loss:0.44971

INFO:absl: 
		 NumberOfEpisodes = 8055
		 EnvironmentSteps = 3800010
		 AverageReturn = -217.1383819580078
		 AverageEpisodeLength = 954.4000244140625


384997 loss:0.31005

INFO:absl: 
		 NumberOfEpisodes = 8110
		 EnvironmentSteps = 3850010
		 AverageReturn = -222.48873901367188
		 AverageEpisodeLength = 912.0999755859375


389998 loss:0.15912

INFO:absl: 
		 NumberOfEpisodes = 8167
		 EnvironmentSteps = 3900010
		 AverageReturn = -194.76913452148438
		 AverageEpisodeLength = 767.4000244140625


390671 loss:0.15823

KeyboardInterrupt: 

In [13]:
frames = []

env.pyenv.envs[0].reset()

def save_frames(trajectory):
    global frames
    frames.append(env.pyenv.envs[0].render())

watch_driver = DynamicStepDriver(
    env,
    agent.policy,
    observers=[save_frames, ShowProgress(2_000)],
    num_steps=2_000)
final_time_step, final_policy_state = watch_driver.run()

2000/2000

In [14]:
import PIL

import os


image_path = os.path.join("rl videos", f"snake_tfagents_{train_iters}.gif")

frame_images = []
for frame in frames:
    frame_images.append(PIL.Image.fromarray(frame))

frame_images[0].save(image_path, format='GIF',
                     append_images=frame_images[1:],
                     save_all=True,
                     duration=30,
                     loop=0)