In [1]:
import warnings

warnings.simplefilter('ignore')

import tensorflow as tf
import numpy as np 

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

In [2]:
from src.simulation.example_states import RL_Network
import gym
import gym_train_env


In [3]:
tf.compat.v1.enable_v2_behavior()

In [4]:
tf.version.VERSION

'2.1.0'

In [37]:
num_iterations = 40000 

initial_collect_steps = 1000
collect_steps_per_iteration = 1
replay_buffer_max_length = 100000

batch_size = 64 
learning_rate = 1e-3
log_interval = 200

num_eval_episodes = 10 
eval_interval = 1000


In [6]:
env_name = 'train-tf-v0'

In [7]:
env = suite_gym.load(env_name)

In [8]:
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

In [27]:
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

In [10]:
fc_layer_params = (15,)
q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)


In [11]:
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(1)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()




To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



In [12]:
eval_policy = agent.policy
collect_policy = agent.collect_policy

In [13]:
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())


In [14]:
def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


In [16]:
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)


In [17]:
def collect_step(environment, policy, buffer):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
  for _ in range(steps):
    collect_step(env, policy, buffer)

collect_data(train_env, random_policy, replay_buffer, steps=100)


In [18]:
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=batch_size, 
    num_steps=2).prefetch(3)

In [21]:
iterator = iter(dataset)


In [44]:
try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few steps using collect_policy and save to the replay buffer.
  for _ in range(collect_steps_per_iteration):
    collect_step(train_env, agent.collect_policy, replay_buffer)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience).loss

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)


step = 200: loss = 73.57440948486328
step = 400: loss = 33.012603759765625
step = 600: loss = 21.662010192871094
step = 800: loss = 93.2724609375
step = 1000: loss = 23.248300552368164
step = 1000: Average Return = 65.0
step = 1200: loss = 44.3319091796875
step = 1400: loss = 10.447744369506836
step = 1600: loss = 100.12498474121094
step = 1800: loss = 37.70118713378906
step = 2000: loss = 71.86082458496094
step = 2000: Average Return = 65.0
step = 2200: loss = 35.24655532836914
step = 2400: loss = 38.35594940185547
step = 2600: loss = 9.173202514648438
step = 2800: loss = 18.139476776123047
step = 3000: loss = 46.07527160644531
step = 3000: Average Return = 65.0
step = 3200: loss = 31.991531372070312
step = 3400: loss = 68.22601318359375
step = 3600: loss = 53.76866912841797
step = 3800: loss = 45.02898406982422
step = 4000: loss = 14.941591262817383
step = 4000: Average Return = 65.0
step = 4200: loss = 41.79685592651367
step = 4400: loss = 10.952231407165527
step = 4600: loss = 67.9

step = 36000: loss = 208.04690551757812
step = 36000: Average Return = 5.0
step = 36200: loss = 118.8736572265625
step = 36400: loss = 112.92021942138672
step = 36600: loss = 173.27456665039062
step = 36800: loss = 192.6452178955078
step = 37000: loss = 129.38682556152344
step = 37000: Average Return = 5.0
step = 37200: loss = 77.99213409423828
step = 37400: loss = 110.86070251464844
step = 37600: loss = 81.22589111328125
step = 37800: loss = 92.59668731689453
step = 38000: loss = 62.86809539794922
step = 38000: Average Return = 5.0
step = 38200: loss = 117.46170043945312
step = 38400: loss = 82.58624267578125
step = 38600: loss = 169.33485412597656
step = 38800: loss = 93.06294250488281
step = 39000: loss = 112.71649169921875
step = 39000: Average Return = 5.0
step = 39200: loss = 113.35415649414062
step = 39400: loss = 118.10310363769531
step = 39600: loss = 178.39678955078125
step = 39800: loss = 112.40154266357422
step = 40000: loss = 129.64846801757812
step = 40000: Average Return

In [43]:
for _ in range(20):
    time_step = eval_env.reset()
    while not time_step.is_last():
        action_step = agent.policy.action(time_step)
        time_step = eval_env.step(action_step.action)
    
    eval_py_env.render()
    print('')

 1:56 -  5:54 -  5:55 -  5:57 -  5:58 -  5:59 - 
 2: 9 -  5:17 -  5:55 -  5:59 -  5:58 -  0: 0 - 
 0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 - 

 1:56 -  5:54 -  5:55 -  5:57 -  5:58 -  5:59 - 
 2: 9 -  5:17 -  5:55 -  5:59 -  5:58 -  0: 0 - 
 0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 - 

 1:56 -  5:54 -  5:55 -  5:57 -  5:58 -  5:59 - 
 2: 9 -  5:17 -  5:55 -  5:59 -  5:58 -  0: 0 - 
 0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 - 

 1:56 -  5:54 -  5:55 -  5:57 -  5:58 -  5:59 - 
 2: 9 -  5:17 -  5:55 -  5:59 -  5:58 -  0: 0 - 
 0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 - 

 1:56 -  5:54 -  5:55 -  5:57 -  5:58 -  5:59 - 
 2: 9 -  5:17 -  5:55 -  5:59 -  5:58 -  0: 0 - 
 0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 - 

 1:56 -  5:54 -  5:55 -  5:57 -  5:58 -  5:59 - 
 2: 9 -  5:17 -  5:55 -  5:59 -  5:58 -  0: 0 - 
 0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 -  0: 0 - 

 1:56 -  5:54 -  5:55 -  5:57 -  5:58 -  5:59 - 
 2: 9 -  5:17 -  5:55 -  5:59 -  5:58 -  0: 0 - 
 0: 0 -  0: 0 