New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic DQN #1014

Merged
merged 14 commits into from Jun 16, 2017

Conversation

Projects
None yet
3 participants
@ShangtongZhang
Member

ShangtongZhang commented May 31, 2017

This (double) DQN implementation is ready for classical control problems (e.g. CartPole). But it's not ready for Atari games with pixel input.

Show outdated Hide outdated src/mlpack/methods/ann/ffn.hpp
/**
* Implementation of various Q-Learning algorithms, such as DQN, double DQN.
*
* For more details, see the following:

This comment has been minimized.

@ShangtongZhang

ShangtongZhang May 31, 2017

Member

Not sure whether to refer the DQN Nature paper. (The bibtex of the nature paper is really really long).

@ShangtongZhang

ShangtongZhang May 31, 2017

Member

Not sure whether to refer the DQN Nature paper. (The bibtex of the nature paper is really really long).

This comment has been minimized.

@zoq

zoq May 31, 2017

Member

What do you think about:

@article{Mnih2013,
  author    = {Volodymyr Mnih and
               Koray Kavukcuoglu and
               David Silver and
               Alex Graves and
               Ioannis Antonoglou and
               Daan Wierstra and
               Martin A. Riedmiller},
  title     = {Playing Atari with Deep Reinforcement Learning},
  journal   = {CoRR},
  year      = {2013},
  url       = {http://arxiv.org/abs/1312.5602}
}

it's not that short, but I think this will work just fine.

@zoq

zoq May 31, 2017

Member

What do you think about:

@article{Mnih2013,
  author    = {Volodymyr Mnih and
               Koray Kavukcuoglu and
               David Silver and
               Alex Graves and
               Ioannis Antonoglou and
               Daan Wierstra and
               Martin A. Riedmiller},
  title     = {Playing Atari with Deep Reinforcement Learning},
  journal   = {CoRR},
  year      = {2013},
  url       = {http://arxiv.org/abs/1312.5602}
}

it's not that short, but I think this will work just fine.

This comment has been minimized.

@ShangtongZhang

ShangtongZhang May 31, 2017

Member

It looks good.

@ShangtongZhang

ShangtongZhang May 31, 2017

Member

It looks good.

Show outdated Hide outdated src/mlpack/methods/reinforcement_learning/q_learning_impl.hpp
@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Jun 6, 2017

Member

Looks ready for me, do you have anything else to add?

Member

zoq commented Jun 6, 2017

Looks ready for me, do you have anything else to add?

@rcurtin

Just a few comments from me, hope they are helpful.

@ShangtongZhang

This comment has been minimized.

Show comment
Hide comment
@ShangtongZhang

ShangtongZhang Jun 6, 2017

Member

I suddenly realized we don't need to provide Predict(mat&&, mat&) and Predict(const mat&, mat&). We can only provide Predict(mat m, mat&). Then in the function implementation, we make sure we use std::move(m) to do things. So if user wants to save a copy, he can invoke the function with Predict(std::move(..), ..), there will be no copy. If user wants to preserve the matrix, he can simply call Predict(..., ...), there is still only one copy as before. Did I miss something? We can do same thing for some constructor if that class won't store a reference.

Member

ShangtongZhang commented Jun 6, 2017

I suddenly realized we don't need to provide Predict(mat&&, mat&) and Predict(const mat&, mat&). We can only provide Predict(mat m, mat&). Then in the function implementation, we make sure we use std::move(m) to do things. So if user wants to save a copy, he can invoke the function with Predict(std::move(..), ..), there will be no copy. If user wants to preserve the matrix, he can simply call Predict(..., ...), there is still only one copy as before. Did I miss something? We can do same thing for some constructor if that class won't store a reference.

@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Jun 6, 2017

Member

I think, providing Predict(const mat&, mat&) is still a good idea, even if we do it just to explicitly tell the user the function doesn't change the input.

Member

zoq commented Jun 6, 2017

I think, providing Predict(const mat&, mat&) is still a good idea, even if we do it just to explicitly tell the user the function doesn't change the input.

@ShangtongZhang

This comment has been minimized.

Show comment
Hide comment
@ShangtongZhang

ShangtongZhang Jun 6, 2017

Member

But I think Predict(mat, mat&) is also explicitly telling user this function won't change the input.

Member

ShangtongZhang commented Jun 6, 2017

But I think Predict(mat, mat&) is also explicitly telling user this function won't change the input.

@rcurtin

This comment has been minimized.

Show comment
Hide comment
@rcurtin

rcurtin Jun 6, 2017

Member

I read http://cpptruths.blogspot.com/2012/03/rvalue-references-in-constructor-when.html , I need to think over it. Maybe we can save a lot of code with this approach, but if we do this:

  • we should apply it to all mlpack code (so we should just open an issue about it, not do it all here of course). Probably it is best applied to user-facing interfaces like constructors, Train(), and Predict() only.
  • we must make sure that the documentation is clear that whenever you can you should pass your data with std::move()---this would be my big sticking point, since now it is way easier for the user to accidentally copy their data and get slow code. How would we best make this clear to user? A comment in every constructor? Something else?
  • we should make sure that gcc and clang are able to optimize that the same (or better) than what we do now.

But the benefit would be big, because it allows users to pass in some values to be copied and some to be moved, and it allows us to reduce the amount of boilerplate code we have by a pretty large amount.

Member

rcurtin commented Jun 6, 2017

I read http://cpptruths.blogspot.com/2012/03/rvalue-references-in-constructor-when.html , I need to think over it. Maybe we can save a lot of code with this approach, but if we do this:

  • we should apply it to all mlpack code (so we should just open an issue about it, not do it all here of course). Probably it is best applied to user-facing interfaces like constructors, Train(), and Predict() only.
  • we must make sure that the documentation is clear that whenever you can you should pass your data with std::move()---this would be my big sticking point, since now it is way easier for the user to accidentally copy their data and get slow code. How would we best make this clear to user? A comment in every constructor? Something else?
  • we should make sure that gcc and clang are able to optimize that the same (or better) than what we do now.

But the benefit would be big, because it allows users to pass in some values to be copied and some to be moved, and it allows us to reduce the amount of boilerplate code we have by a pretty large amount.

@ShangtongZhang

This comment has been minimized.

Show comment
Hide comment
@ShangtongZhang

ShangtongZhang Jun 6, 2017

Member
  • I agree we should open a new issue for this to refactor existing codebase. But in the meantime can we still do it incrementally in newly added code like this PR? Because I think the old function is compatible with the new calling convention. If user calls it with std::move, both old and new functions won't have copy. If user calls it directly, new function will lead to one extra copy, while old function will have at most one extra copy. But according to our new calling convention, if user doesn't use std::move, we can assume user can afford an extra copy.
  • Yeah this could be a headache. What I can think besides clear documentation is that we should also refactor our test cases carefully to show user a correct usage example. When I was a beginner with mlpack, I just turned to test case to find usage of some class.
  • I don't think it would be a big issue. As far as I know, this new style doesn't depend on any special optimizations like NRVO. It's only a move constructor and works somehow like the copy-and-swap idiom.
Member

ShangtongZhang commented Jun 6, 2017

  • I agree we should open a new issue for this to refactor existing codebase. But in the meantime can we still do it incrementally in newly added code like this PR? Because I think the old function is compatible with the new calling convention. If user calls it with std::move, both old and new functions won't have copy. If user calls it directly, new function will lead to one extra copy, while old function will have at most one extra copy. But according to our new calling convention, if user doesn't use std::move, we can assume user can afford an extra copy.
  • Yeah this could be a headache. What I can think besides clear documentation is that we should also refactor our test cases carefully to show user a correct usage example. When I was a beginner with mlpack, I just turned to test case to find usage of some class.
  • I don't think it would be a big issue. As far as I know, this new style doesn't depend on any special optimizations like NRVO. It's only a move constructor and works somehow like the copy-and-swap idiom.
@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Jun 7, 2017

Member

I agree we should open a new issue for this to refactor existing codebase. But in the meantime can we still do it incrementally in newly added code like this PR?

I'm fine to do it for the ann code, it's one part that's not released yet, so I think we don't run into issues any issues.

Member

zoq commented Jun 7, 2017

I agree we should open a new issue for this to refactor existing codebase. But in the meantime can we still do it incrementally in newly added code like this PR?

I'm fine to do it for the ann code, it's one part that's not released yet, so I think we don't run into issues any issues.

@rcurtin

This comment has been minimized.

Show comment
Hide comment
@rcurtin

rcurtin Jun 7, 2017

Member

I thought about it more, I can't see any reason not to switch to pass-by-value instead of having an lvalue-reference overload and an rvalue-reference overload. So I think we should do it throughout the mlpack codebase. I'd say go ahead and do it here too.

What's really nice about this is that it doesn't break the API, since you still interact with the methods in the exact same way. I opened #1021 to handle the rest of the code.

Member

rcurtin commented Jun 7, 2017

I thought about it more, I can't see any reason not to switch to pass-by-value instead of having an lvalue-reference overload and an rvalue-reference overload. So I think we should do it throughout the mlpack codebase. I'd say go ahead and do it here too.

What's really nice about this is that it doesn't break the API, since you still interact with the methods in the exact same way. I opened #1021 to handle the rest of the code.

ShangtongZhang added some commits Jun 8, 2017

@rcurtin

No more concerns from my end.

@ShangtongZhang

This comment has been minimized.

Show comment
Hide comment
@ShangtongZhang

ShangtongZhang Jun 14, 2017

Member

This PR is still not fully backward compatible -- Using optimizers not shown in this #1026 with FFN and RNN will lead to compile error. But I doubt whether there exists someone that did use such unusual optimizers with FFN and RNN. I'm fine to delay merging this PR until all the optimizers are ready.

Member

ShangtongZhang commented Jun 14, 2017

This PR is still not fully backward compatible -- Using optimizers not shown in this #1026 with FFN and RNN will lead to compile error. But I doubt whether there exists someone that did use such unusual optimizers with FFN and RNN. I'm fine to delay merging this PR until all the optimizers are ready.

@zoq

zoq approved these changes Jun 15, 2017

Just made some minor comments, we don't have to wait for the optimizer transition. I don't think anyone is going to use lbfgs in combination with the code in the near future.

* @param stepLimit Maximum steps in each episode, 0 means no limit.
* @param environment Reinforcement learning task.
*/
QLearning(NetworkType network,

This comment has been minimized.

@zoq

zoq Jun 15, 2017

Member

Do you mind to use const for the discount, targetNetworkSyncInterval, etc. parameter? To be consistent with the style of the codebase?

@zoq

zoq Jun 15, 2017

Member

Do you mind to use const for the discount, targetNetworkSyncInterval, etc. parameter? To be consistent with the style of the codebase?

{
arma::Col<size_t> bestActions(actionValues.n_cols);
arma::rowvec maxActionValues = arma::max(actionValues, 0);
for (size_t i = 0; i < actionValues.n_cols; ++i)

This comment has been minimized.

@zoq

zoq Jun 15, 2017

Member

To improve readability maybe we should use braces here?

@zoq

zoq Jun 15, 2017

Member

To improve readability maybe we should use braces here?

@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Jun 15, 2017

Member

There is one more thing, can you remove move from model.Train(X, Y, opt); in convolutional_network_test.cpp, should resolve the issue with the failing test.

Member

zoq commented Jun 15, 2017

There is one more thing, can you remove move from model.Train(X, Y, opt); in convolutional_network_test.cpp, should resolve the issue with the failing test.

@ShangtongZhang

This comment has been minimized.

Show comment
Hide comment
@ShangtongZhang

ShangtongZhang Jun 15, 2017

Member

Hope it's ready to merge now.

Member

ShangtongZhang commented Jun 15, 2017

Hope it's ready to merge now.

@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Jun 15, 2017

Member

Thanks for looking into the issues.

Member

zoq commented Jun 15, 2017

Thanks for looking into the issues.

@zoq zoq merged commit 41cbc50 into mlpack:master Jun 16, 2017

3 checks passed

Style Checks Build finished.
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Jun 16, 2017

Member

Nice work!

Member

zoq commented Jun 16, 2017

Nice work!

@zoq

This comment has been minimized.

Show comment
Hide comment
@zoq

zoq Jun 16, 2017

Member

I run the CartPoleWithDoubleDQN test in a loop to see how often it fails, turns out it fails relative often, so I tried to stabalize the test and also tested another network architecture to reduce the test time. Let me know if I messed something up. Here is the commit: 2c51181

Member

zoq commented Jun 16, 2017

I run the CartPoleWithDoubleDQN test in a loop to see how often it fails, turns out it fails relative often, so I tried to stabalize the test and also tested another network architecture to reduce the test time. Let me know if I messed something up. Here is the commit: 2c51181

@ShangtongZhang

This comment has been minimized.

Show comment
Hide comment
@ShangtongZhang

ShangtongZhang Jun 17, 2017

Member

Thanks for looking into this issue! It's amazing that such a small network will work.

Member

ShangtongZhang commented Jun 17, 2017

Thanks for looking into this issue! It's amazing that such a small network will work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment