Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent DQN families with a new interface #436

Merged
merged 57 commits into from Aug 9, 2019

Conversation

muupan
Copy link
Member

@muupan muupan commented Apr 7, 2019

Merge #431 before this PR.

This PR resolves #112

  • Use the new StatelessRecurrent interface to support recurrent models in DQN variants.
  • Add examples/ale/train_drqn_ale.py as a solid example of recurrent DQN.
    • Remove options for recurrent DQN from other examples for simplicity.

TODO

  • Evaluate recurrent DQN on flickered Atari
  • Compare computational efficiency against the old recurrent interface.
  • Fix all the affected agents that inherits DQN
  • Support CategoricalDoubleDQN

@muupan muupan changed the title [WIP] Recurrent DQN with a new interface [WIP] Recurrent DQN families with a new interface Apr 8, 2019
since it does not make much sense to use this without recurrent models.
supporting recurrent models in these examples can be future work.
Copy link
Contributor

@prabhatnagarajan prabhatnagarajan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add checkboxes in the main ChainerRL repo README for the models that are now supported as recurrent?

examples/ale/train_drqn_ale.py Outdated Show resolved Hide resolved
examples/ale/train_drqn_ale.py Outdated Show resolved Hide resolved
examples/ale/train_ppo_ale.py Outdated Show resolved Hide resolved
@muupan
Copy link
Member Author

muupan commented May 6, 2019

I performed experiments to validate recurrent DQN on flickering Pong with 1-frame observations and p=0.5.

  • DoubleDQN 1-frame flicker: examples/ale/train_drqn_ale.py --flicker --no-frame-stack --env PongNoFrameskip-v4
    • not recurrent, batch size is 32
  • DoubleDQN 1-frame flicker recurrent: examples/ale/train_drqn_ale.py --flicker --no-frame-stack --recurrent --env PongNoFrameskip-v4
    • recurrent, each batch consists of 32 subsequences of up-to-10 steps.
  • DoubleDQN 1-frame flicker recurrent small: examples/ale/train_drqn_ale.py --flicker --no-frame-stack --recurrent --batch-size 4 --episodic-update-len 8 --env PongNoFrameskip-v4
    • recurrent, each batch consists of 4 subsequences of up-to-8 steps.

Each configuration is evaluated with 3 trials with 3 different random seeds.

image

As you can see from the elapsed column, recurrent is ~3x slower than non-recurrent one, while recurrent small is ~2x slower.

steps	episodes	elapsed	mean	median	stdev	max	min	average_q	average_loss	n_updates
250354	274	1364.0595016479492	-20.97142857142857	-21.0	0.1671968272244295	-20.0	-21.0	-0.15472846694227368	0.010731786431286633	50089
500592	558	2847.7919516563416	-21.0	-21.0	0.0	-21.0	-21.0	-0.33754477598757904	0.01057778967586348	112648
750113	849	4324.770362138748	-20.993865030674847	-21.0	0.07832604499879574	-20.0	-21.0	-0.5182035528297387	0.011443975988352832	175029
1000676	1151	5812.36558508873	-21.0	-21.0	0.0	-21.0	-21.0	-0.6665175841253652	0.011932912314109872	237669
steps	episodes	elapsed	mean	median	stdev	max	min	average_q	average_loss	n_updates
250750	271	3796.5660257339478	-20.42281879194631	-20.0	0.6171119313874047	-18.0	-21.0	-0.16072401217893154	0.005658499339007796	50047
500779	532	8020.790464639664	-20.626666666666665	-21.0	0.48530992887983576	-20.0	-21.0	-0.32181127636568474	0.003985276161444072	112554
751002	767	12225.83938908577	-20.053030303030305	-20.0	0.822530019307808	-18.0	-21.0	-0.34344224204186496	0.004697177613578364	175110
1000710	962	16429.62428689003	-18.55056179775281	-19.0	1.544868602496196	-14.0	-21.0	-0.3529464231919613	0.004316390503204711	237537
steps	episodes	elapsed	mean	median	stdev	max	min	average_q	average_loss	n_updates
250325	276	2906.1080510616302	-21.0	-21.0	0.0	-21.0	-21.0	-0.15629060889382612	0.011670349162512314	49941
500477	551	6064.063723325729	-20.911392405063292	-21.0	0.2850800945771471	-20.0	-21.0	-0.3575715294666698	0.007044254700727394	112479
750165	828	9194.020221710205	-20.71641791044776	-21.0	0.5290357279818776	-19.0	-21.0	-0.4687755320280299	0.005378178538859955	174901
1000720	1090	12292.375975370407	-20.017241379310345	-20.0	1.0548666400004822	-16.0	-21.0	-0.5080234859767162	0.004646373723471428	237539

chainerrl/replay_buffer.py Outdated Show resolved Hide resolved
chainerrl/agents/dqn.py Outdated Show resolved Hide resolved
chainerrl/agents/dqn.py Outdated Show resolved Hide resolved
chainerrl/agents/dqn.py Show resolved Hide resolved
chainerrl/agents/iqn.py Outdated Show resolved Hide resolved
chainerrl/agents/iqn.py Show resolved Hide resolved
chainerrl/agents/sarsa.py Outdated Show resolved Hide resolved
chainerrl/agents/sarsa.py Show resolved Hide resolved
chainerrl/agents/sarsa.py Outdated Show resolved Hide resolved
Copy link
Contributor

@prabhatnagarajan prabhatnagarajan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost done... Can you make these small changes.

tests/agents_tests/test_iqn.py Outdated Show resolved Hide resolved
examples_tests/atari/test_drqn.sh Outdated Show resolved Hide resolved
examples_tests/atari/test_drqn.sh Outdated Show resolved Hide resolved
muupan and others added 3 commits August 9, 2019 17:16
Co-Authored-By: Prabhat Nagarajan <prabhat.nagarajan@gmail.com>
Co-Authored-By: Prabhat Nagarajan <prabhat.nagarajan@gmail.com>
Copy link
Contributor

@prabhatnagarajan prabhatnagarajan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Please take a look at the comments for minor improvements.

examples/atari/train_drqn_ale.py Outdated Show resolved Hide resolved
rbuf = replay_buffer.EpisodicReplayBuffer(10 ** 6)
else:
# Q-network without LSTM
q_func = chainer.Sequential(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partly because it is consistent with a recurrent version, partly because I think it is easier to understand, thus better as an example.

args.final_exploration_frames,
lambda: np.random.randint(n_actions))

opt = chainer.optimizers.Adam(1e-4, eps=1e-4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this should be fine. But why did you use Adam?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is just because Adam seems preferred in literature recently. I don't mean to reproduce any paper in this example.

muupan and others added 2 commits August 9, 2019 21:07
Co-Authored-By: Prabhat Nagarajan <prabhat.nagarajan@gmail.com>
@muupan muupan merged commit 36aa37c into chainer:master Aug 9, 2019
@muupan muupan deleted the recurrent-dqn branch August 9, 2019 13:22
@muupan muupan added this to the v0.8 milestone Feb 6, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Partial batch computation for recurrent models with replay
2 participants