diff --git a/COPYRIGHT.txt b/COPYRIGHT.txt index 7b5b3f808ed..b89fb04345b 100644 --- a/COPYRIGHT.txt +++ b/COPYRIGHT.txt @@ -37,7 +37,7 @@ Copyright: Copyright 2014, Udit Saxena Copyright 2014-2015, Stephen Tu Copyright 2014-2015, Jaskaran Singh - Copyright 2015, Shangtong Zhang + Copyright 2015&2017, Shangtong Zhang Copyright 2015, Hritik Jain Copyright 2015, Vladimir Glazachev Copyright 2015, QiaoAn Chen diff --git a/src/mlpack/methods/ann/ffn.hpp b/src/mlpack/methods/ann/ffn.hpp index f297a56c4ee..f3f69bc4da0 100644 --- a/src/mlpack/methods/ann/ffn.hpp +++ b/src/mlpack/methods/ann/ffn.hpp @@ -51,15 +51,18 @@ class FFN /** * Create the FFN object with the given predictors and responses set (this is - * the set that is used to train the network) and the given optimizer. + * the set that is used to train the network). * Optionally, specify which initialize rule and performance function should * be used. * + * If you want to pass in a parameter and discard the original parameter + * object, be sure to use std::move to avoid unnecessary copy. + * * @param outputLayer Output layer used to evaluate the network. * @param initializeRule Optional instantiated InitializationRule object * for initializing the network parameter. */ - FFN(OutputLayerType&& outputLayer = OutputLayerType(), + FFN(OutputLayerType outputLayer = OutputLayerType(), InitializationRuleType initializeRule = InitializationRuleType()); //! Copy constructor. @@ -73,19 +76,22 @@ class FFN /** * Create the FFN object with the given predictors and responses set (this is - * the set that is used to train the network) and the given optimizer. + * the set that is used to train the network). * Optionally, specify which initialize rule and performance function should * be used. * + * If you want to pass in a parameter and discard the original parameter + * object, be sure to use std::move to avoid unnecessary copy. + * * @param predictors Input training variables. * @param responses Outputs results from input training variables. * @param outputLayer Output layer used to evaluate the network. * @param initializeRule Optional instantiated InitializationRule object * for initializing the network parameter. */ - FFN(const arma::mat& predictors, - const arma::mat& responses, - OutputLayerType&& outputLayer = OutputLayerType(), + FFN(arma::mat predictors, + arma::mat responses, + OutputLayerType outputLayer = OutputLayerType(), InitializationRuleType initializeRule = InitializationRuleType()); //! Destructor to release allocated memory. @@ -99,6 +105,9 @@ class FFN * optimization. If this is not what you want, then you should access the * parameters vector directly with Parameters() and modify it as desired. * + * If you want to pass in a parameter and discard the original parameter + * object, be sure to use std::move to avoid unnecessary copy. + * * @tparam OptimizerType Type of optimizer to use to train the model. * @param predictors Input training variables. * @param responses Outputs results from input training variables. @@ -109,8 +118,8 @@ class FFN mlpack::optimization::RMSProp, typename... OptimizerTypeArgs > - void Train(const arma::mat& predictors, - const arma::mat& responses, + void Train(arma::mat predictors, + arma::mat responses, OptimizerType& optimizer); /** @@ -122,6 +131,9 @@ class FFN * optimization. If this is not what you want, then you should access the * parameters vector directly with Parameters() and modify it as desired. * + * If you want to pass in a parameter and discard the original parameter + * object, be sure to use std::move to avoid unnecessary copy. + * * @tparam OptimizerType Type of optimizer to use to train the model. * @param predictors Input training variables. * @param responses Outputs results from input training variables. @@ -129,17 +141,20 @@ class FFN template< template class OptimizerType = mlpack::optimization::RMSProp > - void Train(const arma::mat& predictors, const arma::mat& responses); + void Train(arma::mat predictors, arma::mat responses); /** * Predict the responses to a given set of predictors. The responses will * reflect the output of the given output layer as returned by the * output layer function. * + * If you want to pass in a parameter and discard the original parameter + * object, be sure to use std::move to avoid unnecessary copy. + * * @param predictors Input predictors. * @param results Matrix to put output predictions of responses into. */ - void Predict(const arma::mat& predictors, arma::mat& results); + void Predict(arma::mat predictors, arma::mat& results); /** * Evaluate the feedforward network with the given parameters. This function @@ -226,7 +241,7 @@ class FFN * @param predictors Input data variables. * @param responses Outputs results from input data variables. */ - void ResetData(const arma::mat& predictors, const arma::mat& responses); + void ResetData(arma::mat predictors, arma::mat responses); /** * The Backward algorithm (part of the Forward-Backward algorithm). Computes diff --git a/src/mlpack/methods/ann/ffn_impl.hpp b/src/mlpack/methods/ann/ffn_impl.hpp index 38de848934b..62948f11335 100644 --- a/src/mlpack/methods/ann/ffn_impl.hpp +++ b/src/mlpack/methods/ann/ffn_impl.hpp @@ -29,9 +29,9 @@ namespace ann /** Artificial Neural Network. */ { template FFN::FFN( - OutputLayerType&& outputLayer, InitializationRuleType initializeRule) : + OutputLayerType outputLayer, InitializationRuleType initializeRule) : outputLayer(std::move(outputLayer)), - initializeRule(initializeRule), + initializeRule(std::move(initializeRule)), width(0), height(0), reset(false) @@ -41,22 +41,20 @@ FFN::FFN( template FFN::FFN( - const arma::mat& predictors, - const arma::mat& responses, - OutputLayerType&& outputLayer, + arma::mat predictors, + arma::mat responses, + OutputLayerType outputLayer, InitializationRuleType initializeRule) : outputLayer(std::move(outputLayer)), - initializeRule(initializeRule), + initializeRule(std::move(initializeRule)), width(0), height(0), - reset(false) + reset(false), + predictors(std::move(predictors)), + responses(std::move(responses)), + deterministic(true) { - numFunctions = responses.n_cols; - - this->predictors = std::move(predictors); - this->responses = std::move(responses); - - this->deterministic = true; + numFunctions = this->responses.n_cols; } template @@ -68,7 +66,7 @@ FFN::~FFN() template void FFN::ResetData( - const arma::mat& predictors, const arma::mat& responses) + arma::mat predictors, arma::mat responses) { numFunctions = responses.n_cols; this->predictors = std::move(predictors); @@ -88,15 +86,15 @@ template< typename... OptimizerTypeArgs > void FFN::Train( - const arma::mat& predictors, - const arma::mat& responses, + arma::mat predictors, + arma::mat responses, OptimizerType& optimizer) { - ResetData(predictors, responses); + ResetData(std::move(predictors), std::move(responses)); // Train the model. Timer::Start("ffn_optimization"); - const double out = optimizer.Optimize(parameter); + const double out = optimizer.Optimize(*this, parameter); Timer::Stop("ffn_optimization"); Log::Info << "FFN::FFN(): final objective of trained model is " << out @@ -106,7 +104,7 @@ void FFN::Train( template template class OptimizerType> void FFN::Train( - const arma::mat& predictors, const arma::mat& responses) + arma::mat predictors, arma::mat responses) { numFunctions = responses.n_cols; @@ -125,7 +123,7 @@ void FFN::Train( // Train the model. Timer::Start("ffn_optimization"); - const double out = optimizer.Optimize(parameter); + const double out = optimizer.Optimize(*this, parameter); Timer::Stop("ffn_optimization"); Log::Info << "FFN::FFN(): final objective of trained model is " << out @@ -134,7 +132,7 @@ void FFN::Train( template void FFN::Predict( - const arma::mat& predictors, arma::mat& results) + arma::mat predictors, arma::mat& results) { if (parameter.is_empty()) { @@ -148,7 +146,8 @@ void FFN::Predict( } arma::mat resultsTemp; - Forward(std::move(predictors.col(0))); + Forward(std::move(arma::mat(predictors.colptr(0), + predictors.n_rows, 1, false, true))); resultsTemp = boost::apply_visitor(outputParameterVisitor, network.back()).col(0); @@ -157,7 +156,8 @@ void FFN::Predict( for (size_t i = 1; i < predictors.n_cols; i++) { - Forward(std::move(predictors.col(i))); + Forward(std::move(arma::mat(predictors.colptr(i), + predictors.n_rows, 1, false, true))); resultsTemp = boost::apply_visitor(outputParameterVisitor, network.back()); diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp index 185faadceef..3453cfadfd8 100644 --- a/src/mlpack/methods/ann/rnn.hpp +++ b/src/mlpack/methods/ann/rnn.hpp @@ -46,10 +46,13 @@ class RNN /** * Create the RNN object with the given predictors and responses set (this is - * the set that is used to train the network) and the given optimizer. + * the set that is used to train the network). * Optionally, specify which initialize rule and performance function should * be used. * + * If you want to pass in a parameter and discard the original parameter + * object, be sure to use std::move to avoid unnecessary copy. + * * @param rho Maximum number of steps to backpropagate through time (BPTT). * @param single Predict only the last element of the input sequence. * @param outputLayer Output layer used to evaluate the network. @@ -63,10 +66,13 @@ class RNN /** * Create the RNN object with the given predictors and responses set (this is - * the set that is used to train the network) and the given optimizer. + * the set that is used to train the network). * Optionally, specify which initialize rule and performance function should * be used. * + * If you want to pass in a parameter and discard the original parameter + * object, be sure to use std::move to avoid unnecessary copy. + * * @param predictors Input training variables. * @param responses Outputs results from input training variables. * @param rho Maximum number of steps to backpropagate through time (BPTT). @@ -75,8 +81,8 @@ class RNN * @param initializeRule Optional instantiated InitializationRule object * for initializing the network parameter. */ - RNN(const arma::mat& predictors, - const arma::mat& responses, + RNN(arma::mat predictors, + arma::mat responses, const size_t rho, const bool single = false, OutputLayerType outputLayer = OutputLayerType(), @@ -93,6 +99,9 @@ class RNN * optimization. If this is not what you want, then you should access the * parameters vector directly with Parameters() and modify it as desired. * + * If you want to pass in a parameter and discard the original parameter + * object, be sure to use std::move to avoid unnecessary copy. + * * @tparam OptimizerType Type of optimizer to use to train the model. * @param predictors Input training variables. * @param responses Outputs results from input training variables. @@ -103,8 +112,8 @@ class RNN mlpack::optimization::StandardSGD, typename... OptimizerTypeArgs > - void Train(const arma::mat& predictors, - const arma::mat& responses, + void Train(arma::mat predictors, + arma::mat responses, OptimizerType& optimizer); /** @@ -116,6 +125,9 @@ class RNN * optimization. If this is not what you want, then you should access the * parameters vector directly with Parameters() and modify it as desired. * + * If you want to pass in a parameter and discard the original parameter + * object, be sure to use std::move to avoid unnecessary copy. + * * @tparam OptimizerType Type of optimizer to use to train the model. * @param predictors Input training variables. * @param responses Outputs results from input training variables. @@ -124,17 +136,20 @@ class RNN template class OptimizerType = mlpack::optimization::StandardSGD > - void Train(const arma::mat& predictors, const arma::mat& responses); + void Train(arma::mat predictors, arma::mat responses); /** * Predict the responses to a given set of predictors. The responses will * reflect the output of the given output layer as returned by the * output layer function. * + * If you want to pass in a parameter and discard the original parameter + * object, be sure to use std::move to avoid unnecessary copy. + * * @param predictors Input predictors. * @param results Matrix to put output predictions of responses into. */ - void Predict(const arma::mat& predictors, arma::mat& results); + void Predict(arma::mat predictors, arma::mat& results); /** * Evaluate the recurrent neural network with the given parameters. This diff --git a/src/mlpack/methods/ann/rnn_impl.hpp b/src/mlpack/methods/ann/rnn_impl.hpp index 4130ef09330..0db301e82f9 100644 --- a/src/mlpack/methods/ann/rnn_impl.hpp +++ b/src/mlpack/methods/ann/rnn_impl.hpp @@ -34,8 +34,8 @@ RNN::RNN( OutputLayerType outputLayer, InitializationRuleType initializeRule) : rho(rho), - outputLayer(outputLayer), - initializeRule(initializeRule), + outputLayer(std::move(outputLayer)), + initializeRule(std::move(initializeRule)), inputSize(0), outputSize(0), targetSize(0), @@ -47,27 +47,25 @@ RNN::RNN( template RNN::RNN( - const arma::mat& predictors, - const arma::mat& responses, + arma::mat predictors, + arma::mat responses, const size_t rho, const bool single, OutputLayerType outputLayer, InitializationRuleType initializeRule) : rho(rho), - outputLayer(outputLayer), - initializeRule(initializeRule), + outputLayer(std::move(outputLayer)), + initializeRule(std::move(initializeRule)), inputSize(0), outputSize(0), targetSize(0), reset(false), - single(single) + single(single), + predictors(std::move(predictors)), + responses(std::move(responses)), + deterministic(true) { - numFunctions = responses.n_cols; - - this->predictors = std::move(predictors); - this->responses = std::move(responses); - - this->deterministic = true; + numFunctions = this->responses.n_cols; ResetDeterministic(); } @@ -86,8 +84,8 @@ template< typename... OptimizerTypeArgs > void RNN::Train( - const arma::mat& predictors, - const arma::mat& responses, + arma::mat predictors, + arma::mat responses, OptimizerType& optimizer) { numFunctions = responses.n_cols; @@ -106,7 +104,7 @@ void RNN::Train( // Train the model. Timer::Start("rnn_optimization"); - const double out = optimizer.Optimize(parameter); + const double out = optimizer.Optimize(*this, parameter); Timer::Stop("rnn_optimization"); Log::Info << "RNN::RNN(): final objective of trained model is " << out @@ -116,7 +114,7 @@ void RNN::Train( template template class OptimizerType> void RNN::Train( - const arma::mat& predictors, const arma::mat& responses) + arma::mat predictors, arma::mat responses) { numFunctions = responses.n_cols; @@ -136,7 +134,7 @@ void RNN::Train( // Train the model. Timer::Start("rnn_optimization"); - const double out = optimizer.Optimize(parameter); + const double out = optimizer.Optimize(*this, parameter); Timer::Stop("rnn_optimization"); Log::Info << "RNN::RNN(): final objective of trained model is " << out @@ -145,7 +143,7 @@ void RNN::Train( template void RNN::Predict( - const arma::mat& predictors, arma::mat& results) + arma::mat predictors, arma::mat& results) { if (parameter.is_empty()) { @@ -163,7 +161,10 @@ void RNN::Predict( for (size_t i = 0; i < predictors.n_cols; i++) { - SinglePredict(predictors.col(i), resultsTemp); + SinglePredict( + arma::mat(predictors.colptr(i), predictors.n_rows, 1, false, true), + resultsTemp); + results.col(i) = resultsTemp; } } @@ -199,7 +200,8 @@ double RNN::Evaluate( ResetDeterministic(); } - arma::mat input = predictors.col(i); + arma::mat input = arma::mat(predictors.colptr(i), predictors.n_rows, + 1, false, true); arma::mat target = arma::mat(responses.colptr(i), responses.n_rows, 1, false, true); diff --git a/src/mlpack/methods/reinforcement_learning/CMakeLists.txt b/src/mlpack/methods/reinforcement_learning/CMakeLists.txt new file mode 100644 index 00000000000..ceaaa68242f --- /dev/null +++ b/src/mlpack/methods/reinforcement_learning/CMakeLists.txt @@ -0,0 +1,19 @@ +# Define the files we need to compile +# Anything not in this list will not be compiled into mlpack. +set(SOURCES + q_learning.hpp + q_learning_impl.hpp +) + +# Add directory name to sources. +set(DIR_SRCS) +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() +# Append sources (with directory name) to list of all mlpack sources (used at +# the parent scope). +set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) + +add_subdirectory(environment) +add_subdirectory(estimator) +add_subdirectory(policy) diff --git a/src/mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp b/src/mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp index 036cd147ed6..e44cf11c620 100644 --- a/src/mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp +++ b/src/mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp @@ -51,14 +51,15 @@ class GreedyPolicy * Sample an action based on given action values. * * @param actionValue Values for each action. + * @param deterministic Always select the action greedily. * @return Sampled action. */ - ActionType Sample(const arma::colvec& actionValue) + ActionType Sample(const arma::colvec& actionValue, bool deterministic = false) { double exploration = math::Random(); // Select the action randomly. - if (exploration < epsilon) + if (!deterministic && exploration < epsilon) return static_cast(math::RandInt(ActionType::size)); // Select the action greedily. diff --git a/src/mlpack/methods/reinforcement_learning/q_learning.hpp b/src/mlpack/methods/reinforcement_learning/q_learning.hpp new file mode 100644 index 00000000000..4b797808af2 --- /dev/null +++ b/src/mlpack/methods/reinforcement_learning/q_learning.hpp @@ -0,0 +1,173 @@ +/** + * @file q_learning.hpp + * @author Shangtong Zhang + * + * This file is the definition of QLearning class, + * which implements Q-Learning algorithms. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MLPACK_METHODS_RL_Q_LEARNING_HPP +#define MLPACK_METHODS_RL_Q_LEARNING_HPP + +#include + +#include "replay/random_replay.hpp" + +namespace mlpack { +namespace rl { + +/** + * Implementation of various Q-Learning algorithms, such as DQN, double DQN. + * + * For more details, see the following: + * @code + * @article{Mnih2013, + * author = {Volodymyr Mnih and + * Koray Kavukcuoglu and + * David Silver and + * Alex Graves and + * Ioannis Antonoglou and + * Daan Wierstra and + * Martin A. Riedmiller}, + * title = {Playing Atari with Deep Reinforcement Learning}, + * journal = {CoRR}, + * year = {2013}, + * url = {http://arxiv.org/abs/1312.5602} + * } + * @endcode + * + * @tparam EnvironmentType The environment of the reinforcement learning task. + * @tparam NetworkType The network to compute action value. + * @tparam OptimizerType The optimizer to train the network. + * @tparam PolicyType Behavior policy of the agent. + * @tparam ReplayType Experience replay method. + */ +template < + typename EnvironmentType, + typename NetworkType, + typename OptimizerType, + typename PolicyType, + typename ReplayType = RandomReplay +> +class QLearning +{ + public: + //! Convenient typedef for state. + using StateType = typename EnvironmentType::State; + + //! Convenient typedef for action. + using ActionType = typename EnvironmentType::Action; + + /** + * Create the QLearning object with given settings. + * + * If you want to pass in a parameter and discard the original parameter + * object, be sure to use std::move to avoid unnecessary copy. + * + * @param network The network to compute action value. + * @param optimizer The optimizer to train the network. + * @param discount Discount for future return. + * @param policy Behavior policy of the agent. + * @param replayMethod Experience replay method. + * @param targetNetworkSyncInterval Interval (steps) to sync the target network. + * @param explorationSteps Steps before starting to learn. + * @param doubleQLearning Whether to use double Q-Learning. + * @param stepLimit Maximum steps in each episode, 0 means no limit. + * @param environment Reinforcement learning task. + */ + QLearning(NetworkType network, + OptimizerType optimizer, + const double discount, + PolicyType policy, + ReplayType replayMethod, + const size_t targetNetworkSyncInterval, + const size_t explorationSteps, + const bool doubleQLearning = false, + const size_t stepLimit = 0, + EnvironmentType environment = EnvironmentType()); + + /** + * Execute a step in an episode. + * @return Reward for the step. + */ + double Step(); + + /** + * Execute an episode. + * @return Return of the episode. + */ + double Episode(); + + /** + * @return Total steps from beginning. + */ + const size_t& TotalSteps() const { return totalSteps; } + + //! Modify the training mode / test mode indicator. + bool& Deterministic() { return deterministic; } + + //! Get the indicator of training mode / test mode. + const bool& Deterministic() const { return deterministic; } + + private: + /** + * Select the best action based on given action value. + * @param actionValues Action values. + * @return Selected actions. + */ + arma::Col BestAction(const arma::mat& actionValues); + + //! Locally-stored learning network. + NetworkType learningNetwork; + + //! Locally-stored target network. + NetworkType targetNetwork; + + //! Locally-stored optimizer. + OptimizerType optimizer; + + //! Discount factor of future return. + double discount; + + //! Locally-stored behavior policy. + PolicyType policy; + + //! Locally-stored experience method. + ReplayType replayMethod; + + //! Interval (steps) to update target network. + size_t targetNetworkSyncInterval; + + //! Random steps before starting to learn. + size_t explorationSteps; + + //! Whether to use double Q-Learning. + bool doubleQLearning; + + //! Maximum steps for each episode. + size_t stepLimit; + + //! Locally-stored reinforcement learning task. + EnvironmentType environment; + + //! Total steps from the beginning of the task. + size_t totalSteps; + + //! Locally-stored current state of the agent. + StateType state; + + //! Locally-stored flag indicating training mode or test mode. + bool deterministic; +}; + +} // namespace rl +} // namespace mlpack + +// Include implementation +#include "q_learning_impl.hpp" +#endif diff --git a/src/mlpack/methods/reinforcement_learning/q_learning_impl.hpp b/src/mlpack/methods/reinforcement_learning/q_learning_impl.hpp new file mode 100644 index 00000000000..80a1291e1fe --- /dev/null +++ b/src/mlpack/methods/reinforcement_learning/q_learning_impl.hpp @@ -0,0 +1,218 @@ +/** + * @file q_learning_impl.hpp + * @author Shangtong Zhang + * + * This file is the implementation of QLearning class. + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MLPACK_METHODS_RL_Q_LEARNING_IMPL_HPP +#define MLPACK_METHODS_RL_Q_LEARNING_IMPL_HPP + +#include "q_learning.hpp" + +namespace mlpack { +namespace rl { + +template < + typename EnvironmentType, + typename NetworkType, + typename OptimizerType, + typename PolicyType, + typename ReplayType +> +QLearning< + EnvironmentType, + NetworkType, + OptimizerType, + PolicyType, + ReplayType +>::QLearning(NetworkType network, + OptimizerType optimizer, + const double discount, + PolicyType policy, + ReplayType replayMethod, + const size_t targetNetworkSyncInterval, + const size_t explorationsSteps, + const bool doubleQLearning, + const size_t stepLimit, + EnvironmentType environment): + learningNetwork(std::move(network)), + optimizer(std::move(optimizer)), + discount(discount), + policy(std::move(policy)), + replayMethod(std::move(replayMethod)), + targetNetworkSyncInterval(targetNetworkSyncInterval), + explorationSteps(explorationsSteps), + doubleQLearning(doubleQLearning), + stepLimit(stepLimit), + environment(std::move(environment)), + totalSteps(0), + deterministic(false) +{ + learningNetwork.ResetParameters(); + targetNetwork = learningNetwork; +} + + +template < + typename EnvironmentType, + typename NetworkType, + typename OptimizerType, + typename PolicyType, + typename ReplayType +> +arma::Col QLearning< + EnvironmentType, + NetworkType, + OptimizerType, + PolicyType, + ReplayType +>::BestAction(const arma::mat& actionValues) +{ + arma::Col bestActions(actionValues.n_cols); + arma::rowvec maxActionValues = arma::max(actionValues, 0); + for (size_t i = 0; i < actionValues.n_cols; ++i) + { + bestActions(i) = arma::as_scalar( + arma::find(actionValues.col(i) == maxActionValues[i], 1)); + } + return bestActions; +}; + +template < + typename EnvironmentType, + typename NetworkType, + typename OptimizerType, + typename BehaviorPolicyType, + typename ReplayType +> +double QLearning< + EnvironmentType, + NetworkType, + OptimizerType, + BehaviorPolicyType, + ReplayType +>::Step() +{ + // Get the action value for each action at current state. + arma::colvec actionValue; + learningNetwork.Predict(state.Encode(), actionValue); + + // Select an action according to the behavior policy. + ActionType action = policy.Sample(actionValue, deterministic); + + // Interact with the environment to advance to next state. + StateType nextState; + double reward = environment.Sample(state, action, nextState); + + // Store the transition for replay. + replayMethod.Store(state, action, reward, + nextState, environment.IsTerminal(nextState)); + + // Update current state. + state = nextState; + + if (deterministic || totalSteps < explorationSteps) + return reward; + + // Start experience replay. + + // Sample from previous experience. + arma::mat sampledStates; + arma::icolvec sampledActions; + arma::colvec sampledRewards; + arma::mat sampledNextStates; + arma::icolvec isTerminal; + replayMethod.Sample(sampledStates, sampledActions, sampledRewards, + sampledNextStates, isTerminal); + + // Compute action value for next state with target network. + arma::mat nextActionValues; + targetNetwork.Predict(sampledNextStates, nextActionValues); + + arma::Col bestActions; + if (doubleQLearning) + { + // If use double Q-Learning, use learning network to select the best action. + arma::mat nextActionValues; + learningNetwork.Predict(sampledNextStates, nextActionValues); + bestActions = BestAction(nextActionValues); + } + else + { + bestActions = BestAction(nextActionValues); + } + + // Compute the update target. + arma::mat target; + learningNetwork.Predict(sampledStates, target); + for (size_t i = 0; i < sampledNextStates.n_cols; ++i) + { + target(sampledActions[i], i) = sampledRewards[i] + + discount * (isTerminal[i] ? 0.0 : nextActionValues(bestActions[i], i)); + } + + // Learn form experience. + learningNetwork.Train(sampledStates, target, optimizer); + + return reward; +} + +template < + typename EnvironmentType, + typename NetworkType, + typename OptimizerType, + typename BehaviorPolicyType, + typename ReplayType +> +double QLearning< + EnvironmentType, + NetworkType, + OptimizerType, + BehaviorPolicyType, + ReplayType +>::Episode() +{ + // Get the initial state from environment. + state = environment.InitialSample(); + + // Track the steps in this episode. + size_t steps = 0; + + // Track the return of this episode. + double totalReturn = 0.0; + + // Running until get to the terminal state. + while (!environment.IsTerminal(state)) + { + if (stepLimit && steps >= stepLimit) + break; + + totalReturn += Step(); + steps++; + + if (deterministic) + continue; + + totalSteps++; + + // Update target network + if (totalSteps % targetNetworkSyncInterval == 0) + targetNetwork = learningNetwork; + + if (totalSteps > explorationSteps) + policy.Anneal(); + } + return totalReturn; +} + +} // namespace rl +} // namespace mlpack + +#endif + diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index c59c6336043..97998c76c91 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -69,6 +69,7 @@ add_executable(mlpack_test octree_test.cpp pca_test.cpp perceptron_test.cpp + q_learning_test.cpp qdafn_test.cpp quic_svd_test.cpp radical_test.cpp diff --git a/src/mlpack/tests/convolutional_network_test.cpp b/src/mlpack/tests/convolutional_network_test.cpp index ea54a4ab95b..d6f22f3f5f3 100644 --- a/src/mlpack/tests/convolutional_network_test.cpp +++ b/src/mlpack/tests/convolutional_network_test.cpp @@ -92,7 +92,7 @@ BOOST_AUTO_TEST_CASE(VanillaNetworkTest) RMSProp opt(model, 0.001, 0.88, 1e-8, 5000, -1); - model.Train(std::move(X), std::move(Y), opt); + model.Train(X, Y, opt); arma::mat predictionTemp; model.Predict(X, predictionTemp); diff --git a/src/mlpack/tests/feedforward_network_test.cpp b/src/mlpack/tests/feedforward_network_test.cpp index affa3cc4fa2..d783830c9cc 100644 --- a/src/mlpack/tests/feedforward_network_test.cpp +++ b/src/mlpack/tests/feedforward_network_test.cpp @@ -69,7 +69,7 @@ void BuildVanillaNetwork(MatType& trainData, RMSProp opt(model, 0.01, 0.88, 1e-8, maxEpochs * trainData.n_cols, -1); - model.Train(std::move(trainData), std::move(trainLabels), opt); + model.Train(trainData, trainLabels, opt); MatType predictionTemp; model.Predict(testData, predictionTemp); @@ -197,7 +197,7 @@ void BuildDropoutNetwork(MatType& trainData, RMSProp opt(model, 0.01, 0.88, 1e-8, maxEpochs * trainData.n_cols, -1); - model.Train(std::move(trainData), std::move(trainLabels), opt); + model.Train(trainData, trainLabels, opt); MatType predictionTemp; model.Predict(testData, predictionTemp); @@ -327,7 +327,7 @@ void BuildDropConnectNetwork(MatType& trainData, RMSProp opt(model, 0.01, 0.88, 1e-8, maxEpochs * trainData.n_cols, -1); - model.Train(std::move(trainData), std::move(trainLabels), opt); + model.Train(trainData, trainLabels, opt); MatType predictionTemp; model.Predict(testData, predictionTemp); diff --git a/src/mlpack/tests/ksinit_test.cpp b/src/mlpack/tests/ksinit_test.cpp index 467bee075a9..cf398084144 100644 --- a/src/mlpack/tests/ksinit_test.cpp +++ b/src/mlpack/tests/ksinit_test.cpp @@ -88,7 +88,7 @@ void BuildVanillaNetwork(MatType& trainData, RMSProp opt(model, 0.01, 0.88, 1e-8, maxEpochs * trainData.n_cols, 1e-18); - model.Train(std::move(trainData), std::move(trainLabels), opt); + model.Train(trainData, trainLabels, opt); MatType prediction; diff --git a/src/mlpack/tests/q_learning_test.cpp b/src/mlpack/tests/q_learning_test.cpp new file mode 100644 index 00000000000..9b1429d8211 --- /dev/null +++ b/src/mlpack/tests/q_learning_test.cpp @@ -0,0 +1,149 @@ +/** + * @file q_learning_test.hpp + * @author Shangtong Zhang + * + * Test for Q-Learning implementation + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "test_tools.hpp" + +using namespace mlpack; +using namespace mlpack::ann; +using namespace mlpack::optimization; +using namespace mlpack::rl; + +BOOST_AUTO_TEST_SUITE(QLearningTest); + +//! Test DQN in Cart Pole task. +BOOST_AUTO_TEST_CASE(CartPoleWithDQN) +{ + // Set up the network. + FFN, GaussianInitialization> model; + model.Add>(4, 128); + model.Add>(); + model.Add>(128, 128); + model.Add>(); + model.Add>(128, 2); + + // Set up the optimizer generator. + StandardSGD opt(model, 0.0001, 2); + + // Set up the policy and replay method. + GreedyPolicy policy(1.0, 1000, 0.1); + RandomReplay replayMethod(10, 10000); + + // Set up DQN agent. + QLearning + agent(std::move(model), std::move(opt), 0.9, std::move(policy), + std::move(replayMethod), 100, 100, false, 200); + + arma::running_stat averageReturn; + size_t episodes = 0; + bool converged = true; + while (true) + { + double episodeReturn = agent.Episode(); + averageReturn(episodeReturn); + episodes += 1; + + if (episodes > 1000) + { + Log::Debug << "Cart Pole with DQN failed." << std::endl; + converged = false; + break; + } + + /** + * Reaching running average return 35 is enough to show it works. + * For the speed of the test case, I didn't set high criterion. + */ + Log::Debug << "Average return: " << averageReturn.mean() + << " Episode return: " << episodeReturn << std::endl; + if (averageReturn.mean() > 35) + { + agent.Deterministic() = true; + arma::running_stat testReturn; + for (size_t i = 0; i < 10; ++i) + testReturn(agent.Episode()); + Log::Debug << "Average return in deterministic test: " + << testReturn.mean() << std::endl; + break; + } + } + BOOST_REQUIRE(converged); +} + +//! Test Double DQN in Cart Pole task. +BOOST_AUTO_TEST_CASE(CartPoleWithDoubleDQN) +{ + // Set up the network. + FFN, GaussianInitialization> model; + model.Add>(4, 128); + model.Add>(); + model.Add>(128, 128); + model.Add>(); + model.Add>(128, 2); + + // Set up the optimizer. + StandardSGD opt(model, 0.0001, 2); + + // Set up the policy and replay method. + GreedyPolicy policy(1.0, 1000, 0.1); + RandomReplay replayMethod(10, 10000); + + // Set up the DQN agent. + QLearning + agent(std::move(model), std::move(opt), 0.9, std::move(policy), + std::move(replayMethod), 100, 100, true, 200); + + arma::running_stat averageReturn; + size_t episodes = 0; + bool converged = true; + while (true) + { + double episodeReturn = agent.Episode(); + averageReturn(episodeReturn); + episodes += 1; + if (episodes > 1000) { + converged = false; + Log::Debug << "Cart Pole with DQN failed." << std::endl; + break; + } + /** + * Reaching running average return 35 is enough to show it works. + * For the speed of the test case, I didn't set high criterion. + */ + Log::Debug << "Average return: " << averageReturn.mean() + << " Episode return: " << episodeReturn << std::endl; + if (averageReturn.mean() > 35) + { + agent.Deterministic() = true; + arma::running_stat testReturn; + for (size_t i = 0; i < 10; ++i) + testReturn(agent.Episode()); + Log::Debug << "Average return in deterministic test: " + << testReturn.mean() << std::endl; + break; + } + } + BOOST_REQUIRE(converged); +} + +BOOST_AUTO_TEST_SUITE_END();