Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic DQN #1014

Merged
merged 14 commits into from Jun 16, 2017
2 changes: 1 addition & 1 deletion COPYRIGHT.txt
Expand Up @@ -37,7 +37,7 @@ Copyright:
Copyright 2014, Udit Saxena <saxenda.udit@gmail.com>
Copyright 2014-2015, Stephen Tu <tu.stephenl@gmail.com>
Copyright 2014-2015, Jaskaran Singh <jaskaranvirdi@ymail.com>
Copyright 2015, Shangtong Zhang <zhangshangtong.cpp@icloud.com>
Copyright 2015&2017, Shangtong Zhang <zhangshangtong.cpp@gmail.com>
Copyright 2015, Hritik Jain <hritik.jain.cse13@itbhu.ac.in>
Copyright 2015, Vladimir Glazachev <glazachev.vladimir@gmail.com>
Copyright 2015, QiaoAn Chen <kazenoyumechen@gmail.com>
Expand Down
10 changes: 10 additions & 0 deletions src/mlpack/methods/ann/ffn.hpp
Expand Up @@ -141,6 +141,16 @@ class FFN
*/
void Predict(arma::mat& predictors, arma::mat& results);

/**
* Predict the responses to a given set of predictors. The responses will
* reflect the output of the given output layer as returned by the
* output layer function.
*
* @param predictors Input predictors.
* @return Output predictions of responses into.
*/
arma::mat Predict(arma::mat predictors);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original Predict function is fairly inconvenient to use especially when I want to preserve the predictors. So I just add a wrapper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good for me, I think we should use a reference here, there might be compilers which are not going to optimize this way.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But returning a non-const reference to a local variable is an undefined behavior. And I think NRVO is fairly reliable unless conditional return.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was talking about predictors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. The reason why I didn't use arma::mat Predict(arma::mat& predictors); is that I want predictors to stay unchanged after calling Predict. So I want a copy here. If the signature of the original function is void Predict(const arma::mat& predictors, arma::mat& results);, I won't write this wrapper. But I find it's not so easy to change the original function to const reference.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok that makes sense. Then I think what we need to do here is to implement void Predict(mat&&, mat&) directly and implement void Predict(const mat&, mat&) as a wrapper. (Perhaps need to revert the fix).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a good plan for me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So will you do this or I make it in this PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you like you can do here, don't feel obligated, just let me know.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can do it here

/**
* Evaluate the feedforward network with the given parameters. This function
* is usually called by the optimizer to train the model.
Expand Down
9 changes: 9 additions & 0 deletions src/mlpack/methods/ann/ffn_impl.hpp
Expand Up @@ -167,6 +167,15 @@ void FFN<OutputLayerType, InitializationRuleType>::Predict(
}
}

template<typename OutputLayerType, typename InitializationRuleType>
arma::mat FFN<OutputLayerType, InitializationRuleType>::Predict(
arma::mat predictors)
{
arma::mat results;
Predict(predictors, results);
return results;
};

template<typename OutputLayerType, typename InitializationRuleType>
double FFN<OutputLayerType, InitializationRuleType>::Evaluate(
const arma::mat& /* parameters */, const size_t i, const bool deterministic)
Expand Down
19 changes: 19 additions & 0 deletions src/mlpack/methods/reinforcement_learning/CMakeLists.txt
@@ -0,0 +1,19 @@
# Define the files we need to compile
# Anything not in this list will not be compiled into mlpack.
set(SOURCES
q_learning.hpp
q_learning_impl.hpp
)

# Add directory name to sources.
set(DIR_SRCS)
foreach(file ${SOURCES})
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
endforeach()
# Append sources (with directory name) to list of all mlpack sources (used at
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)

add_subdirectory(environment)
add_subdirectory(estimator)
add_subdirectory(policy)
Expand Up @@ -51,14 +51,15 @@ class GreedyPolicy
* Sample an action based on given action values.
*
* @param actionValue Values for each action.
* @param deterministic Always select the action greedily.
* @return Sampled action.
*/
ActionType Sample(const arma::colvec& actionValue)
ActionType Sample(const arma::colvec& actionValue, bool deterministic = false)
{
double exploration = math::Random();

// Select the action randomly.
if (exploration < epsilon)
if (!deterministic && exploration < epsilon)
return static_cast<ActionType>(math::RandInt(ActionType::size));

// Select the action greedily.
Expand Down
170 changes: 170 additions & 0 deletions src/mlpack/methods/reinforcement_learning/q_learning.hpp
@@ -0,0 +1,170 @@
/**
* @file q_learning.hpp
* @author Shangtong Zhang
*
* This file is the definition of QLearning class,
* which implements Q-Learning algorithms.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP
#define MLPACK_METHODS_RL_Q_LEARNING_HPP

#include <mlpack/prereqs.hpp>

#include "replay/random_replay.hpp"

namespace mlpack {
namespace rl {

/**
* Implementation of various Q-Learning algorithms, such as DQN, double DQN.
*
* For more details, see the following:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether to refer the DQN Nature paper. (The bibtex of the nature paper is really really long).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about:

@article{Mnih2013,
  author    = {Volodymyr Mnih and
               Koray Kavukcuoglu and
               David Silver and
               Alex Graves and
               Ioannis Antonoglou and
               Daan Wierstra and
               Martin A. Riedmiller},
  title     = {Playing Atari with Deep Reinforcement Learning},
  journal   = {CoRR},
  year      = {2013},
  url       = {http://arxiv.org/abs/1312.5602}
}

it's not that short, but I think this will work just fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good.

* @code
* @article{Mnih2013,
* author = {Volodymyr Mnih and
* Koray Kavukcuoglu and
* David Silver and
* Alex Graves and
* Ioannis Antonoglou and
* Daan Wierstra and
* Martin A. Riedmiller},
* title = {Playing Atari with Deep Reinforcement Learning},
* journal = {CoRR},
* year = {2013},
* url = {http://arxiv.org/abs/1312.5602}
* }
* @endcode
*
* @tparam EnvironmentType The environment of the reinforcement learning task.
* @tparam NetworkType The network to compute action value.
* @tparam OptimizerType The optimizer to train the network.
* @tparam PolicyType Behavior policy of the agent.
* @tparam ReplayType Experience replay method.
*/
template <
typename EnvironmentType,
typename NetworkType,
typename OptimizerType,
typename PolicyType,
typename ReplayType = RandomReplay<EnvironmentType>
>
class QLearning
{
public:
//! Convenient typedef for state.
using StateType = typename EnvironmentType::State;

//! Convenient typedef for action.
using ActionType = typename EnvironmentType::Action;

/**
* Create the QLearning object with given settings.
*
* @param network The network to compute action value.
* @param optimizer The optimizer to train the network.
* @param discount Discount for future return.
* @param policy Behavior policy of the agent.
* @param replayMethod Experience replay method.
* @param targetNetworkSyncInterval Interval (steps) to sync the target network.
* @param explorationSteps Steps before starting to learn.
* @param doubleQLearning Whether to use double Q-Learning.
* @param stepLimit Maximum steps in each episode, 0 means no limit.
* @param environment Reinforcement learning task.
*/
QLearning(NetworkType& network,
OptimizerType& optimizer,
double discount,
PolicyType policy,
ReplayType replayMethod,
size_t targetNetworkSyncInterval,
size_t explorationSteps,
bool doubleQLearning = false,
size_t stepLimit = 0,
EnvironmentType environment = EnvironmentType());

/**
* Execute a step in an episode.
* @return Reward for the step.
*/
double Step();

/**
* Execute an episode.
* @return Return of the episode.
*/
double Episode();

/**
* @return Total steps from beginning.
*/
const size_t& TotalSteps() const { return totalSteps; }

//! Modify the training mode / test mode indicator.
bool& Deterministic() { return deterministic; }

//! Get the indicator of training mode / test mode.
const bool& Deterministic() const { return deterministic; }

private:
/**
* Select the best action based on given action value.
* @param actionValues Action values.
* @return Selected actions.
*/
arma::icolvec BestAction(const arma::mat& actionValues);

//! Reference of the learning network.
NetworkType& learningNetwork;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference members can make the class really hard to serialize; I might suggest that you provide a constructor instead that takes a const reference and copies, and then a constructor that takes an rvalue reference. The same applies to the optimizer. This only applies if you intend that this class will ever be serialized.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds good. I didn't think about serialization before. I'll make it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized I can't do this easily. My constructor should accept model and opt, however the problem comes from that opt itself locally stores a reference to model. If I std::move the model object, the reference in opt will be invalid.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, do we have to pass the model? Instead we could get the model from the optimizer itself.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting the model from the optimizer looks very anti-intuitive for me and doesn't solve the problem from root. What I think is the most elegant way is to design a mechanism to allow rebinding for optimizer or delayed binding. It's even more necessary for async methods -- each thread will have its own model and may have their own optimizer, so the model must be bound to its corresponding optimizer dynamically. Currently it seems we cannot do this.
I have some solutions, one is delayed binding by lambda function. In stead of passing in a optimizer, we pass in a optimizer generator (that's what I did in my PyTorch implementation) like this:
auto opt = [](auto& model) { return StandardSGD<decltype(model)>(model, 0.0001, 2); }
The drawback is user has to write this lambda function and the code will become fairly overstaffed so I don't want to do this in C++.
Another delayed binding solution is to add a overload for Optimize of optimizer like this:
double Optimize(arma::mat& iterate, FunctionType& function);
It looks like the easiest way. But the problem is if we pass in the function every time we call the function, why do we still store it in constructor? If we do this, do we need to change the original constructor? Although this, totally separating optimizer and network, is the most efficient design I think, I don't know what's the influence of this change on existing codebase.
Rebinding is also a solution, we can refactor the optimizer with reference_wrapper like this:

template<typename FunctionType>
class GradientDescent
{
 public:
  GradientDescent(FunctionType& function,
                  const double stepSize = 0.01,
                  const size_t maxIterations = 100000,
                  const double tolerance = 1e-5);

  //! Get the instantiated function to be optimized.
  const FunctionType& Function() const { return function.get(); }
  //! Modify the instantiated function.
  FunctionType& Function() { return function.get(); }

  void Rebind(std::reference_wrapper<FunctionType> f) { function = f; };

 private:
  //! The instantiated function.
  std::reference_wrapper<FunctionType> function;
};

This would be totally compatible with existing codebase.
If I do all the things from scratch, I'd prefer the second solution. But if refactoring the constructor of optimizer class will lead to too many regression issues, I'm fine to implement the third one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be greatly appreciated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I refactor the optimizer, I found some optimizer like lbfgs has many member functions and many of them use the locally stored function reference, and even the constructor also uses the function:
const size_t rows = function.GetInitialPoint().n_rows;,
So for those classes, will storing a reference be a better choice? After seeing lbfgs, I would vote for the the std::reference_wrapper idea, by this we only need to modify the code within optimizer folder, all other code won't be affected. (And aug_lagrangian seems also very complicated)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I didn't missed something, as for L_BFGS function is only used in private member functions, as for aug_lagrangian it looks a little bit more complicated since we also have to provide an interface for the utility AugLagrangianFunction. I guess, if the proposed change would make @rcurtin and @micyril's hyperparameter tuner project easier, it's worth the work. Maybe they can provide more information on that point. If they say it would be helpful, we could split the work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can see, function is used in constructor in L_BFGS and SA. I can do this refactor for some simple optimizers like ada_deltas, ada_grad, adam, gradient_descent, minibatch_sgd, rmsprop, sgdr, smorms3. But for other complicated optimizers I am afraid I'm not suitable to do that -- I don't have any idea how them work, even this is the first time I hear those names :<

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds more than fair for me.


//! Locally-stored target network.
NetworkType targetNetwork;

//! Reference of the optimizer.
OptimizerType& optimizer;

//! Discount factor of future return.
double discount;

//! Locally-stored behavior policy.
PolicyType policy;

//! Locally-stored experience method.
ReplayType replayMethod;

//! Interval (steps) to update target network.
size_t targetNetworkSyncInterval;

//! Random steps before starting to learn.
size_t explorationSteps;

//! Whether to use double Q-Learning.
bool doubleQLearning;

//! Maximum steps for each episode.
size_t stepLimit;

//! Locally-stored reinforcement learning task.
EnvironmentType environment;

//! Total steps from the beginning of the task.
size_t totalSteps;

//! Locally-stored current state of the agent.
StateType state;

//! Locally-stored flag indicating training mode or test mode.
bool deterministic;
};

} // namespace rl
} // namespace mlpack

// Include implementation
#include "q_learning_impl.hpp"
#endif