In [17]:
from __future__ import absolute_import, division, print_function

import base64
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image

import tensorflow as tf

from tf_agents.agents.ppo import ppo_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import sequential
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
tf.compat.v1.enable_v2_behavior()

In [18]:
num_iterations = 5000 # @param {type:"integer"}
collect_episodes_per_iteration = 2 # @param {type:"integer"}

initial_collect_steps = 1000  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 1  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 100  # @param {type:"integer"}

In [19]:
import gym
import slimevolleygym
from tf_agents.environments import suite_gym

env_name = "SlimeVolley-v0"

In [20]:
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

In [21]:
from tf_agents.networks.actor_distribution_network import ActorDistributionNetwork
from tf_agents.networks.value_network import ValueNetwork

def create_networks(observation_spec, action_spec):
    actor_net = ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=(200, 100),
        activation_fn=tf.nn.elu)
    value_net = ValueNetwork(
        observation_spec,
        fc_layer_params=(200, 100),
        activation_fn=tf.nn.elu)
    return actor_net, value_net


actor_net, value_net = create_networks(train_env.observation_spec(), train_env.action_spec())

In [22]:
global_step = tf.compat.v1.train.get_or_create_global_step()
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate, epsilon=1e-5)

In [23]:
tf_agent = ppo_agent.PPOAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    optimizer,
    actor_net,
    value_net,
    num_epochs=25,
    train_step_counter=global_step,
    discount_factor=0.99,
    gradient_clipping=0.5,
    entropy_regularization=1e-2,
    importance_ratio_clipping=0.2,
    use_gae=True,
    use_td_lambda_return=True
)
tf_agent.initialize()



In [24]:
eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy

In [25]:
def compute_avg_return(environment, policy, num_episodes=10):
    total_return = 0.0
    for _ in range(num_episodes):

        time_step = environment.reset()
        episode_return = 0.0

        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
        total_return += episode_return

    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]


# Please also see the metrics module for standard implementations of different
# metrics.

In [26]:
replay_buffer_capacity=301
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=tf_agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_capacity)

In [27]:
def collect_episode(environment, policy, num_episodes):

    episode_counter = 0
    environment.reset()

    while episode_counter < num_episodes:
        time_step = environment.current_time_step()
        action_step = policy.action(time_step)
        next_time_step = environment.step(action_step.action)
        traj = trajectory.from_transition(time_step, action_step, next_time_step)

        # Add trajectory to the replay buffer
        replay_buffer.add_batch(traj)

        if traj.is_boundary():
            episode_counter += 1


# This loop is so common in RL, that we provide standard implementations of
# these. For more details see the drivers module.

In [35]:
train_env.current_time_step().reward

<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>

In [28]:
from tqdm.notebook import tqdm

try:
    %%time
except:
    pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
tf_agent.train = common.function(tf_agent.train)

# Reset the train step
tf_agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in tqdm(range(num_iterations)):

    # Collect a few episodes using collect_policy and save to the replay buffer.
    collect_episode(
        train_env, tf_agent.collect_policy, collect_episodes_per_iteration)

    # Use data from the buffer and update the agent's network.
    experience = replay_buffer.gather_all()
    train_loss = tf_agent.train(experience)
    replay_buffer.clear()

    step = tf_agent.train_step_counter.numpy()

    if step % log_interval == 0:
        print('step = {0}: loss = {1}'.format(step, train_loss.loss))

    if step % eval_interval == 0:
        avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
        print('step = {0}: Average Return = {1}'.format(step, avg_return))
        returns.append(avg_return)

HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))

step = 25: loss = 1.1559851169586182
step = 50: loss = 1.2501906156539917
step = 75: loss = 0.4638909101486206
step = 100: loss = 2.630178689956665
step = 100: Average Return = -5.0
step = 125: loss = 0.9760697484016418
step = 150: loss = 1.2097859382629395
step = 175: loss = 0.7155101299285889
step = 200: loss = 0.29848402738571167
step = 200: Average Return = -4.699999809265137
step = 225: loss = 0.9492667317390442
step = 250: loss = 0.627423882484436
step = 275: loss = 0.30667296051979065
step = 300: loss = 0.43363332748413086
step = 300: Average Return = -4.900000095367432
step = 325: loss = 0.2887089252471924
step = 350: loss = 0.6410825848579407
step = 375: loss = 0.6365753412246704
step = 400: loss = 0.8086400032043457
step = 400: Average Return = -4.800000190734863
step = 425: loss = 0.49603354930877686
step = 450: loss = 0.7364003658294678
step = 475: loss = 1.2218672037124634
step = 500: loss = 0.5745640993118286
step = 500: Average Return = -4.699999809265137
step = 525: los

step = 4100: loss = 0.5263229012489319
step = 4100: Average Return = -5.0
step = 4125: loss = 0.5788453221321106
step = 4150: loss = 0.37773191928863525
step = 4175: loss = 0.2932767868041992
step = 4200: loss = 0.36351001262664795
step = 4200: Average Return = -4.800000190734863
step = 4225: loss = 0.280598521232605
step = 4250: loss = 0.27224165201187134
step = 4275: loss = 0.6963028907775879
step = 4300: loss = 0.3055267930030823
step = 4300: Average Return = -4.900000095367432
step = 4325: loss = 0.2434307038784027
step = 4350: loss = 0.6430200338363647
step = 4375: loss = 0.41566312313079834
step = 4400: loss = 0.2943326234817505
step = 4400: Average Return = -4.699999809265137
step = 4425: loss = 0.7132957577705383
step = 4450: loss = 0.6608529090881348
step = 4475: loss = 0.12000926584005356
step = 4500: loss = 0.09917829930782318
step = 4500: Average Return = -4.800000190734863
step = 4525: loss = 0.45170268416404724
step = 4550: loss = 0.24010147154331207
step = 4575: loss = 0

step = 8150: loss = 0.2718872129917145
step = 8175: loss = 0.3116513192653656
step = 8200: loss = 0.3563634157180786
step = 8200: Average Return = -4.900000095367432
step = 8225: loss = 0.3683766722679138
step = 8250: loss = 0.573664665222168
step = 8275: loss = 0.4140697121620178
step = 8300: loss = 0.5274245142936707
step = 8300: Average Return = -4.900000095367432
step = 8325: loss = 0.5427643656730652
step = 8350: loss = 0.33407747745513916
step = 8375: loss = 0.4078744053840637
step = 8400: loss = 0.6433783769607544
step = 8400: Average Return = -4.800000190734863
step = 8425: loss = 0.46205881237983704
step = 8450: loss = 0.352082222700119
step = 8475: loss = 0.4168441593647003
step = 8500: loss = 0.34541451930999756
step = 8500: Average Return = -4.900000095367432
step = 8525: loss = 0.5237377882003784
step = 8550: loss = 0.5200231671333313
step = 8575: loss = 0.4439460039138794
step = 8600: loss = 0.3218826651573181
step = 8600: Average Return = -5.0
step = 8625: loss = 0.23915

step = 12150: loss = 0.22616243362426758
step = 12175: loss = 1.5958257913589478
step = 12200: loss = 0.19227974116802216
step = 12200: Average Return = -4.800000190734863
step = 12225: loss = 0.1607682704925537
step = 12250: loss = 0.31688573956489563
step = 12275: loss = 0.46413174271583557
step = 12300: loss = 0.5947014689445496
step = 12300: Average Return = -4.900000095367432
step = 12325: loss = 0.7871896028518677
step = 12350: loss = 0.4980195164680481
step = 12375: loss = 0.3791332542896271
step = 12400: loss = 0.20969194173812866
step = 12400: Average Return = -4.699999809265137
step = 12425: loss = 0.5346651077270508
step = 12450: loss = 0.3496337831020355
step = 12475: loss = 0.20748358964920044
step = 12500: loss = 0.3260191082954407
step = 12500: Average Return = -4.699999809265137
step = 12525: loss = 0.20924700796604156
step = 12550: loss = 0.5083172917366028
step = 12575: loss = 0.49458229541778564
step = 12600: loss = 1.3463999032974243
step = 12600: Average Return = -

step = 16100: Average Return = -5.0
step = 16125: loss = 0.22899232804775238
step = 16150: loss = 0.29573866724967957
step = 16175: loss = 0.42527616024017334
step = 16200: loss = 0.3030565679073334
step = 16200: Average Return = -5.0
step = 16225: loss = 0.5708224177360535
step = 16250: loss = 0.40255653858184814
step = 16275: loss = 0.3106139004230499
step = 16300: loss = 0.4586334228515625
step = 16300: Average Return = -4.5
step = 16325: loss = 0.2003219723701477
step = 16350: loss = 0.19661906361579895
step = 16375: loss = 0.33830568194389343
step = 16400: loss = 0.4389812648296356
step = 16400: Average Return = -5.0
step = 16425: loss = 0.4234065115451813
step = 16450: loss = 0.5151466727256775
step = 16475: loss = 1.6237845420837402
step = 16500: loss = 0.3746531009674072
step = 16500: Average Return = -4.900000095367432
step = 16525: loss = 0.453885942697525
step = 16550: loss = 0.35283130407333374
step = 16575: loss = 0.4059564173221588
step = 16600: loss = 0.4280478358268738


step = 20100: loss = 0.2457386553287506
step = 20100: Average Return = -4.699999809265137
step = 20125: loss = 0.42091143131256104
step = 20150: loss = 1.5366687774658203
step = 20175: loss = 0.4343491792678833
step = 20200: loss = 0.4692985713481903
step = 20200: Average Return = -4.900000095367432
step = 20225: loss = 0.22271966934204102
step = 20250: loss = 0.4425097703933716
step = 20275: loss = 0.27142685651779175
step = 20300: loss = 0.3166648745536804
step = 20300: Average Return = -4.900000095367432
step = 20325: loss = 0.32813364267349243
step = 20350: loss = 0.444928377866745
step = 20375: loss = 0.18644163012504578
step = 20400: loss = 0.3016880750656128
step = 20400: Average Return = -4.699999809265137
step = 20425: loss = 1.0301939249038696
step = 20450: loss = 0.2899523973464966
step = 20475: loss = 0.33445900678634644
step = 20500: loss = 0.3388253152370453
step = 20500: Average Return = -4.699999809265137
step = 20525: loss = 0.3737327456474304
step = 20550: loss = 0.43

step = 24050: loss = 0.5163354277610779
step = 24075: loss = 0.30402812361717224
step = 24100: loss = 0.6287400722503662
step = 24100: Average Return = -4.900000095367432
step = 24125: loss = 0.543488085269928
step = 24150: loss = 0.25631213188171387
step = 24175: loss = 0.23132094740867615
step = 24200: loss = 0.3967549204826355
step = 24200: Average Return = -5.0
step = 24225: loss = 0.9067341685295105
step = 24250: loss = 0.6325486302375793
step = 24275: loss = 0.3156733810901642
step = 24300: loss = 0.42634138464927673
step = 24300: Average Return = -5.0
step = 24325: loss = 0.2359749972820282
step = 24350: loss = 0.37105804681777954
step = 24375: loss = 0.40121838450431824
step = 24400: loss = 0.3034859001636505
step = 24400: Average Return = -5.0
step = 24425: loss = 0.2572840750217438
step = 24450: loss = 0.3411848545074463
step = 24475: loss = 0.2445983588695526
step = 24500: loss = 0.4366665482521057
step = 24500: Average Return = -5.0
step = 24525: loss = 0.20056429505348206


step = 28025: loss = 0.2732219398021698
step = 28050: loss = 0.2966265380382538
step = 28075: loss = 0.325954794883728
step = 28100: loss = 0.27768170833587646
step = 28100: Average Return = -4.599999904632568
step = 28125: loss = 0.4309832453727722
step = 28150: loss = 0.4871305823326111
step = 28175: loss = 0.2693346440792084
step = 28200: loss = 0.16974025964736938
step = 28200: Average Return = -4.900000095367432
step = 28225: loss = 0.2091635912656784
step = 28250: loss = 1.3034299612045288
step = 28275: loss = 0.3450092673301697
step = 28300: loss = 0.34126055240631104
step = 28300: Average Return = -4.5
step = 28325: loss = 0.4664051830768585
step = 28350: loss = 0.3356919288635254
step = 28375: loss = 0.42463254928588867
step = 28400: loss = 0.19957032799720764
step = 28400: Average Return = -5.0
step = 28425: loss = 0.30205726623535156
step = 28450: loss = 0.5504007935523987
step = 28475: loss = 0.27032098174095154
step = 28500: loss = 0.3410736322402954
step = 28500: Average 

step = 31975: loss = 0.31052157282829285
step = 32000: loss = 0.7280240058898926
step = 32000: Average Return = -4.699999809265137
step = 32025: loss = 0.3514127731323242
step = 32050: loss = 0.6998111605644226
step = 32075: loss = 0.3262835443019867
step = 32100: loss = 0.7932946085929871
step = 32100: Average Return = -4.599999904632568
step = 32125: loss = 0.11926527321338654
step = 32150: loss = 0.2517646551132202
step = 32175: loss = 0.5043090581893921
step = 32200: loss = 0.491834819316864
step = 32200: Average Return = -4.900000095367432
step = 32225: loss = 0.5212425589561462
step = 32250: loss = 0.30859965085983276
step = 32275: loss = 0.25990861654281616
step = 32300: loss = 0.19941258430480957
step = 32300: Average Return = -4.800000190734863
step = 32325: loss = 0.09542853385210037
step = 32350: loss = 0.27125120162963867
step = 32375: loss = 0.29197776317596436
step = 32400: loss = 0.08023890852928162
step = 32400: Average Return = -5.0
step = 32425: loss = 0.7974388599395

step = 35925: loss = 0.6248746514320374
step = 35950: loss = 0.41906291246414185
step = 35975: loss = 2.3272035121917725
step = 36000: loss = 0.4507881700992584
step = 36000: Average Return = -4.800000190734863
step = 36025: loss = 0.41361820697784424
step = 36050: loss = 0.25375279784202576
step = 36075: loss = 0.24468500912189484
step = 36100: loss = 0.37833207845687866
step = 36100: Average Return = -4.900000095367432
step = 36125: loss = 0.40103358030319214
step = 36150: loss = 0.24379707872867584
step = 36175: loss = 0.4244092106819153
step = 36200: loss = 0.11221921443939209
step = 36200: Average Return = -4.900000095367432
step = 36225: loss = 0.30098825693130493
step = 36250: loss = 0.26785406470298767
step = 36275: loss = 0.1983637511730194
step = 36300: loss = 0.5841137766838074
step = 36300: Average Return = -4.900000095367432
step = 36325: loss = 0.3357975482940674
step = 36350: loss = 0.48231998085975647
step = 36375: loss = 0.3737083375453949
step = 36400: loss = 0.192720

step = 39900: Average Return = -4.800000190734863
step = 39925: loss = 0.5528547763824463
step = 39950: loss = 0.3496798276901245
step = 39975: loss = 0.20130613446235657
step = 40000: loss = 0.40385547280311584
step = 40000: Average Return = -4.800000190734863
step = 40025: loss = 0.2637914717197418
step = 40050: loss = 0.21939027309417725
step = 40075: loss = 0.34705498814582825
step = 40100: loss = 0.38752007484436035
step = 40100: Average Return = -5.0
step = 40125: loss = 0.22841207683086395
step = 40150: loss = 1.0049504041671753
step = 40175: loss = 0.4753292202949524
step = 40200: loss = 0.08999703824520111
step = 40200: Average Return = -5.0
step = 40225: loss = 0.32989558577537537
step = 40250: loss = 0.2901899814605713
step = 40275: loss = 0.19491781294345856
step = 40300: loss = 0.25552886724472046
step = 40300: Average Return = -5.0
step = 40325: loss = 0.27650004625320435
step = 40350: loss = 0.4490308165550232
step = 40375: loss = 0.4184052646160126
step = 40400: loss = 

step = 43875: loss = 0.1615353226661682
step = 43900: loss = 0.30414554476737976
step = 43900: Average Return = -4.900000095367432
step = 43925: loss = 0.2550002932548523
step = 43950: loss = 0.3923290967941284
step = 43975: loss = 0.3579144775867462
step = 44000: loss = 0.24523484706878662
step = 44000: Average Return = -4.800000190734863
step = 44025: loss = 0.7566260099411011
step = 44050: loss = 0.3443264067173004
step = 44075: loss = 0.20470194518566132
step = 44100: loss = 0.7036063075065613
step = 44100: Average Return = -4.900000095367432
step = 44125: loss = 0.9425840377807617
step = 44150: loss = 0.5706779956817627
step = 44175: loss = 0.2423376739025116
step = 44200: loss = 0.36028194427490234
step = 44200: Average Return = -4.900000095367432
step = 44225: loss = 0.3669598698616028
step = 44250: loss = 0.240467369556427
step = 44275: loss = 0.2207304984331131
step = 44300: loss = 0.4627702236175537
step = 44300: Average Return = -4.800000190734863
step = 44325: loss = 0.1900

step = 47800: Average Return = -4.800000190734863
step = 47825: loss = 0.46777796745300293
step = 47850: loss = 0.36733949184417725
step = 47875: loss = 0.5491920113563538
step = 47900: loss = 0.36044055223464966
step = 47900: Average Return = -4.699999809265137
step = 47925: loss = 0.39081352949142456
step = 47950: loss = 0.29981324076652527
step = 47975: loss = 0.4646138846874237
step = 48000: loss = 0.3764020502567291
step = 48000: Average Return = -4.900000095367432
step = 48025: loss = 0.24290494620800018
step = 48050: loss = 0.29041779041290283
step = 48075: loss = 0.5893824696540833
step = 48100: loss = 0.5824771523475647
step = 48100: Average Return = -4.900000095367432
step = 48125: loss = 0.3410012125968933
step = 48150: loss = 0.22995877265930176
step = 48175: loss = 0.22986099123954773
step = 48200: loss = 0.523926854133606
step = 48200: Average Return = -4.900000095367432
step = 48225: loss = 0.2904279828071594
step = 48250: loss = 0.2582052946090698
step = 48275: loss = 0

step = 51725: loss = 0.43087679147720337
step = 51750: loss = 0.3505319058895111
step = 51775: loss = 0.8059607148170471
step = 51800: loss = 0.28078439831733704
step = 51800: Average Return = -4.900000095367432
step = 51825: loss = 0.21265631914138794
step = 51850: loss = 0.5011957287788391
step = 51875: loss = 0.7090011835098267
step = 51900: loss = 0.5386852622032166
step = 51900: Average Return = -5.0
step = 51925: loss = 0.47506749629974365
step = 51950: loss = 0.5489835143089294
step = 51975: loss = 0.3382619023323059
step = 52000: loss = 0.36710458993911743
step = 52000: Average Return = -4.800000190734863
step = 52025: loss = 0.3689701557159424
step = 52050: loss = 0.5030471682548523
step = 52075: loss = 0.4299165606498718
step = 52100: loss = 0.361268013715744
step = 52100: Average Return = -4.800000190734863
step = 52125: loss = 0.28702372312545776
step = 52150: loss = 0.4387619197368622
step = 52175: loss = 0.14997076988220215
step = 52200: loss = 0.7267452478408813
step = 5

step = 55700: loss = 0.47243624925613403
step = 55700: Average Return = -4.699999809265137
step = 55725: loss = 0.3395160138607025
step = 55750: loss = 0.3842303454875946
step = 55775: loss = 0.4449331760406494
step = 55800: loss = 0.35178714990615845
step = 55800: Average Return = -4.699999809265137
step = 55825: loss = 0.7725856900215149
step = 55850: loss = 0.22764310240745544
step = 55875: loss = 0.5116342902183533
step = 55900: loss = 0.2733641564846039
step = 55900: Average Return = -4.800000190734863
step = 55925: loss = 0.3595157265663147
step = 55950: loss = 0.30702540278434753
step = 55975: loss = 0.16353298723697662
step = 56000: loss = 0.2531071901321411
step = 56000: Average Return = -4.800000190734863
step = 56025: loss = 0.4242837727069855
step = 56050: loss = 0.30808261036872864
step = 56075: loss = 0.332015722990036
step = 56100: loss = 0.323779821395874
step = 56100: Average Return = -4.900000095367432
step = 56125: loss = 0.3327682316303253
step = 56150: loss = 0.484

step = 59650: loss = 0.5787374377250671
step = 59675: loss = 0.44641584157943726
step = 59700: loss = 0.576362133026123
step = 59700: Average Return = -5.0
step = 59725: loss = 0.3160637319087982
step = 59750: loss = 0.20227846503257751
step = 59775: loss = 0.2648298442363739
step = 59800: loss = 0.1557321548461914
step = 59800: Average Return = -4.5
step = 59825: loss = 0.9750165343284607
step = 59850: loss = 0.29739460349082947
step = 59875: loss = 0.45202624797821045
step = 59900: loss = 0.3342505097389221
step = 59900: Average Return = -4.599999904632568
step = 59925: loss = 0.31107261776924133
step = 59950: loss = 0.18838319182395935
step = 59975: loss = 0.3308482766151428
step = 60000: loss = 0.44259965419769287
step = 60000: Average Return = -5.0
step = 60025: loss = 0.5027800798416138
step = 60050: loss = 0.26075032353401184
step = 60075: loss = 0.2515360116958618
step = 60100: loss = 0.4843669533729553
step = 60100: Average Return = -4.900000095367432
step = 60125: loss = 0.41

step = 63600: Average Return = -4.900000095367432
step = 63625: loss = 0.2658464014530182
step = 63650: loss = 0.20053061842918396
step = 63675: loss = 1.194706916809082
step = 63700: loss = 0.26830822229385376
step = 63700: Average Return = -4.900000095367432
step = 63725: loss = 0.45513883233070374
step = 63750: loss = 0.5450139045715332
step = 63775: loss = 0.39284735918045044
step = 63800: loss = 0.3766825199127197
step = 63800: Average Return = -4.5
step = 63825: loss = 0.1768062561750412
step = 63850: loss = 0.31279510259628296
step = 63875: loss = 0.4099474847316742
step = 63900: loss = 0.292464941740036
step = 63900: Average Return = -4.900000095367432
step = 63925: loss = 0.3575281798839569
step = 63950: loss = 0.29987210035324097
step = 63975: loss = 0.2642904222011566
step = 64000: loss = 0.41278165578842163
step = 64000: Average Return = -4.900000095367432
step = 64025: loss = 0.31488844752311707
step = 64050: loss = 0.26443517208099365
step = 64075: loss = 0.54904913902282

KeyboardInterrupt: 

In [29]:
def embed_mp4(filename):
    """Embeds an mp4 file in the notebook."""
    video = open(filename,'rb').read()
    b64 = base64.b64encode(video)
    tag = '''
    <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
    Your browser does not support the video tag.
    </video>'''.format(b64.decode())

    return IPython.display.HTML(tag)

In [30]:
import imageio

num_episodes = 20
video_filename = 'imageio.mp4'
with imageio.get_writer(video_filename, fps=60) as video:
    for _ in range(num_episodes):
        time_step = eval_env.reset()
        video.append_data(eval_py_env.render())
        while not time_step.is_last():
            action_step = tf_agent.policy.action(time_step)
            time_step = eval_env.step(action_step.action)
            video.append_data(eval_py_env.render())

embed_mp4(video_filename)

