In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
!unzip gym-foo.zip
!pip install -e /content/gym-foo

In [None]:
!pip install tf_agents

In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
import imageio
from absl import logging
from tf_agents.environments.py_environment import PyEnvironment
from tf_agents.environments.tf_environment import TFEnvironment
from tf_agents.policies import tf_policy
from tf_agents.trajectories.policy_step import PolicyStep


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import os
import time

from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
from tf_agents.agents.ppo import ppo_clip_agent
from tf_agents.drivers import dynamic_episode_driver
from tf_agents.environments import parallel_py_environment
from tf_agents.environments import suite_mujoco
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.networks import actor_distribution_rnn_network
from tf_agents.networks import value_network
from tf_agents.networks import value_rnn_network
from tf_agents.policies import policy_saver
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.system import system_multiprocessing as multiprocessing
from tf_agents.utils import common

import gym_foo
from tf_agents.environments import suite_gym

In [None]:
def create_video(py_environment: PyEnvironment, tf_environment: TFEnvironment, policy, num_episodes=1, video_filename='imageio.mp4'):
  logging.info("Generating video %s" % video_filename)
	#with imageio.get_writer(video_filename, fps=60) as video:
  with imageio.get_writer(video_filename, fps=1) as video:
    for episode in range(num_episodes):
      logging.info("Generating episode %d of %d" % (episode, num_episodes))

      time_step = tf_environment.reset()
      state = policy.get_initial_state(tf_environment.batch_size)

      video.append_data(py_environment.render())

      while not time_step.is_last():
        policy_step: PolicyStep = policy.action(time_step, state)
        #state = policy_step.state
        #act = policy_step.action

        time_step = tf_environment.step(policy_step.action)
        #time_step = tf_environment.step(act)
        #time_step = tf_environment.step(policy_step.action)
        img_bgr = py_environment.render()
        video.append_data(img_bgr)

  logging.info("Finished video %s" % video_filename)

In [None]:
def env_load_fn(env_name):

  py_env = suite_gym.load(env_name)

  py_env.reset()

  return py_env

In [None]:

flags.DEFINE_string('f', '', 'kernel')
flags.DEFINE_string('root_dir', '/content/gdrive/My Drive/DATA691_RL_logs10',
                    'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_string('env_name', 'foo-v0', 'Name of an environment')
flags.DEFINE_integer(
    'replay_buffer_capacity', 1001, 'Replay buffer capacity per env.'
)
flags.DEFINE_integer(
    #'num_parallel_environments', 16, 'Number of environments to run in parallel'
    'num_parallel_environments', 256, 'Number of environments to run in parallel'
)
flags.DEFINE_integer(
    'num_environment_steps',
    25000000,
    'Number of environment steps to run before finishing.',
)
flags.DEFINE_integer(
    'num_epochs', 25, 'Number of epochs for computing policy updates.'
)
flags.DEFINE_integer(
    'collect_episodes_per_iteration',
    30,
    'The number of episodes to take in the environment before '
    'each update. This is the total across all parallel '
    'environments.',
)
flags.DEFINE_integer(
    'num_eval_episodes', 30, 'The number of episodes to run eval on.'
)
flags.DEFINE_boolean(
    #'use_rnns', False, 'If true, use RNN for policy and value function.'
    'use_rnns', True, 'If true, use RNN for policy and value function.'
)
FLAGS = flags.FLAGS


@gin.configurable
def train_eval(
    root_dir,
    #env_name='HalfCheetah-v2',
    env_name='foo-v0',
    #env_load_fn=suite_mujoco.load,
    env_load_fn=env_load_fn,
    random_seed=None,
    # TODO(b/127576522): rename to policy_fc_layers.
    actor_fc_layers=(200, 100),
    value_fc_layers=(200, 100),
    #actor_fc_layers=(400, 200),
    #value_fc_layers=(400, 200),
    use_rnns=False,
    #lstm_size=(20,),
    lstm_size=(64,),
    # Params for collect
    num_environment_steps=25000000,
    collect_episodes_per_iteration=30,
    num_parallel_environments=30,
    replay_buffer_capacity=1001,  # Per-environment
    # Params for train
    num_epochs=25,
    learning_rate=1e-3,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=100,
    # Params for summaries and logging
    train_checkpoint_interval=1000,
    policy_checkpoint_interval=1000,
    log_interval=100,
    summary_interval=100,
    summaries_flush_secs=1,
    #use_tf_functions=True,
    use_tf_functions=True,
    debug_summaries=False,
    summarize_grads_and_vars=False,
):
  """A simple train and eval for PPO."""
  if root_dir is None:
    raise AttributeError('train_eval requires a root_dir.')

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')
  saved_model_dir = os.path.join(root_dir, 'policy_saved_model')
  videos_dir = os.path.join(root_dir, 'videos')
  os.makedirs(videos_dir, exist_ok=True)

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000
  )
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000
  )
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)
  ):
    if random_seed is not None:
      tf.compat.v1.set_random_seed(random_seed)
    eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments
        )
    )
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    if use_rnns:
      actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          input_fc_layer_params=actor_fc_layers,
          output_fc_layer_params=None,
          lstm_size=lstm_size,
      )
      value_net = value_rnn_network.ValueRnnNetwork(
          tf_env.observation_spec(),
          input_fc_layer_params=value_fc_layers,
          output_fc_layer_params=None,
      )
    else:
      actor_net = actor_distribution_network.ActorDistributionNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          #(tf.keras.layers.experimental.preprocessing.CategoryEncoding(num_tokens=4, output_mode="binary"),
           # tf.keras.layers.experimental.preprocessing.CategoryEncoding(num_tokens=4, output_mode="binary")),
            #tf.keras.layers.experimental.preprocessing.CategoryEncoding(num_tokens=2, output_mode="binary")),
          #tf.keras.layers.Concatenate(axis=-1),
          fc_layer_params=actor_fc_layers,
          activation_fn=tf.keras.activations.tanh,
      )
      value_net = value_network.ValueNetwork(
          tf_env.observation_spec(),
          #(tf.keras.layers.experimental.preprocessing.CategoryEncoding(num_tokens=4, output_mode="binary"),
          #  tf.keras.layers.experimental.preprocessing.CategoryEncoding(num_tokens=4, output_mode="binary")),
            #tf.keras.layers.experimental.preprocessing.CategoryEncoding(num_tokens=2, output_mode="binary")),
          #tf.keras.layers.Concatenate(axis=-1),
          fc_layer_params=value_fc_layers,
          activation_fn=tf.keras.activations.tanh,
      )

    tf_agent = ppo_clip_agent.PPOClipAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        optimizer,
        actor_net=actor_net,
        value_net=value_net,
        entropy_regularization=0.001,
        importance_ratio_clipping=0.7,
        #entropy_regularization=0.0,
        #importance_ratio_clipping=0.2,
        normalize_observations=False,
        normalize_rewards=False,
        use_gae=True,
        num_epochs=num_epochs,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step,
    )
    tf_agent.initialize()

    environment_steps_metric = tf_metrics.EnvironmentSteps()
    step_metrics = [
        tf_metrics.NumberOfEpisodes(),
        environment_steps_metric,
    ]

    train_metrics = step_metrics + [
        tf_metrics.AverageReturnMetric(batch_size=num_parallel_environments),
        tf_metrics.AverageEpisodeLengthMetric(
            batch_size=num_parallel_environments
        ),
    ]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=num_parallel_environments,
        max_length=replay_buffer_capacity,
    )

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
    )
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step,
    )
    saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step)

    train_checkpointer.initialize_or_restore()

    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=collect_episodes_per_iteration,
    )

    def train_step():
      trajectories = replay_buffer.gather_all()
      return tf_agent.train(experience=trajectories)

    if use_tf_functions:
      # TODO(b/123828980): Enable once the cause for slowdown was identified.
      collect_driver.run = common.function(collect_driver.run, autograph=False)
      tf_agent.train = common.function(tf_agent.train, autograph=False)
      train_step = common.function(train_step)

    collect_time = 0
    train_time = 0
    timed_at_step = global_step.numpy()

    while environment_steps_metric.result() < num_environment_steps:
      global_step_val = global_step.numpy()
      if global_step_val % eval_interval == 0: #and global_step_val != 0:
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )

      start_time = time.time()
      collect_driver.run()
      collect_time += time.time() - start_time

      start_time = time.time()
      total_loss, _ = train_step()
      replay_buffer.clear()
      train_time += time.time() - start_time

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=step_metrics
        )

      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val, total_loss)
        steps_per_sec = (global_step_val - timed_at_step) / (
            collect_time + train_time
        )
        logging.info('%.3f steps/sec', steps_per_sec)
        logging.info(
            'collect_time = %.3f, train_time = %.3f', collect_time, train_time
        )
        with tf.compat.v2.summary.record_if(True):
          tf.compat.v2.summary.scalar(
              name='global_steps_per_sec', data=steps_per_sec, step=global_step
          )

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:

          video_py_env = env_load_fn(env_name)
          video_tf_env = tf_py_environment.TFPyEnvironment(video_py_env)
          create_video(video_py_env, video_tf_env, tf_agent.policy, num_episodes=1, video_filename=os.path.join(videos_dir, "video_%d.mp4" % global_step_val))

          policy_checkpointer.save(global_step=global_step_val)
          saved_model_path = os.path.join(
              saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)
          )
          saved_model.save(saved_model_path)

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0

    # One final eval before exiting.
    metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )



In [None]:
def main(_):
    logging.set_verbosity(logging.INFO)
    tf.compat.v1.enable_v2_behavior()
    train_eval(
      FLAGS.root_dir,
      env_name=FLAGS.env_name,
      use_rnns=FLAGS.use_rnns,
      num_environment_steps=FLAGS.num_environment_steps,
      collect_episodes_per_iteration=FLAGS.collect_episodes_per_iteration,
      num_parallel_environments=FLAGS.num_parallel_environments,
      replay_buffer_capacity=FLAGS.replay_buffer_capacity,
      num_epochs=FLAGS.num_epochs,
      num_eval_episodes=FLAGS.num_eval_episodes,
    )


if __name__ == '__main__':
  flags.mark_flag_as_required('root_dir')
  multiprocessing.handle_main(functools.partial(app.run, main))