In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
scenario_path = "C:\\Users\\dkolano\\OneDrive - Agile Space Industries\\Documents\\STK 12\\test_astrogator_collision\\test_astrogator_collision.sc"
visible = True
userControl = True


In [3]:
import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import sequential
from tf_agents.policies import py_tf_eager_policy
from tf_agents.trajectories import trajectory
from tf_agents.specs import tensor_spec
from tf_agents.utils import common
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.replay_buffers import py_uniform_replay_buffer
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import dynamic_step_driver
from tf_agents.drivers import dynamic_episode_driver
from tf_agents.policies import policy_saver

from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.sac import sac_agent
from tf_agents.agents.sac import tanh_normal_projection_network
from tf_agents.metrics import py_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import greedy_policy
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_py_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.train import actor
from tf_agents.train import learner
from tf_agents.train import triggers
from tf_agents.train.utils import spec_utils
from tf_agents.train.utils import strategy_utils
from tf_agents.train.utils import train_utils


import os, tempfile, shutil
import matplotlib.pyplot as plt

tempdir = tempfile.gettempdir()


from cs238_final_project.collision_avoidance.environment import Environment
from cs238_final_project.simulation.simulation import Simulation


In [4]:
# Use "num_iterations = 1e6" for better results (2 hrs)
# 1e5 is just so this doesn't take too long (1 hr)
num_iterations = 55000  # @param {type:"integer"}

initial_collect_steps = 10  # @param {type:"integer"}
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_capacity = 1000000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}

critic_learning_rate = 4e-4  # @param {type:"number"}
actor_learning_rate = 4e-4  # @param {type:"number"}
alpha_learning_rate = 4e-4  # @param {type:"number"}
target_update_tau = 0.005  # @param {type:"number"}
target_update_period = 1  # @param {type:"number"}
gamma = 1.0  # @param {type:"number"}
reward_scale_factor = 1.0  # @param {type:"number"}

actor_fc_layer_params = (32, 16)
critic_joint_fc_layer_params = (32, 16)

log_interval = 200  # @param {type:"integer"}

num_eval_episodes = 8  # @param {type:"integer"}
eval_interval = 2000  # @param {type:"integer"}

policy_save_interval = eval_interval  # @param {type:"integer"}


In [None]:
sim = Simulation.simulation_from_file(
    scenario_path, use_stk_engine=False,
    visible=visible, userControl=userControl)


In [6]:
train_py_env = Environment(sim, continuous=True)
collect_env = tf_py_environment.TFPyEnvironment(train_py_env)
# collect_env = train_py_env

In [7]:
observation_spec = collect_env.observation_spec()
action_spec = collect_env.action_spec()

critic_net = critic_network.CriticNetwork(
    (observation_spec, action_spec),
    observation_fc_layer_params=None,
    action_fc_layer_params=None,
    joint_fc_layer_params=critic_joint_fc_layer_params,
    kernel_initializer='glorot_uniform',
    last_kernel_initializer='glorot_uniform')


In [8]:
actor_net = actor_distribution_network.ActorDistributionNetwork(
    observation_spec,
    action_spec,
    fc_layer_params=actor_fc_layer_params,
    continuous_projection_net=(
        tanh_normal_projection_network.TanhNormalProjectionNetwork))


In [9]:

train_step = tf.compat.v1.train.get_or_create_global_step()

tf_agent = sac_agent.SacAgent(
    collect_env.time_step_spec(),
    action_spec,
    actor_network=actor_net,
    critic_network=critic_net,
    actor_optimizer=tf.keras.optimizers.Adam(
        learning_rate=actor_learning_rate),
    critic_optimizer=tf.keras.optimizers.Adam(
        learning_rate=critic_learning_rate),
    alpha_optimizer=tf.keras.optimizers.Adam(
        learning_rate=alpha_learning_rate),
    target_update_tau=target_update_tau,
    target_update_period=target_update_period,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=gamma,
    reward_scale_factor=reward_scale_factor,
    train_step_counter=train_step)

tf_agent.initialize()


In [10]:
eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
collect_policy = tf_agent.collect_policy

In [11]:
data_spec = tf_agent.collect_data_spec
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec, batch_size=collect_env.batch_size, max_length=replay_buffer_capacity)



In [None]:
num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()
ave_return = tf_metrics.AverageReturnMetric()

num_episodes_eval = tf_metrics.NumberOfEpisodes()
env_steps_eval = tf_metrics.EnvironmentSteps()
ave_return_eval = tf_metrics.AverageReturnMetric()

observers = [replay_buffer.add_batch, num_episodes,
             env_steps, ave_return]
observers_eval = [num_episodes_eval, env_steps_eval,
                  ave_return_eval]

def reset_eval_metrics():
    for o in observers_eval:
        o.reset()


# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3)

iterator = iter(dataset)


In [None]:
initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
    collect_env,
    collect_policy,
    observers=[replay_buffer.add_batch],
    num_steps=initial_collect_steps)
initial_collect_driver.run()


In [14]:
collect_driver = dynamic_step_driver.DynamicStepDriver(
    collect_env,
    collect_policy,
    observers=observers,
    num_steps=collect_steps_per_iteration)

eval_driver = dynamic_episode_driver.DynamicEpisodeDriver(
    collect_env,
    eval_policy,
    observers=observers_eval,
    num_episodes=num_eval_episodes)


In [15]:
shutil.rmtree(os.path.join(
    tempdir, learner.POLICY_SAVED_MODEL_DIR), ignore_errors=True)
shutil.rmtree(os.path.join(tempdir, 'checkpoint'), ignore_errors=True)
shutil.rmtree(os.path.join(tempdir, 'checkpoints'), ignore_errors=True)


In [None]:
saved_model_dir = os.path.join(tempdir, learner.POLICY_SAVED_MODEL_DIR)

# Triggers to save the agent's policy checkpoints.
learning_triggers = [
    triggers.PolicySavedModelTrigger(
        saved_model_dir,
        tf_agent,
        train_step,
        interval=policy_save_interval),
    triggers.StepPerSecondLogTrigger(train_step, interval=1000),
]


def experience_dataset_fn(): return dataset


agent_learner = learner.Learner(
    tempdir,
    train_step,
    tf_agent,
    experience_dataset_fn,
    triggers=learning_triggers)


checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=tf_agent,
    policy=tf_agent.policy,
    replay_buffer=replay_buffer,
    global_step=train_step
)


In [17]:
def get_eval_metrics():
  eval_driver.run()
  results = {}
  for metric in observers_eval:
    results[metric.name] = metric.result()
  return results


# metrics = get_eval_metrics()


In [18]:
def log_eval_metrics(step, metrics):
  eval_results = (', ').join(
      '{} = {:.6f}'.format(name, result) for name, result in metrics.items())
  print('step = {0}: {1}'.format(step, eval_results))


# log_eval_metrics(0, metrics)


In [19]:
def plot_metrics(losses, returns, training_returns):
  plt.figure(figsize=(16, 6))
  plt.subplot(1, 3, 1)
  plt.plot(losses)
  plt.title('Training Loss')
  plt.subplot(1, 3, 2)
  plt.plot(returns)
  plt.title('Evaluation Average Returns')
  ax = plt.gca()
  ax.set_ylim([ax.get_ylim()[0], 0])
  plt.subplot(1, 3, 3)
  plt.plot(training_returns)
  plt.title('Average Training Returns')
  ax = plt.gca()
  ax.set_ylim([ax.get_ylim()[0], 0])
  plt.show()


In [None]:

# Reset the train step
tf_agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = get_eval_metrics()["AverageReturn"]
returns = [avg_return]
losses = []
training_returns = []

for _ in range(num_iterations):
  # Training.
  collect_driver.run()
  loss_info = agent_learner.run(iterations=1)
  # Evaluating.
  step = agent_learner.train_step_numpy

  if eval_interval and step % eval_interval == 0:
    train_checkpointer.save(train_step)
    reset_eval_metrics()
    metrics = get_eval_metrics()
    log_eval_metrics(step, metrics)
    returns.append(metrics["AverageReturn"])
    plot_metrics(losses, returns, training_returns)

  if log_interval and step % log_interval == 0:
    print('step = {0}: loss = {1}, average return: {2}'.format(
        step, loss_info.loss.numpy(), ave_return.result().numpy()))
    training_returns.append(ave_return.result().numpy())
    losses.append(loss_info.loss.numpy())

print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
print('Average Return: ', ave_return.result().numpy())

plot_metrics(losses, returns, training_returns)


In [None]:
import numpy as np
plt.style.use('seaborn-dark')


def plot_state(x, y, new_figure=True):
    if new_figure:
        plt.figure(figsize=(10, 10))
        plt.plot([0], [0], 'o', label='Target', markersize=12, color='black')
    plt.plot(x, y, '--', label='Agent Path', linewidth=3, color='blue')
    plt.plot(x[-1], y[-1], 'o', label='Initial State',
             markersize=12, color='red')
    plt.plot(x[0], y[0], '*', label='Final State',
             markersize=15, color='green')
    if new_figure:
        plt.legend()
        plt.title('Agent Paths on Three Representative Tests', fontsize=20)
        plt.xlabel('$250(\\lambda_t - \\lambda_a)$', fontsize=16)
        plt.ylabel('$a_t-a_a$', fontsize=16)
        plt.xticks(fontsize=15)
        plt.yticks(fontsize=15)


In [None]:
def plot_metrics(losses, returns):
  plt.style.use('ggplot')
  plt.figure(figsize=(14, 7))
  plt.subplot(1, 2, 1)
  x = np.arange(len(losses))*200
  plt.plot(x, losses, linewidth=2)
  plt.title('Training Loss', fontsize=20)
  plt.ylabel('Training Loss', fontsize=14)
  plt.xlabel('Training Step', fontsize=14)
  plt.xticks(fontsize=12)
  plt.yticks(fontsize=12)
  plt.subplot(1, 2, 2)
  x = np.arange(len(returns))*2000
  plt.plot(x, returns, linewidth=2)
  plt.title('Evaluation Average Returns', fontsize=20)
  plt.xlabel('Training Step', fontsize=14)
  plt.ylabel('Average Evaluation Return', fontsize=14)
  plt.xticks(fontsize=12)
  plt.yticks(fontsize=12)
  plt.show()
