Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RL4J - Fix for setTarget() (issue #8107) #8250

Merged
merged 2 commits into from Oct 1, 2019
Merged

Conversation

@aboulang2002
Copy link
Contributor

aboulang2002 commented Sep 22, 2019

What changes were proposed in this pull request?

See issue #8107

After reading again "Playing Atari with Deep Reinforcement Learning" (http://arxiv.org/abs/1312.5602), I still think there is a problem with the current implementation of standard DQN, but I don't think that the target-network should be used for dqnOutputAr.

I have also read the section 3 "Understanding Deep Q-Network" of "A Theoretical Analysis of Deep Q-Learning" (https://arxiv.org/pdf/1901.00137.pdf).
And also https://medium.com/@qempsil0914/deep-q-learning-part2-double-deep-q-network-double-dqn-b8fc9212bbb2 which has a good overview of DQN.
Plus I reviewed a few DQN implementations on github. For example, Google's TensorLayer (https://github.com/tensorlayer/tensorlayer/blob/72666dbf23718e8710dc1069dbcf3231d131dff2/examples/reinforcement_learning/tutorial_DQN_variants.py).

First, the target-network should be used to compute tempQ in the standard DQN algorithm.
This can be seen in the equation 3.1 (and the one above in the text) of https://arxiv.org/pdf/1901.00137.pdf and also from lines 6-8 of tutorial_DQN_variants.py:

	The max operator in standard DQN uses the same values both to select and to evaluate an action by
	Q(s_t, a_t) = R_{t+1} + \gamma * max_{a}Q_{tar}(s_{t+1}, a)

We can see that the target-network (Q_{tar}) should be used.

Second, the target-network is only used as a stable base to compute the loss function, and not "to obtain the labels" as @flyheroljg says. This means that the Q-Network should used for dqnOutputAr. This can be seen on line 8 of tutorial_DQN_variants.py and also on lines 322-325:

                # calculate loss
                with tf.GradientTape() as q_tape:
                    b_q = tf.reduce_sum(qnet(b_o) * tf.one_hot(b_a, out_dim), 1)
                    loss = tf.reduce_mean(huber_loss(b_q - (b_r + reward_gamma * b_q_)))

The above snippet is used for Double-DQN, but only the td-target part ((b_r + reward_gamma * b_q_) above) is different between standard and double DQN in the computation of the loss function.
This can also be seen in the section "Double Q-Learning Algorithm" of https://medium.com/@qempsil0914/deep-q-learning-part2-double-deep-q-network-double-dqn-b8fc9212bbb2, as well as equation 3.3 of https://arxiv.org/pdf/1901.00137.pdf

How was this patch tested?

With cartpole-v0 (re-written in java because couldn't make gym's version to work; can add it to this PR if requested)

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
@saudet saudet self-requested a review Sep 30, 2019
@saudet
saudet approved these changes Sep 30, 2019
Copy link
Member

saudet left a comment

That looks good! Please do add your version of Cartpole, and let's merge this :) Thanks

Signed-off-by: unknown <aboulang2002@yahoo.com>
@aboulang2002

This comment has been minimized.

Copy link
Contributor Author

aboulang2002 commented Sep 30, 2019

I added CartpoleNative, a native java port of GYM's cartpole-v0. The rendering has not been ported. I may add it later but it's lower on my list of todos.

@saudet saudet merged commit 5959ff4 into eclipse:master Oct 1, 2019
1 check passed
1 check passed
eclipsefdn/eca The author(s) of the pull request is covered by necessary legal agreements in order to proceed!
Details
@aboulang2002 aboulang2002 deleted the aboulang2002:ab2002_rl4j_fixSetTarget branch Oct 1, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
2 participants
You can’t perform that action at this time.