Skip to content

Commit

Permalink
Merge branch 'fix_69'
Browse files Browse the repository at this point in the history
  • Loading branch information
keiohta committed Oct 3, 2020
2 parents e472da1 + 127e42e commit 9489678
Show file tree
Hide file tree
Showing 12 changed files with 49 additions and 53 deletions.
11 changes: 5 additions & 6 deletions tf2rl/algos/bi_res_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _train_body(self, states, actions, next_states, rewards, dones, weights):

with tf.GradientTape() as tape:
next_action = self.actor(states)
actor_loss = -tf.reduce_mean(self.critic([states, next_action]))
actor_loss = -tf.reduce_mean(self.critic(states, next_action))

actor_grad = tape.gradient(
actor_loss, self.actor.trainable_variables)
Expand All @@ -58,17 +58,16 @@ def _compute_td_error_body(self, states, actions, next_states, rewards, dones):
with tf.device(self.device):
not_dones = 1. - tf.cast(dones, dtype=tf.float32)
# Compute standard TD error
target_Q = self.critic_target(
[next_states, self.actor_target(next_states)])
target_Q = self.critic_target(next_states, self.actor_target(next_states))
target_Q = rewards + (not_dones * self.discount * target_Q)
target_Q = tf.stop_gradient(target_Q)
current_Q = self.critic([states, actions])
current_Q = self.critic(states, actions)
td_errors1 = target_Q - current_Q
# Compute residual TD error
next_actions = tf.stop_gradient(self.actor(next_states))
target_Q = self.critic([next_states, next_actions])
target_Q = self.critic(next_states, next_actions)
target_Q = rewards + (not_dones * self.discount * target_Q)
current_Q = tf.stop_gradient(self.critic_target([states, actions]))
current_Q = tf.stop_gradient(self.critic_target(states, actions))
td_errors2 = target_Q - current_Q
return td_errors1, td_errors2

Expand Down
45 changes: 21 additions & 24 deletions tf2rl/algos/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

from tf2rl.algos.policy_base import OffPolicyAgent
from tf2rl.misc.target_update_ops import update_target_variables
from tf2rl.misc.huber_loss import huber_loss


class Actor(tf.keras.Model):
def __init__(self, state_shape, action_dim, max_action, units=[400, 300], name="Actor"):
def __init__(self, state_shape, action_dim, max_action, units=(400, 300), name="Actor"):
super().__init__(name=name)

self.l1 = Dense(units[0], name="L1")
Expand All @@ -18,7 +17,7 @@ def __init__(self, state_shape, action_dim, max_action, units=[400, 300], name="
self.max_action = max_action

with tf.device("/cpu:0"):
self(tf.constant(np.zeros(shape=(1,)+state_shape, dtype=np.float32)))
self(tf.constant(np.zeros(shape=(1,) + state_shape, dtype=np.float32)))

def call(self, inputs):
features = tf.nn.relu(self.l1(inputs))
Expand All @@ -29,23 +28,22 @@ def call(self, inputs):


class Critic(tf.keras.Model):
def __init__(self, state_shape, action_dim, units=[400, 300], name="Critic"):
def __init__(self, state_shape, action_dim, units=(400, 300), name="Critic"):
super().__init__(name=name)

self.l1 = Dense(units[0], name="L1")
self.l2 = Dense(units[1], name="L2")
self.l3 = Dense(1, name="L3")

dummy_state = tf.constant(
np.zeros(shape=(1,)+state_shape, dtype=np.float32))
np.zeros(shape=(1,) + state_shape, dtype=np.float32))
dummy_action = tf.constant(
np.zeros(shape=[1, action_dim], dtype=np.float32))
with tf.device("/cpu:0"):
self([dummy_state, dummy_action])
self(dummy_state, dummy_action)

def call(self, inputs):
states, actions = inputs
features = tf.concat([states, actions], axis=1)
def call(self, states, actions):
features = tf.concat((states, actions), axis=1)
features = tf.nn.relu(self.l1(features))
features = tf.nn.relu(self.l2(features))
features = self.l3(features)
Expand All @@ -61,8 +59,8 @@ def __init__(
max_action=1.,
lr_actor=0.001,
lr_critic=0.001,
actor_units=[400, 300],
critic_units=[400, 300],
actor_units=(400, 300),
critic_units=(400, 300),
sigma=0.1,
tau=0.005,
n_warmup=int(1e4),
Expand Down Expand Up @@ -108,8 +106,9 @@ def get_action(self, state, test=False, tensor=False):
def _get_action_body(self, state, sigma, max_action):
with tf.device(self.device):
action = self.actor(state)
action += tf.random.normal(shape=action.shape,
mean=0., stddev=sigma, dtype=tf.float32)
if sigma > 0.:
action += tf.random.normal(shape=action.shape,
mean=0., stddev=sigma, dtype=tf.float32)
return tf.clip_by_value(action, -max_action, max_action)

def train(self, states, actions, next_states, rewards, done, weights=None):
Expand All @@ -119,9 +118,9 @@ def train(self, states, actions, next_states, rewards, done, weights=None):
states, actions, next_states, rewards, done, weights)

if actor_loss is not None:
tf.summary.scalar(name=self.policy_name+"/actor_loss",
tf.summary.scalar(name=self.policy_name + "/actor_loss",
data=actor_loss)
tf.summary.scalar(name=self.policy_name+"/critic_loss",
tf.summary.scalar(name=self.policy_name + "/critic_loss",
data=critic_loss)

return td_errors
Expand All @@ -132,8 +131,7 @@ def _train_body(self, states, actions, next_states, rewards, done, weights):
with tf.GradientTape() as tape:
td_errors = self._compute_td_error_body(
states, actions, next_states, rewards, done)
critic_loss = tf.reduce_mean(
huber_loss(td_errors, delta=self.max_grad) * weights)
critic_loss = tf.reduce_mean(td_errors ** 2)

critic_grad = tape.gradient(
critic_loss, self.critic.trainable_variables)
Expand All @@ -142,7 +140,7 @@ def _train_body(self, states, actions, next_states, rewards, done, weights):

with tf.GradientTape() as tape:
next_action = self.actor(states)
actor_loss = -tf.reduce_mean(self.critic([states, next_action]))
actor_loss = -tf.reduce_mean(self.critic(states, next_action))

actor_grad = tape.gradient(
actor_loss, self.actor.trainable_variables)
Expand All @@ -169,10 +167,9 @@ def compute_td_error(self, states, actions, next_states, rewards, dones):
def _compute_td_error_body(self, states, actions, next_states, rewards, dones):
with tf.device(self.device):
not_dones = 1. - tf.cast(dones, dtype=tf.float32)
target_Q = self.critic_target(
[next_states, self.actor_target(next_states)])
target_Q = rewards + (not_dones * self.discount * target_Q)
target_Q = tf.stop_gradient(target_Q)
current_Q = self.critic([states, actions])
td_errors = target_Q - current_Q
next_act_target = self.actor_target(next_states)
next_q_target = self.critic_target(next_states, next_act_target)
target_q = rewards + not_dones * self.discount * next_q_target
current_q = self.critic(states, actions)
td_errors = tf.stop_gradient(target_q) - current_q
return td_errors
4 changes: 2 additions & 2 deletions tf2rl/algos/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class QFunc(tf.keras.Model):
def __init__(self, state_shape, action_dim, units=[32, 32],
def __init__(self, state_shape, action_dim, units=(32, 32),
name="QFunc", enable_dueling_dqn=False,
enable_noisy_dqn=False, enable_categorical_dqn=False,
n_atoms=51):
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
q_func=None,
name="DQN",
lr=0.001,
units=[32, 32],
units=(32, 32),
epsilon=0.1,
epsilon_min=None,
epsilon_decay_step=int(1e6),
Expand Down
4 changes: 2 additions & 2 deletions tf2rl/algos/gaifo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class Discriminator(DiscriminatorGAIL):
def __init__(self, state_shape, units=[32, 32],
def __init__(self, state_shape, units=(32, 32),
enable_sn=False, output_activation="sigmoid",
name="Discriminator"):
tf.keras.Model.__init__(self, name=name)
Expand All @@ -30,7 +30,7 @@ class GAIfO(GAIL):
def __init__(
self,
state_shape,
units=[32, 32],
units=(32, 32),
lr=0.001,
enable_sn=False,
name="GAIfO",
Expand Down
2 changes: 1 addition & 1 deletion tf2rl/algos/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class Discriminator(tf.keras.Model):
def __init__(self, state_shape, action_dim, units=[32, 32],
def __init__(self, state_shape, action_dim, units=(32, 32),
enable_sn=False, output_activation="sigmoid",
name="Discriminator"):
super().__init__(name=name)
Expand Down
2 changes: 1 addition & 1 deletion tf2rl/algos/sac_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class CriticQ(tf.keras.Model):
to Q: S -> R^|A|
"""

def __init__(self, state_shape, action_dim, critic_units=[256, 256], name='qf'):
def __init__(self, state_shape, action_dim, critic_units=(256, 256), name='qf'):
super().__init__(name=name)

self.l1 = Dense(critic_units[0], name="L1", activation='relu')
Expand Down
23 changes: 10 additions & 13 deletions tf2rl/algos/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class Critic(tf.keras.Model):
def __init__(self, state_shape, action_dim, units=[400, 300], name="Critic"):
def __init__(self, state_shape, action_dim, units=(400, 300), name="Critic"):
super().__init__(name=name)

self.l1 = Dense(units[0], name="L1")
Expand All @@ -24,11 +24,10 @@ def __init__(self, state_shape, action_dim, units=[400, 300], name="Critic"):
dummy_action = tf.constant(
np.zeros(shape=[1, action_dim], dtype=np.float32))
with tf.device("/cpu:0"):
self([dummy_state, dummy_action])
self(dummy_state, dummy_action)

def call(self, inputs):
states, actions = inputs
xu = tf.concat([states, actions], axis=1)
def call(self, states, actions):
xu = tf.concat((states, actions), axis=1)

x1 = tf.nn.relu(self.l1(xu))
x1 = tf.nn.relu(self.l2(x1))
Expand Down Expand Up @@ -92,7 +91,7 @@ def _train_body(self, states, actions, next_states, rewards, done, weights):
self._it.assign_add(1)
with tf.GradientTape() as tape:
next_actions = self.actor(states)
actor_loss = - tf.reduce_mean(self.critic([states, next_actions]))
actor_loss = - tf.reduce_mean(self.critic(states, next_actions))

remainder = tf.math.mod(self._it, self._actor_update_freq)

Expand Down Expand Up @@ -130,11 +129,9 @@ def _compute_td_error_body(self, states, actions, next_states, rewards, dones):
next_action = tf.clip_by_value(
next_action + noise, -self.actor_target.max_action, self.actor_target.max_action)

target_Q1, target_Q2 = self.critic_target(
[next_states, next_action])
target_Q = tf.minimum(target_Q1, target_Q2)
target_Q = rewards + (not_dones * self.discount * target_Q)
target_Q = tf.stop_gradient(target_Q)
current_Q1, current_Q2 = self.critic([states, actions])
next_q1_target, next_q2_target = self.critic_target(next_states, next_action)
next_q_target = tf.minimum(next_q1_target, next_q2_target)
q_target = tf.stop_gradient(rewards + not_dones * self.discount * next_q_target)
current_q1, current_q2 = self.critic(states, actions)

return target_Q - current_Q1, target_Q - current_Q2
return q_target - current_q1, q_target - current_q2
4 changes: 2 additions & 2 deletions tf2rl/algos/vail.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Discriminator(tf.keras.Model):
LOG_SIG_CAP_MIN = -20 # np.e**-10 = 4.540e-05
EPS = 1e-6

def __init__(self, state_shape, action_dim, units=[32, 32],
def __init__(self, state_shape, action_dim, units=(32, 32),
n_latent_unit=32, enable_sn=False, name="Discriminator"):
super().__init__(name=name)

Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
self,
state_shape,
action_dim,
units=[32, 32],
units=(32, 32),
n_latent_unit=32,
lr=5e-5,
kl_target=0.5,
Expand Down
4 changes: 2 additions & 2 deletions tf2rl/algos/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def __init__(
critic=None,
actor_critic=None,
max_action=1.,
actor_units=[256, 256],
critic_units=[256, 256],
actor_units=(256, 256),
critic_units=(256, 256),
lr_actor=1e-3,
lr_critic=3e-3,
hidden_activation_actor="relu",
Expand Down
1 change: 1 addition & 0 deletions tf2rl/experiments/irl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __call__(self):
obs = next_obs

if done or episode_steps == self._episode_max_steps:
replay_buffer.on_episode_end()
obs = self._env.reset()

n_episode += 1
Expand Down
1 change: 1 addition & 0 deletions tf2rl/experiments/on_policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __call__(self):
tf.summary.flush()

def finish_horizon(self, last_val=0):
self.local_buffer.on_episode_end()
samples = self.local_buffer._encode_sample(
np.arange(self.local_buffer.get_stored_size()))
rews = np.append(samples["rew"], last_val)
Expand Down
1 change: 1 addition & 0 deletions tf2rl/experiments/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __call__(self):
obs = next_obs

if done or episode_steps == self._episode_max_steps:
replay_buffer.on_episode_end()
obs = self._env.reset()

n_episode += 1
Expand Down

0 comments on commit 9489678

Please sign in to comment.