Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent DQN families with a new interface #436

Merged
merged 57 commits into from Aug 9, 2019
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
a65be21
Merge branch 'recurrent-ppo-squash' into recurrent-dqn
muupan Apr 6, 2019
03f1e06
Use new recurrent model in DQN
muupan Apr 7, 2019
19b8051
Fix double dqn
muupan Apr 7, 2019
91bc211
Add DRQN example
muupan Apr 7, 2019
a85f32a
Fix dtype error
muupan Apr 7, 2019
24459d8
Use mean of loss to avoid effect from sequence length
muupan Apr 7, 2019
1bdfaa6
Stabilize tests
muupan Apr 8, 2019
b0a4063
Renew recurrent AL
muupan Apr 8, 2019
ac6bd4d
Renew recurrent PAL
muupan Apr 8, 2019
f57e1d9
Renew recurrent DoublePAL
muupan Apr 8, 2019
78e0059
Add recurrent DPP
muupan Apr 8, 2019
15740ff
Make SARSA off-policy, batched, and recurrent
muupan Apr 8, 2019
9783ab7
Renew recurrent CategoricalDQN
muupan Apr 8, 2019
8ed0587
Make IQN recurrent
muupan Apr 8, 2019
e9efeaf
Fix to support py2
muupan Apr 8, 2019
9f41f5c
Renew recurrent ResidualDQN
muupan Apr 8, 2019
c1a1581
Remove --episodic-replay options
muupan Apr 8, 2019
b6fdf53
Update examples/ale/train_drqn_ale.py
prabhatnagarajan May 6, 2019
90f0b5a
Update examples/ale/train_drqn_ale.py
prabhatnagarajan May 6, 2019
9febe00
Update examples/ale/train_ppo_ale.py
prabhatnagarajan May 6, 2019
c7a519d
Add options
muupan Apr 11, 2019
162330a
Merge branch 'recurrent-ppo-squash' into recurrent-dqn
muupan May 6, 2019
496ffc6
Merge branch 'master' into recurrent-dqn
muupan May 6, 2019
bb40059
Use the new recurrent interface for CategoricalDoubleDQN
muupan May 6, 2019
6dc49e9
Mark IQN as supporting recurrent
muupan May 6, 2019
730ee88
Merge branch 'master' into recurrent-dqn
muupan May 11, 2019
61c8d06
Merge branch 'master' into recurrent-dqn
muupan Jun 25, 2019
c631ed7
Add --final-eval-n-episodes for faster testing
muupan Jun 25, 2019
cdf3d2c
Add test for examples/ale/train_drqn_ale.py
muupan Jun 25, 2019
2194d7d
Move to examples/atari
muupan Jul 1, 2019
d974719
Merge branch 'master' into recurrent-dqn
muupan Jul 1, 2019
7e1ef1a
Rename the option
muupan Jul 1, 2019
ba5bfc3
Add test for train_drqn_ale.py
muupan Jul 1, 2019
573d29b
Add DRQN to examples/atari/README.md
muupan Jul 1, 2019
cf848d6
Remove examples_tests/ale/test_drqn.sh
muupan Jul 1, 2019
6ce8fb6
Update chainerrl/agents/dqn.py
muupan Jul 1, 2019
55241fd
Update examples_tests/atari/test_drqn.sh
muupan Jul 1, 2019
52105cd
Update examples_tests/atari/test_drqn.sh
muupan Jul 1, 2019
8499465
Update examples_tests/atari/test_drqn.sh
muupan Jul 1, 2019
8f89ee4
Merge branch 'master' into recurrent-dqn
muupan Jul 17, 2019
eccee63
Update chainerrl/replay_buffer.py
muupan Jul 18, 2019
0722695
Merge branch 'master' into recurrent-dqn
muupan Jul 18, 2019
ea692c7
Add DRQN to README
muupan Jul 18, 2019
df482b2
Add assert message
muupan Jul 18, 2019
aea32d7
Update chainerrl/agents/sarsa.py
muupan Jul 31, 2019
4100079
Update chainerrl/agents/sarsa.py
muupan Jul 31, 2019
c5df80a
Restore chainer.using_config('train', False)
muupan Jul 31, 2019
2c68749
Reduce redundancy
muupan Jul 31, 2019
f0a67bd
Use a full path
muupan Jul 31, 2019
65b25cd
Merge two methods into one
muupan Jul 31, 2019
2360686
Merge branch 'master' into recurrent-dqn
muupan Jul 31, 2019
78b5e3c
Fix syntax error with python2
muupan Jul 31, 2019
4e5d0b9
Update examples_tests/atari/test_drqn.sh
muupan Aug 9, 2019
fdca04d
Remove commented out code
muupan Aug 9, 2019
22d13df
Update examples_tests/atari/test_drqn.sh
muupan Aug 9, 2019
467f845
Update examples/atari/train_drqn_ale.py
muupan Aug 9, 2019
9070c76
Add a link to DRQN paper
muupan Aug 9, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Expand Up @@ -41,7 +41,7 @@ For more information, you can refer to [ChainerRL's documentation](http://chaine
| DQN (including DoubleDQN etc.) | ✓ | ✓ (NAF) | ✓ | x |
| Categorical DQN | ✓ | x | ✓ | x |
| Rainbow | ✓ | x | ✓ | x |
| IQN | ✓ | x | x | x |
| IQN | ✓ | x | | x |
muupan marked this conversation as resolved.
Show resolved Hide resolved
| DDPG | x | ✓ | ✓ | x |
| A3C | ✓ | ✓ | ✓ | ✓ |
| ACER | ✓ | ✓ | ✓ | ✓ |
Expand All @@ -63,7 +63,7 @@ Following algorithms have been implemented in ChainerRL:
- [Categorical DQN](https://arxiv.org/abs/1707.06887)
- examples: [[atari]](examples/atari/train_categorical_dqn_ale.py) [[general gym]](examples/gym/train_categorical_dqn_gym.py)
- [DQN (Deep Q-Network)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) (including [Double DQN](https://arxiv.org/abs/1509.06461), [Persistent Advantage Learning (PAL)](https://arxiv.org/abs/1512.04860), Double PAL, [Dynamic Policy Programming (DPP)](http://www.jmlr.org/papers/volume13/azar12a/azar12a.pdf))
- examples: [[atari reproduction]](examples/atari/reproduction/dqn) [[atari]](examples/atari/train_dqn_ale.py) [[atari (batched)]](examples/atari/train_dqn_batch_ale.py) [[general gym]](examples/gym/train_dqn_gym.py)
- examples: [[atari reproduction]](examples/atari/reproduction/dqn) [[atari]](examples/atari/train_dqn_ale.py) [[atari (batched)]](examples/atari/train_dqn_batch_ale.py) [[flickering atari]](examples/atari/train_drqn_ale.py) [[general gym]](examples/gym/train_dqn_gym.py)
- [DDPG (Deep Deterministic Policy Gradients)](https://arxiv.org/abs/1509.02971) (including [SVG(0)](https://arxiv.org/abs/1510.09142))
- examples: [[mujoco reproduction]](examples/mujoco/reproduction/ddpg) [[mujoco]](examples/mujoco/train_ddpg_gym.py) [[mujoco (batched)]](examples/mujoco/train_ddpg_batch_gym.py)
- [IQN (Implicit Quantile Networks)](https://arxiv.org/abs/1806.06923)
Expand All @@ -90,6 +90,8 @@ Following useful techniques have been also implemented in ChainerRL:
- examples: [[Rainbow]](examples/atari/reproduction/rainbow) [[DQN/DoubleDQN/PAL]](examples/atari/train_dqn_ale.py)
- [Normalized Advantage Function](https://arxiv.org/abs/1603.00748)
- examples: [[DQN]](examples/gym/train_dqn_gym.py) (for continuous-action envs only)
- [Deep Recurrent Q-Network](https://arxiv.org/abs/1507.06527)
- examples: [[DQN]](examples/atari/train_drqn_ale.py)


## Visualization
Expand Down
28 changes: 17 additions & 11 deletions chainerrl/agents/al.py
Expand Up @@ -10,7 +10,6 @@
from chainer import functions as F

from chainerrl.agents import dqn
from chainerrl.recurrent import state_kept


class AL(dqn.DQN):
Expand All @@ -34,22 +33,32 @@ def _compute_y_and_t(self, exp_batch):
batch_state = exp_batch['state']
batch_size = len(exp_batch['reward'])

qout = self.q_function(batch_state)
if self.recurrent:
qout, _ = self.model.n_step_forward(
batch_state, exp_batch['recurrent_state'],
output_mode='concat')
else:
qout = self.model(batch_state)

batch_actions = exp_batch['action']

batch_q = qout.evaluate_actions(batch_actions)

# Compute target values
batch_next_state = exp_batch['next_state']

with chainer.no_backprop_mode():
target_qout = self.target_q_function(batch_state)
if self.recurrent:
target_qout, _ = self.target_model.n_step_forward(
batch_state, exp_batch['recurrent_state'],
output_mode='concat')
target_next_qout, _ = self.target_model.n_step_forward(
batch_next_state, exp_batch['next_recurrent_state'],
output_mode='concat')
else:
target_qout = self.target_model(batch_state)
target_next_qout = self.target_model(batch_next_state)

batch_next_state = exp_batch['next_state']

with state_kept(self.target_q_function):
target_next_qout = self.target_q_function(
batch_next_state)
next_q_max = F.reshape(target_next_qout.max, (batch_size,))

batch_rewards = exp_batch['reward']
Expand All @@ -65,6 +74,3 @@ def _compute_y_and_t(self, exp_batch):
tal_q = t_q + self.alpha * cur_advantage

return batch_q, tal_q

def input_initial_batch_to_target_model(self, batch):
pass
16 changes: 11 additions & 5 deletions chainerrl/agents/categorical_double_dqn.py
Expand Up @@ -9,7 +9,6 @@

from chainerrl.agents import categorical_dqn
from chainerrl.agents.categorical_dqn import _apply_categorical_projection
from chainerrl.recurrent import state_kept


class CategoricalDoubleDQN(categorical_dqn.CategoricalDQN):
Expand All @@ -24,10 +23,17 @@ def _compute_target_values(self, exp_batch):
batch_rewards = exp_batch['reward']
batch_terminal = exp_batch['is_state_terminal']

with chainer.using_config('train', False), state_kept(self.q_function):
muupan marked this conversation as resolved.
Show resolved Hide resolved
next_qout = self.q_function(batch_next_state)

target_next_qout = self.target_q_function(batch_next_state)
with chainer.using_config('train', False):
if self.recurrent:
target_next_qout, _ = self.target_model.n_step_forward(
batch_next_state, exp_batch['next_recurrent_state'],
output_mode='concat')
next_qout, _ = self.model.n_step_forward(
batch_next_state, exp_batch['next_recurrent_state'],
output_mode='concat')
else:
target_next_qout = self.target_model(batch_next_state)
next_qout = self.model(batch_next_state)

next_q_max = target_next_qout.evaluate_actions(
next_qout.greedy_actions)
Expand Down
14 changes: 12 additions & 2 deletions chainerrl/agents/categorical_dqn.py
Expand Up @@ -129,7 +129,12 @@ def _compute_target_values(self, exp_batch):
"""Compute a batch of target return distributions."""

batch_next_state = exp_batch['next_state']
target_next_qout = self.target_model(batch_next_state)
if self.recurrent:
target_next_qout, _ = self.target_model.n_step_forward(
batch_next_state, exp_batch['next_recurrent_state'],
output_mode='concat')
else:
target_next_qout = self.target_model(batch_next_state)

batch_rewards = exp_batch['reward']
batch_terminal = exp_batch['is_state_terminal']
Expand Down Expand Up @@ -158,7 +163,12 @@ def _compute_y_and_t(self, exp_batch):
batch_state = exp_batch['state']

# (batch_size, n_actions, n_atoms)
qout = self.model(batch_state)
if self.recurrent:
qout, _ = self.model.n_step_forward(
batch_state, exp_batch['recurrent_state'],
output_mode='concat')
else:
qout = self.model(batch_state)
n_atoms = qout.z_values.size

batch_actions = exp_batch['action']
Expand Down
23 changes: 18 additions & 5 deletions chainerrl/agents/double_dqn.py
Expand Up @@ -8,7 +8,6 @@
import chainer

from chainerrl.agents import dqn
from chainerrl.recurrent import state_kept


class DoubleDQN(dqn.DQN):
Expand All @@ -21,10 +20,24 @@ def _compute_target_values(self, exp_batch):

batch_next_state = exp_batch['next_state']

with chainer.using_config('train', False), state_kept(self.q_function):
next_qout = self.q_function(batch_next_state)

target_next_qout = self.target_q_function(batch_next_state)
with chainer.using_config('train', False):
if self.recurrent:
next_qout, _ = self.model.n_step_forward(
batch_next_state,
exp_batch['next_recurrent_state'],
output_mode='concat',
)
else:
next_qout = self.model(batch_next_state)

if self.recurrent:
target_next_qout, _ = self.target_model.n_step_forward(
batch_next_state,
exp_batch['next_recurrent_state'],
output_mode='concat',
)
else:
target_next_qout = self.target_model(batch_next_state)

next_q_max = target_next_qout.evaluate_actions(
next_qout.greedy_actions)
Expand Down
30 changes: 20 additions & 10 deletions chainerrl/agents/double_pal.py
Expand Up @@ -10,7 +10,6 @@
from chainer import functions as F

from chainerrl.agents import pal
from chainerrl.recurrent import state_kept


class DoublePAL(pal.PAL):
Expand All @@ -20,24 +19,35 @@ def _compute_y_and_t(self, exp_batch):
batch_state = exp_batch['state']
batch_size = len(exp_batch['reward'])

qout = self.q_function(batch_state)
if self.recurrent:
qout, _ = self.model.n_step_forward(
batch_state, exp_batch['recurrent_state'],
output_mode='concat')
else:
qout = self.model(batch_state)

batch_actions = exp_batch['action']
batch_q = qout.evaluate_actions(batch_actions)

# Compute target values

with chainer.no_backprop_mode():
target_qout = self.target_q_function(batch_state)

batch_next_state = exp_batch['next_state']
if self.recurrent:
next_qout, _ = self.model.n_step_forward(
batch_next_state, exp_batch['next_recurrent_state'],
output_mode='concat')
target_qout, _ = self.target_model.n_step_forward(
batch_state, exp_batch['recurrent_state'],
output_mode='concat')
target_next_qout, _ = self.target_model.n_step_forward(
batch_next_state, exp_batch['next_recurrent_state'],
output_mode='concat')
else:
next_qout = self.model(batch_next_state)
target_qout = self.target_model(batch_state)
target_next_qout = self.target_model(batch_next_state)

with state_kept(self.q_function):
next_qout = self.q_function(batch_next_state)

with state_kept(self.target_q_function):
target_next_qout = self.target_q_function(
batch_next_state)
next_q_max = F.reshape(target_next_qout.evaluate_actions(
next_qout.greedy_actions), (batch_size,))

Expand Down
21 changes: 18 additions & 3 deletions chainerrl/agents/dpp.py
Expand Up @@ -30,7 +30,12 @@ def _compute_target_values(self, exp_batch):

batch_next_state = exp_batch['next_state']

target_next_qout = self.target_q_function(batch_next_state)
if self.recurrent:
target_next_qout, _ = self.target_model.n_step_forward(
batch_next_state, exp_batch['next_recurrent_state'],
output_mode='concat')
else:
target_next_qout = self.target_model(batch_next_state)
next_q_expect = self._l_operator(target_next_qout)

batch_rewards = exp_batch['reward']
Expand All @@ -44,7 +49,12 @@ def _compute_y_and_t(self, exp_batch):
batch_state = exp_batch['state']
batch_size = len(exp_batch['reward'])

qout = self.q_function(batch_state)
if self.recurrent:
qout, _ = self.model.n_step_forward(
batch_state, exp_batch['recurrent_state'],
output_mode='concat')
else:
qout = self.model(batch_state)

batch_actions = exp_batch['action']
# Q(s_t,a_t)
Expand All @@ -53,7 +63,12 @@ def _compute_y_and_t(self, exp_batch):

with chainer.no_backprop_mode():
# Compute target values
target_qout = self.target_q_function(batch_state)
if self.recurrent:
target_qout, _ = self.target_model.n_step_forward(
batch_state, exp_batch['recurrent_state'],
output_mode='concat')
else:
target_qout = self.target_model(batch_state)

# Q'(s_t,a_t)
target_q = F.reshape(target_qout.evaluate_actions(
Expand Down