Skip to content

Commit

Permalink
Merge pull request #534 from muupan/fix-sarsa-gpu
Browse files Browse the repository at this point in the history
Fix ValueError in SARSA with GPU
  • Loading branch information
muupan committed Aug 30, 2019
2 parents 0158d0b + a3cf4b2 commit 49425d0
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion chainerrl/agents/sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from future import standard_library
standard_library.install_aliases() # NOQA

import chainer

from chainerrl.agents import dqn


Expand All @@ -27,9 +29,11 @@ def _compute_target_values(self, exp_batch):
else:
target_next_qout = self.target_model(batch_next_state)
# Choose an action using the behavior policy
next_greedy_actions = chainer.cuda.to_cpu(
target_next_qout.greedy_actions.array)
batch_next_action = self.xp.array([
self.explorer.select_action(
self.t, lambda: target_next_qout.greedy_actions.array[i],
self.t, lambda: next_greedy_actions[i],
action_value=target_next_qout[i:i + 1],
)
for i in range(len(exp_batch['action']))])
Expand Down

0 comments on commit 49425d0

Please sign in to comment.