New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic DQN #1014
Basic DQN #1014
Changes from 2 commits
1438cb0
4941e66
b844b04
46bf8e3
b9b1d8e
8aa4e03
5baa6da
d15706a
0e5dc3c
2bded68
50a6972
ffb45aa
f677299
7c37253
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Define the files we need to compile | ||
# Anything not in this list will not be compiled into mlpack. | ||
set(SOURCES | ||
q_learning.hpp | ||
q_learning_impl.hpp | ||
) | ||
|
||
# Add directory name to sources. | ||
set(DIR_SRCS) | ||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
# Append sources (with directory name) to list of all mlpack sources (used at | ||
# the parent scope). | ||
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) | ||
|
||
add_subdirectory(environment) | ||
add_subdirectory(estimator) | ||
add_subdirectory(policy) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
/** | ||
* @file q_learning.hpp | ||
* @author Shangtong Zhang | ||
* | ||
* This file is the definition of QLearning class, | ||
* which implements Q-Learning algorithms. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP | ||
#define MLPACK_METHODS_RL_Q_LEARNING_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
#include "replay/random_replay.hpp" | ||
|
||
namespace mlpack { | ||
namespace rl { | ||
|
||
/** | ||
* Implementation of various Q-Learning algorithms, such as DQN, double DQN. | ||
* | ||
* For more details, see the following: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure whether to refer the DQN Nature paper. (The bibtex of the nature paper is really really long). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think about: @article{Mnih2013,
author = {Volodymyr Mnih and
Koray Kavukcuoglu and
David Silver and
Alex Graves and
Ioannis Antonoglou and
Daan Wierstra and
Martin A. Riedmiller},
title = {Playing Atari with Deep Reinforcement Learning},
journal = {CoRR},
year = {2013},
url = {http://arxiv.org/abs/1312.5602}
} it's not that short, but I think this will work just fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks good. |
||
* @code | ||
* @article{Mnih2013, | ||
* author = {Volodymyr Mnih and | ||
* Koray Kavukcuoglu and | ||
* David Silver and | ||
* Alex Graves and | ||
* Ioannis Antonoglou and | ||
* Daan Wierstra and | ||
* Martin A. Riedmiller}, | ||
* title = {Playing Atari with Deep Reinforcement Learning}, | ||
* journal = {CoRR}, | ||
* year = {2013}, | ||
* url = {http://arxiv.org/abs/1312.5602} | ||
* } | ||
* @endcode | ||
* | ||
* @tparam EnvironmentType The environment of the reinforcement learning task. | ||
* @tparam NetworkType The network to compute action value. | ||
* @tparam OptimizerType The optimizer to train the network. | ||
* @tparam PolicyType Behavior policy of the agent. | ||
* @tparam ReplayType Experience replay method. | ||
*/ | ||
template < | ||
typename EnvironmentType, | ||
typename NetworkType, | ||
typename OptimizerType, | ||
typename PolicyType, | ||
typename ReplayType = RandomReplay<EnvironmentType> | ||
> | ||
class QLearning | ||
{ | ||
public: | ||
//! Convenient typedef for state. | ||
using StateType = typename EnvironmentType::State; | ||
|
||
//! Convenient typedef for action. | ||
using ActionType = typename EnvironmentType::Action; | ||
|
||
/** | ||
* Create the QLearning object with given settings. | ||
* | ||
* @param network The network to compute action value. | ||
* @param optimizer The optimizer to train the network. | ||
* @param discount Discount for future return. | ||
* @param policy Behavior policy of the agent. | ||
* @param replayMethod Experience replay method. | ||
* @param targetNetworkSyncInterval Interval (steps) to sync the target network. | ||
* @param explorationSteps Steps before starting to learn. | ||
* @param doubleQLearning Whether to use double Q-Learning. | ||
* @param stepLimit Maximum steps in each episode, 0 means no limit. | ||
* @param environment Reinforcement learning task. | ||
*/ | ||
QLearning(NetworkType& network, | ||
OptimizerType& optimizer, | ||
double discount, | ||
PolicyType policy, | ||
ReplayType replayMethod, | ||
size_t targetNetworkSyncInterval, | ||
size_t explorationSteps, | ||
bool doubleQLearning = false, | ||
size_t stepLimit = 0, | ||
EnvironmentType environment = EnvironmentType()); | ||
|
||
/** | ||
* Execute a step in an episode. | ||
* @return Reward for the step. | ||
*/ | ||
double Step(); | ||
|
||
/** | ||
* Execute an episode. | ||
* @return Return of the episode. | ||
*/ | ||
double Episode(); | ||
|
||
/** | ||
* @return Total steps from beginning. | ||
*/ | ||
const size_t& TotalSteps() const { return totalSteps; } | ||
|
||
//! Modify the training mode / test mode indicator. | ||
bool& Deterministic() { return deterministic; } | ||
|
||
//! Get the indicator of training mode / test mode. | ||
const bool& Deterministic() const { return deterministic; } | ||
|
||
private: | ||
/** | ||
* Select the best action based on given action value. | ||
* @param actionValues Action values. | ||
* @return Selected actions. | ||
*/ | ||
arma::icolvec BestAction(const arma::mat& actionValues); | ||
|
||
//! Reference of the learning network. | ||
NetworkType& learningNetwork; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reference members can make the class really hard to serialize; I might suggest that you provide a constructor instead that takes a const reference and copies, and then a constructor that takes an rvalue reference. The same applies to the optimizer. This only applies if you intend that this class will ever be serialized. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It sounds good. I didn't think about serialization before. I'll make it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realized I can't do this easily. My constructor should accept There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, do we have to pass the model? Instead we could get the model from the optimizer itself. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Getting the model from the optimizer looks very anti-intuitive for me and doesn't solve the problem from root. What I think is the most elegant way is to design a mechanism to allow rebinding for optimizer or delayed binding. It's even more necessary for async methods -- each thread will have its own model and may have their own optimizer, so the model must be bound to its corresponding optimizer dynamically. Currently it seems we cannot do this.
This would be totally compatible with existing codebase. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would be greatly appreciated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When I refactor the optimizer, I found some optimizer like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I didn't missed something, as for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I can see, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds more than fair for me. |
||
|
||
//! Locally-stored target network. | ||
NetworkType targetNetwork; | ||
|
||
//! Reference of the optimizer. | ||
OptimizerType& optimizer; | ||
|
||
//! Discount factor of future return. | ||
double discount; | ||
|
||
//! Locally-stored behavior policy. | ||
PolicyType policy; | ||
|
||
//! Locally-stored experience method. | ||
ReplayType replayMethod; | ||
|
||
//! Interval (steps) to update target network. | ||
size_t targetNetworkSyncInterval; | ||
|
||
//! Random steps before starting to learn. | ||
size_t explorationSteps; | ||
|
||
//! Whether to use double Q-Learning. | ||
bool doubleQLearning; | ||
|
||
//! Maximum steps for each episode. | ||
size_t stepLimit; | ||
|
||
//! Locally-stored reinforcement learning task. | ||
EnvironmentType environment; | ||
|
||
//! Total steps from the beginning of the task. | ||
size_t totalSteps; | ||
|
||
//! Locally-stored current state of the agent. | ||
StateType state; | ||
|
||
//! Locally-stored flag indicating training mode or test mode. | ||
bool deterministic; | ||
}; | ||
|
||
} // namespace rl | ||
} // namespace mlpack | ||
|
||
// Include implementation | ||
#include "q_learning_impl.hpp" | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Original
Predict
function is fairly inconvenient to use especially when I want to preserve thepredictors
. So I just add a wrapper.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good for me, I think we should use a reference here, there might be compilers which are not going to optimize this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But returning a non-const reference to a local variable is an undefined behavior. And I think
NRVO
is fairly reliable unless conditional return.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was talking about
predictors
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh. The reason why I didn't use
arma::mat Predict(arma::mat& predictors);
is that I wantpredictors
to stay unchanged after callingPredict
. So I want a copy here. If the signature of the original function isvoid Predict(const arma::mat& predictors, arma::mat& results);
, I won't write this wrapper. But I find it's not so easy to change the original function to const reference.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok that makes sense. Then I think what we need to do here is to implement
void Predict(mat&&, mat&)
directly and implementvoid Predict(const mat&, mat&)
as a wrapper. (Perhaps need to revert the fix).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like a good plan for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So will you do this or I make it in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you like you can do here, don't feel obligated, just let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I can do it here