In [1]:
from mlrl.maze_env import make_maze_env
from mlrl.maze_state import MazeState
from mlrl.search_tree import SearchTree
from mlrl.meta_env import MetaEnv
from mlrl.manhattan_q import ManhattanQHat

object_env = make_maze_env(seed=0)
q_hat = ManhattanQHat(object_env)

def make_maze_search_tree(env) -> SearchTree:
    return SearchTree(env, extract_state=MazeState.extract_state)

meta_env = MetaEnv(object_env, q_hat, make_maze_search_tree)

In [2]:
import tensorflow as tf
import tf_agents



In [3]:
num_iterations = 20000

initial_collect_steps = 100 
collect_steps_per_iteration =   1
replay_buffer_max_length = 100000

batch_size = 64 
learning_rate = 1e-3
log_interval = 200

num_eval_episodes = 10
eval_interval = 1000 

num_actions = meta_env.action_space.n

In [4]:
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import sequential
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.trajectories import trajectory
from tf_agents.specs import tensor_spec
from tf_agents.utils import common

In [5]:
env = tf_py_environment.TFPyEnvironment(tf_agents.environments.gym_wrapper.GymWrapper(meta_env))

eval_object_env = make_maze_env(seed=0)
eval_meta_env = MetaEnv(eval_object_env, q_hat, make_maze_search_tree)
eval_env = tf_py_environment.TFPyEnvironment(tf_agents.environments.gym_wrapper.GymWrapper(eval_meta_env))

env.reset()
eval_env.reset()

TimeStep(
{'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': OrderedDict([('search_tree_tokens',
                              <tf.Tensor: shape=(1, 10, 5), dtype=float32, numpy=
array([[[-1., -1.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]]], dtype=float32)>),
                             ('valid_action_mask',
                              <tf.Tensor: shape=(1, 41), dtype=int32, numpy=
array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
      dtype=int32)>)]),
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'step_type': <tf.Ten

In [6]:
def mask_invalid_action_constraint_splitter(observation):
    return observation['search_tree_tokens'], observation['valid_action_mask']

In [7]:
fc_layer_params = (100, 50)
action_tensor_spec = tensor_spec.from_spec(env.action_spec())
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

# Define a helper function to create Dense layers configured with the right
# activation and kernel initializer.
def dense_layer(num_units):
  return tf.keras.layers.Dense(
      num_units,
      activation=tf.keras.activations.relu,
      kernel_initializer=tf.keras.initializers.VarianceScaling(
          scale=2.0, mode='fan_in', distribution='truncated_normal'))

# QNetwork consists of a sequence of Dense layers followed by a dense layer
# with `num_actions` units to generate one q_value per available action as
# its output.
dense_layers = [tf.keras.layers.Flatten()] + [dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer = tf.keras.layers.Dense(
    num_actions,
    activation=None,
    kernel_initializer=tf.keras.initializers.RandomUniform(
        minval=-0.03, maxval=0.03),
    bias_initializer=tf.keras.initializers.Constant(-0.2))
q_net = sequential.Sequential(dense_layers + [q_values_layer])

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    env.time_step_spec(),
    env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    observation_and_action_constraint_splitter=mask_invalid_action_constraint_splitter,
    train_step_counter=train_step_counter
)

agent.initialize()

In [8]:
random_policy = random_tf_policy.RandomTFPolicy(env.time_step_spec(),
                                                env.action_spec(),
                                                observation_and_action_constraint_splitter=mask_invalid_action_constraint_splitter)

In [9]:
import time
from IPython.display import clear_output
from mlrl.utils.plot_search_tree import plot_tree
from mlrl.maze_env import render_maze

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

import io
import numpy as np


def plot_to_array(fig):
    io_buf = io.BytesIO()
    fig.savefig(io_buf, format='raw')
    io_buf.seek(0)
    img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                         newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
    io_buf.close()
    return img_arr


def plot_meta_env(meta_env: MetaEnv):
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    
    root_state = meta_env.tree.get_root().get_state()
    root_state.set_environment_to_state(meta_env.object_env)
    render_maze(meta_env.object_env, ax=axs[0], title='Maze', show=False)
    plot_tree(meta_env.tree, ax=axs[1], show=False)

    q_dist = eval_meta_env.root_q_distribution()

    sns.barplot(x=list(range(q_dist.shape[0])), y=q_dist, ax=axs[2])
    axs[2].set_ylim([0, 1])
    axs[2].set_title('Root Q-Distribution')
    axs[2].set_xticklabels(object_env.ACTION)
    axs[2].yaxis.set_label_position("right")
    axs[2].yaxis.tick_right()
    plt.tight_layout(rect=[0, 0.03, 1, .9])
    
    return fig


def watch_policy(policy: tf_agents.policies.tf_policy.TFPolicy, max_steps=100):

    eval_env.reset()
    meta_actions = eval_meta_env.get_action_strings()

    for step in range(max_steps):
        try:
            action_step = policy.action(eval_env.current_time_step())
            time_step = eval_env.step(action_step.action)

            plot_meta_env(eval_meta_env)
            action_string = meta_actions[action_step.action.numpy()[0]]
            plt.suptitle(f'{action_string} | Reward: {time_step.reward.numpy()[0]:.3f}')
            plt.show()

            time.sleep(1.5)
            clear_output(wait=True)
        except KeyboardInterrupt:
            break


from IPython.display import HTML
import base64
import imageio


def embed_mp4(filename):
    """Embeds an mp4 file in the notebook."""
    video = open(filename,'rb').read()
    b64 = base64.b64encode(video)
    tag = '''
    <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
    Your browser does not support the video tag.
    </video>'''.format(b64.decode())

    return HTML(tag)


def create_policy_eval_video(policy, filename='video', max_steps=60, fps=1):
    filename = filename + ".mp4"
    meta_actions = eval_meta_env.get_action_strings()
    eval_env.reset()

    with imageio.get_writer(filename, fps=fps) as video:
        fig = plot_meta_env(eval_meta_env)
        plt.suptitle(f'Intial State')
        video.append_data(plot_to_array(fig))

        for step in range(max_steps):
            prior_best_object_action = eval_meta_env.get_best_object_action()
            action_step = policy.action(eval_env.current_time_step())
            time_step = eval_env.step(action_step.action)
            curr_best_object_action = eval_meta_env.get_best_object_action()

            action_string = meta_actions[action_step.action.numpy()[0]]
            fig = plot_meta_env(eval_meta_env)

            if eval_meta_env.tree.get_num_nodes() != 1:
                computational_reward = eval_meta_env.get_computational_reward(prior_best_object_action)
            else:
                computational_reward = 0
            
            plt.suptitle(f'Meta-action: [{action_string}] | '
                         f'Meta-Reward: {time_step.reward.numpy()[0]:.3f} | '
                         f'Best Object-action: {object_env.ACTION[curr_best_object_action]} | '
                         f'Computational-Reward: {computational_reward:.3f}')

            video.append_data(plot_to_array(fig))
            
            plt.show()
            clear_output(wait=True)
    return embed_mp4(filename)

In [19]:
create_policy_eval_video(random_policy)

In [16]:
create_policy_eval_video(random_policy)

In [22]:
def compute_avg_return(environment, policy, num_episodes=5, max_steps=500):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    n_steps = 0
    while not time_step.is_last() and n_steps < max_steps:
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
      n_steps += 1
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]

In [23]:
compute_avg_return(env, random_policy)

-0.50874484

In [24]:
compute_avg_return(env, agent.policy)

0.009036064

In [25]:
from tf_agents.replay_buffers import tf_uniform_replay_buffer

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=env.batch_size,
    max_length=replay_buffer_max_length)


def collect_step(environment, policy, buffer):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
  for _ in range(steps):
    collect_step(env, policy, buffer)

collect_data(env, random_policy, replay_buffer, initial_collect_steps)

In [26]:
iter(replay_buffer.as_dataset()).next()

Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


(Trajectory(
 {'action': <tf.Tensor: shape=(), dtype=int64, numpy=4>,
  'discount': <tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
  'next_step_type': <tf.Tensor: shape=(), dtype=int32, numpy=1>,
  'observation': {'search_tree_tokens': <tf.Tensor: shape=(10, 5), dtype=float32, numpy=
 array([[-1.   , -1.   ,  0.   ,  0.   ,  0.   ],
        [ 0.   ,  2.   , -0.004,  1.   ,  0.   ],
        [ 0.   ,  1.   , -0.004,  0.   ,  0.   ],
        [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
        [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
        [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
        [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
        [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
        [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
        [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ]], dtype=float32)>,
                  'valid_action_mask': <tf.Tensor: shape=(41,), dtype=int32, numpy=
 array([1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [27]:
# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3,
    sample_batch_size=batch_size,
    num_steps=2).prefetch(3)

iterator = iter(dataset)

dataset

<PrefetchDataset element_spec=(Trajectory(
{'action': TensorSpec(shape=(64, 2), dtype=tf.int64, name=None),
 'discount': TensorSpec(shape=(64, 2), dtype=tf.float32, name=None),
 'next_step_type': TensorSpec(shape=(64, 2), dtype=tf.int32, name=None),
 'observation': {'search_tree_tokens': TensorSpec(shape=(64, 2, 10, 5), dtype=tf.float32, name=None),
                 'valid_action_mask': TensorSpec(shape=(64, 2, 41), dtype=tf.int32, name=None)},
 'policy_info': (),
 'reward': TensorSpec(shape=(64, 2), dtype=tf.float32, name=None),
 'step_type': TensorSpec(shape=(64, 2), dtype=tf.int32, name=None)}), BufferInfo(ids=TensorSpec(shape=(64, 2), dtype=tf.int64, name=None), probabilities=TensorSpec(shape=(64,), dtype=tf.float32, name=None)))>

In [28]:
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few steps using collect_policy and save to the replay buffer.
  collect_data(env, agent.collect_policy, replay_buffer, collect_steps_per_iteration)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience).loss

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)

Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
step = 200: loss = 1568992.125
step = 400: loss = 155707520.0
step = 600: loss = 1194554368.0
step = 800: loss = 2218108416.0
step = 1000: loss = 19108038656.0
step = 1000: Average Return = -0.49436813592910767
step = 1200: loss = 35338883072.0
step = 1400: loss = 52723650560.0
step = 1600: loss = 24496648192.0
step = 1800: loss = 64770433024.0
step = 2000: loss = 98337308672.0
step = 2000: Average Return = -0.49436813592910767
step = 2200: loss = 91838799872.0
step = 2400: loss = 298557374464.0
step = 2600: loss = 100495605760.0
step = 2800: loss = 166111510528.0
step = 3000: loss = 292375101440.0
step = 3000: Average Return = -0.49436813592910767
step = 3200: loss = 789839609856.0
step = 3400: loss = 277328363520.0
step = 3600: loss = 338292310016.0
step

In [33]:
create_policy_eval_video(agent.policy, max_steps=100)