# DQN

Ref: https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial

In [None]:
from tf_agents.environments import suite_gym, tf_py_environment
from tf_agents.networks import sequential, q_network
from tf_agents.specs import tensor_spec
from tf_agents.utils import common
from tf_agents.agents.dqn import dqn_agent
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.drivers import py_driver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.metrics import tf_metrics
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.drivers import dynamic_step_driver
from tf_agents.eval import metric_utils

import matplotlib
import matplotlib.pyplot as plt

import reverb

import tensorflow as tf
import pyvirtualdisplay


import PIL
from PIL import ImageDraw, ImageFont
import numpy as np
import IPython
import imageio
import base64

import os
import time

from absl import logging

In [None]:
logging.set_verbosity(logging.INFO)

In [None]:
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()


In [None]:
env_name = 'CartPole-v0'
env_name = 'CartPole-v1'
env_name = 'Acrobot-v1'
train_or_retrain = True
num_iterations = 100_000

learning_rate = 1e-4
num_eval_episodes = 10
replay_buffer_max_length = 100000
initial_collect_steps = 100
batch_size = 64
collect_steps_per_iteration = 1

log_interval = 200
eval_interval = 1000
video_recording_interval = 1_000

root_dir = os.path.join('./data', env_name)
summaries_flush_secs = 10

fc_layer_params = (100, 50)
gamma = 0.99
reward_scale_factor = 1.0
gradient_clipping = None
debug_summaries = False
summarize_grads_and_vars = False

# Params for train
use_tf_functions = True
train_steps_per_iteration = 1

# Params for collect
epsilon_greedy = 0.1
replay_buffer_capacity = 100_000

# Params for target update
target_update_tau = 0.05
target_update_period = 5

# Params for summaries and logging
summary_interval = 1_000
eval_metrics_callback = None

train_sequence_length = 1

# Params for checkpoints
train_checkpoint_interval = 10_000
policy_checkpoint_interval = 5_000
rb_checkpoint_interval = 20_000

## Setup

In [None]:
root_dir = os.path.expanduser(root_dir)
train_dir = os.path.join(root_dir, 'train')
eval_dir = os.path.join(root_dir, 'eval')
video_dir = os.path.join(root_dir, 'video')

# Create the video recording directory
os.makedirs(video_dir)

In [None]:
# Train Summary Writer
train_summary_writer = tf.summary.create_file_writer(
    train_dir, flush_millis=summaries_flush_secs * 1000)
train_summary_writer.set_as_default()

In [None]:
# Eval Summary Writer
eval_summary_writer = tf.summary.create_file_writer(
    eval_dir, flush_millis=summaries_flush_secs * 1000)
eval_metrics = [
    tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
    tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
]

In [None]:
# Create global_step
global_step = tf.compat.v1.train.get_or_create_global_step()

### Video

In [None]:
def get_timestamp():
    import datetime
    return datetime.datetime.now().timestamp()

In [None]:
def embed_mp4(filename):
    video = open(filename, 'rb').read()
    b64 = base64.b64encode(video)
    
    tag = '''
    <video width="640" height="480" controls>
        <source src="data:video/mp4;base64,{0}" type="video/mp4">
    Your browser does not support the video tag.
    </video>'''.format(b64.decode())
    
    return IPython.display.HTML(tag)

In [None]:
def enhance_frame(frame: np.ndarray, text=None) -> np.ndarray:
    if text is None:
        return frame
    
    # Convert array to PIl.Image
    image = PIL.Image.fromarray(frame).convert('RGB')

    # Get draw context
    draw = ImageDraw.Draw(image, 'RGB')

    # Get font
    font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', 20)

    # Draw text
    draw.text((30, 30), text, font=font, fill=(0, 0, 0), stroke_width=1, stroke_fill=(255, 255, 255))

    return np.array(image)

In [None]:
def create_policy_eval_video(policy, eval_env, eval_py_env, filename=None, 
        num_episodes=3, fps=30, env_name=env_name, freeze_seconds=0,
        step=None):
    if filename is None:
        filename = str(get_timestamp())
        
    filename = filename + '.mp4'
    logging.info('Env: %s', env_name)
    logging.info('Filename: %s', filename)
    
    with imageio.get_writer(filename, fps=fps) as video:
        for idx in range(num_episodes):
            logging.info('Begin #%d of %d', idx+1, num_episodes)
            time_step = eval_env.reset()
            frame_idx = 0
            
            text = f'Env: {env_name}'
            if step is not None:
                text += f'\nStp: {step}'
            text += f'\nEp:  {idx+1}/{num_episodes}\nFrm: {frame_idx}'
            
            frame = enhance_frame(eval_py_env.render(mode='rgb_array'), text)
            video.append_data(frame)

            while not time_step.is_last():
                action_step = policy.action(time_step)
                time_step = eval_env.step(action_step.action)
                frame_idx += 1
                
                text = f'Env: {env_name}'
                if step is not None:
                    text += f'\nStp: {step}'
                text += f'\nEp:  {idx+1}/{num_episodes}\nFrm: {frame_idx}'

                frame = enhance_frame(eval_py_env.render(mode='rgb_array'), text)
                video.append_data(frame)
                
                # Freeze frame for a few seconds
                if time_step.is_last() and freeze_seconds > 0:
                    for _ in range(fps * freeze_seconds):
                        video.append_data(frame)
    
    logging.info('All done')
    return filename
    # return embed_mp4(filename)

In [None]:
with tf.summary.record_if(lambda: tf.math.equal(global_step % summary_interval, 0)):
    # Create env
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
    eval_py_env = suite_gym.load(env_name)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
    
    # Create Q network
    q_net = q_network.QNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        fc_layer_params=fc_layer_params)
    
    # Create Agent
    tf_agent = dqn_agent.DqnAgent(
        time_step_spec=tf_env.time_step_spec(),
        action_spec=tf_env.action_spec(),
        q_network=q_net,
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        
        # Params for collect
        epsilon_greedy=epsilon_greedy,
        
        # Params for target network updates
        target_q_network=None,
        target_update_tau=target_update_tau,  # Default: 1.0, "Factor for soft update of the target network"
        target_update_period=target_update_period,  # Default: 1, "Period for soft update of the target network"
        
        # Params for training
        td_errors_loss_fn=common.element_wise_squared_loss,  # Default: common.element_wise_huber_loss
        gamma=gamma,  # Default: 1.0, Discount for future rewards.
        reward_scale_factor=reward_scale_factor,  # Default: 1.0
        gradient_clipping=gradient_clipping,  # Default: None, "Norm length to clip gradients"
        
        # Params for debugging
        train_step_counter=global_step,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        name=None,  # Default: class name. The agent name.
    )
    tf_agent.initialize()
    
    # Train Metrics
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]
    
    # Policies
    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy
    
    # Replay Buffer
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)
    
    # Collect Driver
    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration)
    
    # Checkpointers
    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)
    
    train_checkpointer.initialize_or_restore()
    policy_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()
    
    if train_or_retrain:

        # Speed up with common.function
        if use_tf_functions:
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay buffer data.

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
        )

        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)

        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        # Variables for logging time (steps_per_sec)
        timed_at_step = global_step.numpy()
        time_acc = 0  # Time accumulation

        # Dataset
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=batch_size,
            num_steps=train_sequence_length + 1,
            single_deterministic_pass=False,
        ).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(
                    name='global_steps_per_sec', data=steps_per_sec, step=global_step)

                # Reset time.
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(
                    train_step=global_step, step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())
            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())
            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

            # Record a video of current eval agent policy
            if global_step.numpy() % video_recording_interval == 0:
                filename = '{}_{}'.format(global_step.numpy(), get_timestamp())
                full_filename = os.path.join(video_dir, filename)
                create_policy_eval_video(
                    eval_policy,
                    eval_tf_env,
                    eval_py_env,
                    filename=full_filename,
                    fps=15,
                    freeze_seconds=3,
                    num_episodes=1,
                    step=global_step.numpy(),
                )

In [None]:
saved_filename = create_policy_eval_video(tf_agent.policy, eval_tf_env, eval_py_env, fps=15, freeze_seconds=3)

In [None]:
embed_mp4(saved_filename)