Skip to content

Commit

Permalink
Merge 9423ea9 into 8d7004b
Browse files Browse the repository at this point in the history
  • Loading branch information
prabhatnagarajan committed Nov 2, 2018
2 parents 8d7004b + 9423ea9 commit 1163f50
Show file tree
Hide file tree
Showing 16 changed files with 279 additions and 146 deletions.
4 changes: 2 additions & 2 deletions chainerrl/agents/al.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, *args, **kwargs):
self.alpha = kwargs.pop('alpha', 0.9)
super().__init__(*args, **kwargs)

def _compute_y_and_t(self, exp_batch, gamma):
def _compute_y_and_t(self, exp_batch):

batch_state = exp_batch['state']
batch_size = len(exp_batch['reward'])
Expand All @@ -56,7 +56,7 @@ def _compute_y_and_t(self, exp_batch, gamma):
batch_terminal = exp_batch['is_state_terminal']

# T Q: Bellman operator
t_q = batch_rewards + self.gamma * \
t_q = batch_rewards + exp_batch['discount'] * \
(1.0 - batch_terminal) * next_q_max

# T_AL Q: advantage learning operator
Expand Down
14 changes: 8 additions & 6 deletions chainerrl/agents/categorical_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class CategoricalDQN(dqn.DQN):
DistributionalDiscreteActionValue and clip_delta is ignored.
"""

def _compute_target_values(self, exp_batch, gamma):
def _compute_target_values(self, exp_batch):
"""Compute a batch of target return distributions."""

batch_next_state = exp_batch['next_state']
Expand All @@ -100,10 +100,12 @@ def _compute_target_values(self, exp_batch, gamma):

# Tz: (batch_size, n_atoms)
Tz = (batch_rewards[..., None]
+ (1.0 - batch_terminal[..., None]) * gamma * z_values[None])
+ (1.0 - batch_terminal[..., None])
* self.xp.expand_dims(exp_batch['discount'], 1)
* z_values[None])
return _apply_categorical_projection(Tz, next_q_max, z_values)

def _compute_y_and_t(self, exp_batch, gamma):
def _compute_y_and_t(self, exp_batch):
"""Compute a batch of predicted/target return distributions."""

batch_size = exp_batch['reward'].shape[0]
Expand All @@ -120,14 +122,14 @@ def _compute_y_and_t(self, exp_batch, gamma):
assert batch_q.shape == (batch_size, n_atoms)

with chainer.no_backprop_mode():
batch_q_target = self._compute_target_values(exp_batch, gamma)
batch_q_target = self._compute_target_values(exp_batch)
assert batch_q_target.shape == (batch_size, n_atoms)

return batch_q, batch_q_target

def _compute_loss(self, exp_batch, gamma, errors_out=None):
def _compute_loss(self, exp_batch, errors_out=None):
"""Compute a loss of categorical DQN."""
y, t = self._compute_y_and_t(exp_batch, gamma)
y, t = self._compute_y_and_t(exp_batch)
# Minimize the cross entropy
# y is clipped to avoid log(0)
eltwise_loss = -t * F.log(F.clip(y, 1e-10, 1.))
Expand Down
4 changes: 2 additions & 2 deletions chainerrl/agents/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def compute_actor_loss(self, batch):
def update(self, experiences, errors_out=None):
"""Update the model from experiences"""

batch = batch_experiences(experiences, self.xp, self.phi)
batch = batch_experiences(experiences, self.xp, self.phi, self.gamma)
self.critic_optimizer.update(lambda: self.compute_critic_loss(batch))
self.actor_optimizer.update(lambda: self.compute_actor_loss(batch))

Expand All @@ -273,7 +273,7 @@ def update_from_episodes(self, episodes, errors_out=None):
break
transitions.append(ep[i])
batch = batch_experiences(
transitions, xp=self.xp, phi=self.phi)
transitions, xp=self.xp, phi=self.phi, gamma=self.gamma)
batches.append(batch)

with self.model.state_reset(), self.target_model.state_reset():
Expand Down
5 changes: 3 additions & 2 deletions chainerrl/agents/double_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DoubleDQN(dqn.DQN):
See: http://arxiv.org/abs/1509.06461.
"""

def _compute_target_values(self, exp_batch, gamma):
def _compute_target_values(self, exp_batch):

batch_next_state = exp_batch['next_state']

Expand All @@ -31,5 +31,6 @@ def _compute_target_values(self, exp_batch, gamma):

batch_rewards = exp_batch['reward']
batch_terminal = exp_batch['is_state_terminal']
discount = exp_batch['discount']

return batch_rewards + self.gamma * (1.0 - batch_terminal) * next_q_max
return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max
4 changes: 2 additions & 2 deletions chainerrl/agents/double_pal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class DoublePAL(pal.PAL):

def _compute_y_and_t(self, exp_batch, gamma):
def _compute_y_and_t(self, exp_batch):

batch_state = exp_batch['state']
batch_size = len(exp_batch['reward'])
Expand Down Expand Up @@ -45,7 +45,7 @@ def _compute_y_and_t(self, exp_batch, gamma):
batch_terminal = exp_batch['is_state_terminal']

# T Q: Bellman operator
t_q = batch_rewards + self.gamma * \
t_q = batch_rewards + exp_batch['discount'] * \
(1.0 - batch_terminal) * next_q_max

# T_PAL Q: persistent advantage learning operator
Expand Down
8 changes: 4 additions & 4 deletions chainerrl/agents/dpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class AbstractDPP(with_metaclass(ABCMeta, DQN)):
def _l_operator(self, qout):
raise NotImplementedError()

def _compute_target_values(self, exp_batch, gamma):
def _compute_target_values(self, exp_batch):

batch_next_state = exp_batch['next_state']

Expand All @@ -37,9 +37,9 @@ def _compute_target_values(self, exp_batch, gamma):
batch_terminal = exp_batch['is_state_terminal']

return (batch_rewards +
self.gamma * (1 - batch_terminal) * next_q_expect)
exp_batch['discount'] * (1 - batch_terminal) * next_q_expect)

def _compute_y_and_t(self, exp_batch, gamma):
def _compute_y_and_t(self, exp_batch):

batch_state = exp_batch['state']
batch_size = len(exp_batch['reward'])
Expand All @@ -65,7 +65,7 @@ def _compute_y_and_t(self, exp_batch, gamma):

# r + g * LQ'(s_{t+1},a)
batch_q_target = F.reshape(
self._compute_target_values(exp_batch, gamma), (batch_size, 1))
self._compute_target_values(exp_batch), (batch_size, 1))

# Q'(s_t,a_t) + r + g * LQ'(s_{t+1},a) - LQ'(s_t,a)
t = target_q + batch_q_target - target_q_expect
Expand Down
47 changes: 25 additions & 22 deletions chainerrl/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,28 +201,27 @@ def update(self, experiences, errors_out=None):
This function is thread-safe.
Args:
experiences (list): list of dict that contains
experiences (list): list of dicts that contains
state: cupy.ndarray or numpy.ndarray
action: int [0, n_action_types)
reward: float32
next_state: cupy.ndarray or numpy.ndarray
next_legal_actions: list of booleans; True means legal
gamma (float): discount factor
Returns:
None
"""

has_weight = 'weight' in experiences[0]
exp_batch = batch_experiences(experiences, xp=self.xp, phi=self.phi,
batch_states=self.batch_states)
has_weight = 'weight' in experiences[0][0]
exp_batch = batch_experiences(
experiences, xp=self.xp,
phi=self.phi, gamma=self.gamma,
batch_states=self.batch_states)
if has_weight:
exp_batch['weights'] = self.xp.asarray(
[elem['weight'] for elem in experiences],
[elem[0]['weight']for elem in experiences],
dtype=self.xp.float32)
if errors_out is None:
errors_out = []
loss = self._compute_loss(
exp_batch, self.gamma, errors_out=errors_out)
loss = self._compute_loss(exp_batch, errors_out=errors_out)
if has_weight:
self.replay_buffer.update_errors(errors_out)

Expand Down Expand Up @@ -250,6 +249,7 @@ def update_from_episodes(self, episodes, errors_out=None):
for _ in episodes:
errors_out.append(0.0)
errors_out_step = []

with state_reset(self.model), state_reset(self.target_model):
loss = 0
tmp = list(reversed(sorted(
Expand All @@ -266,16 +266,18 @@ def update_from_episodes(self, episodes, errors_out=None):
transitions.append(ep[i])
if has_weights:
weights_step.append(weights[index])
batch = batch_experiences(transitions,
xp=self.xp,
phi=self.phi,
batch_states=self.batch_states)
batch = batch_experiences(
[transitions],
xp=self.xp,
phi=self.phi,
gamma=self.gamma,
batch_states=self.batch_states)
if i == 0:
self.input_initial_batch_to_target_model(batch)
if has_weights:
batch['weights'] = self.xp.asarray(
weights_step, dtype=self.xp.float32)
loss += self._compute_loss(batch, self.gamma,
loss += self._compute_loss(batch,
errors_out=errors_out_step)
if errors_out is not None:
for err, index in zip(errors_out_step, indices):
Expand All @@ -293,18 +295,19 @@ def update_from_episodes(self, episodes, errors_out=None):
if has_weights:
self.replay_buffer.update_errors(errors_out)

def _compute_target_values(self, exp_batch, gamma):
def _compute_target_values(self, exp_batch):
batch_next_state = exp_batch['next_state']

target_next_qout = self.target_model(batch_next_state)
next_q_max = target_next_qout.max

batch_rewards = exp_batch['reward']
batch_terminal = exp_batch['is_state_terminal']
discount = exp_batch['discount']

return batch_rewards + self.gamma * (1.0 - batch_terminal) * next_q_max
return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max

def _compute_y_and_t(self, exp_batch, gamma):
def _compute_y_and_t(self, exp_batch):
batch_size = exp_batch['reward'].shape[0]

# Compute Q-values for current states
Expand All @@ -318,22 +321,22 @@ def _compute_y_and_t(self, exp_batch, gamma):

with chainer.no_backprop_mode():
batch_q_target = F.reshape(
self._compute_target_values(exp_batch, gamma),
self._compute_target_values(exp_batch),
(batch_size, 1))

return batch_q, batch_q_target

def _compute_loss(self, exp_batch, gamma, errors_out=None):
def _compute_loss(self, exp_batch, errors_out=None):
"""Compute the Q-learning loss for a batch of experiences
Args:
experiences (list): see update()'s docstring
gamma (float): discount factor
discount (float): Amount by the Q-values should be discounted
Returns:
loss
Computed loss from the minibatch of experiences
"""
y, t = self._compute_y_and_t(exp_batch, gamma)
y, t = self._compute_y_and_t(exp_batch)

if errors_out is not None:
del errors_out[:]
Expand Down
4 changes: 2 additions & 2 deletions chainerrl/agents/pal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, *args, **kwargs):
self.alpha = kwargs.pop('alpha', 0.9)
super().__init__(*args, **kwargs)

def _compute_y_and_t(self, exp_batch, gamma):
def _compute_y_and_t(self, exp_batch):

batch_state = exp_batch['state']
batch_size = len(exp_batch['reward'])
Expand All @@ -55,7 +55,7 @@ def _compute_y_and_t(self, exp_batch, gamma):
batch_terminal = exp_batch['is_state_terminal']

# T Q: Bellman operator
t_q = batch_rewards + self.gamma * \
t_q = batch_rewards + exp_batch['discount'] * \
(1.0 - batch_terminal) * next_q_max

# T_PAL Q: persistent advantage learning operator
Expand Down
3 changes: 2 additions & 1 deletion chainerrl/agents/pcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,11 @@ def update_from_replay(self):
for ep in sorted_episodes:
if len(ep) <= t:
break
transitions.append(ep[t])
transitions.append([ep[t]])
batch = batch_experiences(transitions,
xp=self.xp,
phi=self.phi,
gamma=self.gamma,
batch_states=self.batch_states)
batchsize = batch['action'].shape[0]
if next_action_distrib is not None:
Expand Down
5 changes: 3 additions & 2 deletions chainerrl/agents/sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class SARSA(dqn.DQN):
compute target Q values, thus is an on-policy algorithm.
"""

def _compute_target_values(self, exp_batch, gamma):
def _compute_target_values(self, exp_batch):

batch_next_state = exp_batch['next_state']
batch_next_action = exp_batch['next_action']
Expand All @@ -28,5 +28,6 @@ def _compute_target_values(self, exp_batch, gamma):

batch_rewards = exp_batch['reward']
batch_terminal = exp_batch['is_state_terminal']
discount = exp_batch['discount']

return batch_rewards + self.gamma * (1.0 - batch_terminal) * next_q
return batch_rewards + discount * (1.0 - batch_terminal) * next_q
Loading

0 comments on commit 1163f50

Please sign in to comment.