From 099faed745150d353158da54a1754059e1246052 Mon Sep 17 00:00:00 2001 From: Shangtong Zhang Date: Mon, 27 Mar 2017 22:23:08 -0600 Subject: [PATCH 1/4] Fix bug of variadic template parameters of Optimizer --- src/mlpack/methods/ann/ffn.hpp | 5 +++-- src/mlpack/methods/ann/ffn_impl.hpp | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/mlpack/methods/ann/ffn.hpp b/src/mlpack/methods/ann/ffn.hpp index b0269b17507..6a612b12c4e 100644 --- a/src/mlpack/methods/ann/ffn.hpp +++ b/src/mlpack/methods/ann/ffn.hpp @@ -92,11 +92,12 @@ class FFN * @param optimizer Instantiated optimizer used to train the model. */ template< - template class OptimizerType = mlpack::optimization::RMSprop + template class OptimizerType = mlpack::optimization::RMSprop, + typename ...Args > void Train(const arma::mat& predictors, const arma::mat& responses, - OptimizerType& optimizer); + OptimizerType& optimizer); /** * Train the feedforward network on the given input data. By default, the diff --git a/src/mlpack/methods/ann/ffn_impl.hpp b/src/mlpack/methods/ann/ffn_impl.hpp index a8c6f04be54..0d192976c70 100644 --- a/src/mlpack/methods/ann/ffn_impl.hpp +++ b/src/mlpack/methods/ann/ffn_impl.hpp @@ -74,11 +74,11 @@ FFN::~FFN() } template -template class OptimizerType> +template class OptimizerType, typename ...Args> void FFN::Train( const arma::mat& predictors, const arma::mat& responses, - OptimizerType& optimizer) + OptimizerType& optimizer) { numFunctions = responses.n_cols; From e981050fd9a7199a4b4cc3226982fafdead50dbc Mon Sep 17 00:00:00 2001 From: Shangtong Zhang Date: Mon, 24 Apr 2017 20:04:18 -0600 Subject: [PATCH 2/4] Classical control tasks, Mountain Car and Cart Pole --- .gitignore | 2 + .../environment/CMakeLists.txt | 15 ++ .../environment/cart_pole.hpp | 182 ++++++++++++++++++ .../environment/mountain_car.hpp | 155 +++++++++++++++ 4 files changed, 354 insertions(+) create mode 100644 src/mlpack/methods/reinforcement_learning/environment/CMakeLists.txt create mode 100644 src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp create mode 100644 src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp diff --git a/.gitignore b/.gitignore index 6a6e0b36ecd..826abc20d12 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ xcode* .DS_Store src/mlpack/core/util/gitversion.hpp src/mlpack/core/util/arma_config.hpp +.idea +cmake-build-* \ No newline at end of file diff --git a/src/mlpack/methods/reinforcement_learning/environment/CMakeLists.txt b/src/mlpack/methods/reinforcement_learning/environment/CMakeLists.txt new file mode 100644 index 00000000000..58baa7e14d3 --- /dev/null +++ b/src/mlpack/methods/reinforcement_learning/environment/CMakeLists.txt @@ -0,0 +1,15 @@ +# Define the files we need to compile +# Anything not in this list will not be compiled into mlpack. +set(SOURCES + mountain_car.hpp + cart_pole.hpp +) + +# Add directory name to sources. +set(DIR_SRCS) +foreach(file ${SOURCES}) + set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) +endforeach() +# Append sources (with directory name) to list of all mlpack sources (used at +# the parent scope). +set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) diff --git a/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp b/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp new file mode 100644 index 00000000000..355b9f08b91 --- /dev/null +++ b/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp @@ -0,0 +1,182 @@ +/** + * @file cart_pole.hpp + * @author Shangtong Zhang + * + * This file is an implementation of Cart Pole task + * https://gym.openai.com/envs/CartPole-v0 + * + * TODO: refactor to OpenAI interface + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP +#define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP + +#include + +namespace mlpack { +namespace rl { + +namespace cart_pole_details { +// Some constants of Cart Pole task +constexpr double gravity = 9.8; +constexpr double massCart = 1.0; +constexpr double massPole = 0.1; +constexpr double totalMass = massCart + massPole; +constexpr double length = 0.5; +constexpr double poleMassLength = massPole * length; +constexpr double forceMag = 10.0; +constexpr double tau = 0.02; +constexpr double thetaThresholdRadians = 12 * 2 * 3.1416 / 360; +constexpr double xThreshold = 2.4; +} + +/** + * Implementation of Cart Pole task + */ +class CartPole +{ + public: + + /** + * Implementation of state of Cart Pole + * Each state is a tuple of (position, velocity, angle, angular velocity) + */ + class State + { + public: + //! Construct a state instance + State() : data(4) { } + + //! Construct a state instance from given data + State(arma::colvec data) : data(data) { } + + //! Get position + double X() const + { + return data[0]; + } + + //! Modify position + double& X() + { + return data[0]; + } + + //! Get velocity + double XDot() const + { + return data[1]; + } + + //! Modify velocity + double& XDot() + { + return data[1]; + } + + //! Get angle + double Theta() const + { + return data[2]; + } + + //! Modify angle + double& Theta() + { + return data[2]; + } + + //! Get angular velocity + double ThetaDot() const + { + return data[3]; + } + + //! Modify angular velocity + double& ThetaDot() + { + return data[3]; + } + + //! Encode the state to a column vector + const arma::colvec& Encode() const + { + return data; + } + + //! Whether current state is terminal state + bool IsTerminal() const + { + using namespace cart_pole_details; + return std::abs(X()) > xThreshold || + std::abs(Theta()) > thetaThresholdRadians; + } + + private: + //! Locally-stored (position, velocity, angle, angular velocity) + arma::colvec data; + }; + + /** + * Implementation of action of Cart Pole + */ + class Action + { + public: + enum Actions + { + backward, + forward + }; + + //! # of actions + static constexpr size_t count = 2; + }; + + /** + * Dynamics of Cart Pole + * Get next state and next action based on current state and current action + * @param state Current state + * @param action Current action + * @param nextState Next state + * @param reward Reward is always 1 + */ + void Sample(const State& state, const Action::Actions& action, + State& nextState, double& reward) + { + using namespace cart_pole_details; + double force = action ? forceMag : -forceMag; + double cosTheta = std::cos(state.Theta()); + double sinTheta = std::sin(state.Theta()); + double temp = (force + poleMassLength * state.ThetaDot() * state.ThetaDot() * sinTheta) / totalMass; + double thetaAcc = (gravity * sinTheta - cosTheta * temp) / + (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass)); + double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass; + nextState.X() = state.X() + tau * state.XDot(); + nextState.XDot() = state.XDot() + tau * xAcc; + nextState.Theta() = state.Theta() + tau * state.ThetaDot(); + nextState.ThetaDot() = state.ThetaDot() + tau * thetaAcc; + + reward = 1.0; + } + + /** + * Initial state representation is randomly generated within [-0.05, 0.05] + * @return Initial state for each episode + */ + State InitialSample() + { + return State((arma::randu(4) - 0.5) / 10.0); + } + +}; + +} // namespace rl +} // namespace mlpack + +#endif \ No newline at end of file diff --git a/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp b/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp new file mode 100644 index 00000000000..8575ea3bbd6 --- /dev/null +++ b/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp @@ -0,0 +1,155 @@ +/** + * @file mountain_car.hpp + * @author Shangtong Zhang + * + * This file is an implementation of Mountain Car task + * https://gym.openai.com/envs/MountainCar-v0 + * + * TODO: refactor to OpenAI interface + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#ifndef MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP +#define MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP + +#include + +namespace mlpack { +namespace rl { + +namespace mountain_car_details { +constexpr double positionMin = -1.2; +constexpr double positionMax = 0.5; +constexpr double velocityMin = -0.07; +constexpr double velocityMax = 0.07; +} + +/** + * Implementation of Mountain Car task + */ +class MountainCar +{ + public: + + /** + * Implementation of state of Mountain Car + * Each state is a (velocity, position) pair + */ + class State + { + public: + //! Construct a state instance + State(double velocity = 0, double position = 0) : data(2) + { + this->Velocity() = velocity; + this->Position() = position; + } + + //! Encode the state to a column vector + const arma::colvec& Encode() const + { + return data; + } + + //! Get velocity + double Velocity() const + { + return data[0]; + } + + //! Modify velocity + double& Velocity() + { + return data[0]; + } + + //! Get position + double Position() const + { + return data[1]; + } + + //! Modify position + double& Position() + { + return data[1]; + } + + //! Whether current state is terminal state + bool IsTerminal() const + { + using namespace mountain_car_details; + return std::abs(Position() - positionMax) <= 1e-5; + } + + private: + //! Locally-stored velocity and position + arma::colvec data; + }; + + /** + * Implementation of action of Mountain Car + */ + class Action + { + public: + enum Actions + { + backward, + stop, + forward + }; + + //! # of actions + static constexpr size_t count = 3; + }; + + /** + * Dynamics of Mountain Car + * Get next state and next action based on current state and current action + * @param state Current state + * @param action Current action + * @param nextState Next state + * @param reward Reward is always -1 + */ + void Sample(const State& state, const Action::Actions& action, + State& nextState, double& reward) + { + using namespace mountain_car_details; + int direction = action - 1; + nextState.Velocity() = state.Velocity() + 0.001 * direction - 0.0025 * std::cos(3 * state.Position()); + nextState.Velocity() = std::min(std::max(nextState.Velocity(), velocityMin), velocityMax); + + nextState.Position() = state.Position() + nextState.Velocity(); + nextState.Position() = std::min(std::max(nextState.Position(), positionMin), positionMax); + + reward = -1.0; + if (std::abs(nextState.Position() - positionMin) <= 1e-5) + { + nextState.Velocity() = 0.0; + } + } + + /** + * Initial position is randomly generated within [-0.6, -0.4] + * Initial velocity is 0 + * @return Initial state for each episode + */ + State InitialSample() + { + State state; + state.Velocity() = 0.0; + state.Position() = arma::as_scalar(arma::randu(1)) * 0.2 - 0.6; + return state; + } + +}; + +} // namespace rl +} // namespace mlpack + +#endif From c2930166b4361d349af3feac6f70005b4d99d89a Mon Sep 17 00:00:00 2001 From: Shangtong Zhang Date: Fri, 28 Apr 2017 22:35:08 -0600 Subject: [PATCH 3/4] Refactor rl tasks and add a simple unit test --- .../environment/cart_pole.hpp | 187 ++++++++++-------- .../environment/mountain_car.hpp | 129 ++++++------ src/mlpack/tests/CMakeLists.txt | 1 + src/mlpack/tests/rl_environment_test.cpp | 46 +++++ 4 files changed, 223 insertions(+), 140 deletions(-) create mode 100644 src/mlpack/tests/rl_environment_test.cpp diff --git a/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp b/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp index 355b9f08b91..23e9f69662f 100644 --- a/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +++ b/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp @@ -21,20 +21,6 @@ namespace mlpack { namespace rl { -namespace cart_pole_details { -// Some constants of Cart Pole task -constexpr double gravity = 9.8; -constexpr double massCart = 1.0; -constexpr double massPole = 0.1; -constexpr double totalMass = massCart + massPole; -constexpr double length = 0.5; -constexpr double poleMassLength = massPole * length; -constexpr double forceMag = 10.0; -constexpr double tau = 0.02; -constexpr double thetaThresholdRadians = 12 * 2 * 3.1416 / 360; -constexpr double xThreshold = 2.4; -} - /** * Implementation of Cart Pole task */ @@ -53,69 +39,40 @@ class CartPole State() : data(4) { } //! Construct a state instance from given data - State(arma::colvec data) : data(data) { } + State(const arma::colvec& data) : data(data) { } + + /** + * Set the internal data to given value + * @param data desired internal data + */ + void Set(const arma::colvec& data) { this->data = data; } //! Get position - double X() const - { - return data[0]; - } + double Position() const { return data[0]; } //! Modify position - double& X() - { - return data[0]; - } + double& Position() { return data[0]; } //! Get velocity - double XDot() const - { - return data[1]; - } + double Velocity() const { return data[1]; } //! Modify velocity - double& XDot() - { - return data[1]; - } + double& Velocity() { return data[1]; } //! Get angle - double Theta() const - { - return data[2]; - } + double Angle() const { return data[2]; } //! Modify angle - double& Theta() - { - return data[2]; - } + double& Angle() { return data[2]; } //! Get angular velocity - double ThetaDot() const - { - return data[3]; - } + double AngularVelocity() const { return data[3]; } //! Modify angular velocity - double& ThetaDot() - { - return data[3]; - } + double& AngularVelocity() { return data[3]; } //! Encode the state to a column vector - const arma::colvec& Encode() const - { - return data; - } - - //! Whether current state is terminal state - bool IsTerminal() const - { - using namespace cart_pole_details; - return std::abs(X()) > xThreshold || - std::abs(Theta()) > thetaThresholdRadians; - } + const arma::colvec& Encode() const { return data; } private: //! Locally-stored (position, velocity, angle, angular velocity) @@ -125,55 +82,117 @@ class CartPole /** * Implementation of action of Cart Pole */ - class Action + enum Action { - public: - enum Actions - { - backward, - forward - }; - - //! # of actions - static constexpr size_t count = 2; + backward, + forward, + + // Track the size of the action space + size }; + /** + * Construct a Cart Pole instance + * @param gravity gravity + * @param massCart mass of the cart + * @param massPole mass of the pole + * @param length length of the pole + * @param forceMag magnitude of the applied force + * @param tau time interval + * @param thetaThresholdRadians maximum angle + * @param xThreshold maximum position + */ + CartPole(double gravity = 9.8, double massCart = 1.0, double massPole = 0.1, double length = 0.5, double forceMag = 10.0, + double tau = 0.02, double thetaThresholdRadians = 12 * 2 * 3.1416 / 360, double xThreshold = 2.4) : + gravity(gravity), massCart(massCart), massPole(massPole), totalMass(massCart + massPole), + length(length), poleMassLength(massPole * length), forceMag(forceMag), tau(tau), + thetaThresholdRadians(thetaThresholdRadians), xThreshold(xThreshold) { } + /** * Dynamics of Cart Pole - * Get next state and next action based on current state and current action + * Get reward and next state based on current state and current action * @param state Current state * @param action Current action * @param nextState Next state - * @param reward Reward is always 1 + * @return reward, it's always 1.0 */ - void Sample(const State& state, const Action::Actions& action, - State& nextState, double& reward) + double Sample(const State& state, const Action& action, State& nextState) const { - using namespace cart_pole_details; double force = action ? forceMag : -forceMag; - double cosTheta = std::cos(state.Theta()); - double sinTheta = std::sin(state.Theta()); - double temp = (force + poleMassLength * state.ThetaDot() * state.ThetaDot() * sinTheta) / totalMass; + double cosTheta = std::cos(state.Angle()); + double sinTheta = std::sin(state.Angle()); + double temp = (force + poleMassLength * state.AngularVelocity() * state.AngularVelocity() * sinTheta) / totalMass; double thetaAcc = (gravity * sinTheta - cosTheta * temp) / (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass)); double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass; - nextState.X() = state.X() + tau * state.XDot(); - nextState.XDot() = state.XDot() + tau * xAcc; - nextState.Theta() = state.Theta() + tau * state.ThetaDot(); - nextState.ThetaDot() = state.ThetaDot() + tau * thetaAcc; + nextState.Position() = state.Position() + tau * state.Velocity(); + nextState.Velocity() = state.Velocity() + tau * xAcc; + nextState.Angle() = state.Angle() + tau * state.AngularVelocity(); + nextState.AngularVelocity() = state.AngularVelocity() + tau * thetaAcc; - reward = 1.0; + return 1.0; + } + + /** + * Dynamics of Cart Pole + * Get reward based on current state and current action + * @param state Current state + * @param action Current action + * @return reward, it's always 1.0 + */ + double Sample(const State& state, const Action& action) const + { + State nextState; + return Sample(state, action, nextState); } /** * Initial state representation is randomly generated within [-0.05, 0.05] * @return Initial state for each episode */ - State InitialSample() + State InitialSample() const { return State((arma::randu(4) - 0.5) / 10.0); } + + /** + * Whether given state is terminal state + * @param state desired state + * @return true if @state is terminal state, otherwise false + */ + bool IsTerminal(const State& state) const { - return State((arma::randu(4) - 0.5) / 10.0); + return std::abs(state.Position()) > xThreshold || + std::abs(state.Angle()) > thetaThresholdRadians; } + private: + //! Locally-stored gravity + double gravity; + + //! Locally-stored mass of the cart + double massCart; + + //! Locally-stored mass of the pole + double massPole; + + //! Locally-stored total mass + double totalMass; + + //! Locally-stored length of the pole + double length; + + //! Locally-stored moment of pole + double poleMassLength; + + //! Locally-stored magnitude of the applied force + double forceMag; + + //! Locally-stored time interval + double tau; + + //! Locally-stored maximum angle + double thetaThresholdRadians; + + //! Locally-stored maximum position + double xThreshold; }; } // namespace rl diff --git a/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp b/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp index 8575ea3bbd6..edd5c12660a 100644 --- a/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp +++ b/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp @@ -21,13 +21,6 @@ namespace mlpack { namespace rl { -namespace mountain_car_details { -constexpr double positionMin = -1.2; -constexpr double positionMax = 0.5; -constexpr double velocityMin = -0.07; -constexpr double velocityMax = 0.07; -} - /** * Implementation of Mountain Car task */ @@ -43,48 +36,34 @@ class MountainCar { public: //! Construct a state instance - State(double velocity = 0, double position = 0) : data(2) - { - this->Velocity() = velocity; - this->Position() = position; - } + State(): data(2, arma::fill::zeros) { } + + /** + * Construct a state based on given data + * @param data desired internal data + */ + State(const arma::colvec& data): data(data) { } //! Encode the state to a column vector - const arma::colvec& Encode() const - { - return data; - } + const arma::colvec& Encode() const { return data; } + + /** + * Set the internal data to given value + * @param data desired internal data + */ + void Set(const arma::colvec& data) { this->data = data; } //! Get velocity - double Velocity() const - { - return data[0]; - } + double Velocity() const { return data[0]; } //! Modify velocity - double& Velocity() - { - return data[0]; - } + double& Velocity() { return data[0]; } //! Get position - double Position() const - { - return data[1]; - } + double Position() const { return data[1]; } //! Modify position - double& Position() - { - return data[1]; - } - - //! Whether current state is terminal state - bool IsTerminal() const - { - using namespace mountain_car_details; - return std::abs(Position() - positionMax) <= 1e-5; - } + double& Position() { return data[1]; } private: //! Locally-stored velocity and position @@ -94,32 +73,36 @@ class MountainCar /** * Implementation of action of Mountain Car */ - class Action + enum Action { - public: - enum Actions - { - backward, - stop, - forward - }; + backward, + stop, + forward, - //! # of actions - static constexpr size_t count = 3; + // Track the size of the action space + size }; + /** + * Construct a Mountain Car instance + * @param positionMin minimum legal position + * @param positionMax maximum legal position + * @param velocityMin minimum legal velocity + * @param velocityMax maximum legal velocity + */ + MountainCar(double positionMin = -1.2, double positionMax = 0.5, double velocityMin = -0.07, double velocityMax = 0.07): + positionMin(positionMin), positionMax(positionMax), velocityMin(velocityMin), velocityMax(velocityMax) { } + /** * Dynamics of Mountain Car - * Get next state and next action based on current state and current action + * Get reward and next state based on current state and current action * @param state Current state * @param action Current action * @param nextState Next state - * @param reward Reward is always -1 + * @return reward, it's always -1.0 */ - void Sample(const State& state, const Action::Actions& action, - State& nextState, double& reward) + double Sample(const State& state, const Action& action, State& nextState) const { - using namespace mountain_car_details; int direction = action - 1; nextState.Velocity() = state.Velocity() + 0.001 * direction - 0.0025 * std::cos(3 * state.Position()); nextState.Velocity() = std::min(std::max(nextState.Velocity(), velocityMin), velocityMax); @@ -127,11 +110,25 @@ class MountainCar nextState.Position() = state.Position() + nextState.Velocity(); nextState.Position() = std::min(std::max(nextState.Position(), positionMin), positionMax); - reward = -1.0; if (std::abs(nextState.Position() - positionMin) <= 1e-5) { nextState.Velocity() = 0.0; } + + return -1.0; + } + + /** + * Dynamics of Mountain Car + * Get reward based on current state and current action + * @param state Current state + * @param action Current action + * @return reward, it's always -1.0 + */ + double Sample(const State& state, const Action& action) const + { + State nextState; + return Sample(state, action, nextState); } /** @@ -139,7 +136,7 @@ class MountainCar * Initial velocity is 0 * @return Initial state for each episode */ - State InitialSample() + State InitialSample() const { State state; state.Velocity() = 0.0; @@ -147,6 +144,26 @@ class MountainCar return state; } + /** + * Whether given state is terminal state + * @param state desired state + * @return true if @state is terminal state, otherwise false + */ + bool IsTerminal(const State& state) const { return std::abs(state.Position() - positionMax) <= 1e-5; } + + private: + //! Locally-stored minimum legal position + double positionMin; + + //! Locally-stored maximum legal position + double positionMax; + + //! Locally-stored minimum legal velocity + double velocityMin; + + //! Locally-stored maximum legal velocity + double velocityMax; + }; } // namespace rl diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index 475764580ef..f96a25d39e3 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -76,6 +76,7 @@ add_executable(mlpack_test recurrent_network_test.cpp rectangle_tree_test.cpp regularized_svd_test.cpp + rl_environment_test.cpp rmsprop_test.cpp sa_test.cpp sdp_primal_dual_test.cpp diff --git a/src/mlpack/tests/rl_environment_test.cpp b/src/mlpack/tests/rl_environment_test.cpp new file mode 100644 index 00000000000..c2f64315c2b --- /dev/null +++ b/src/mlpack/tests/rl_environment_test.cpp @@ -0,0 +1,46 @@ +/** + * @file rl_environment_test.hpp + * @author Shangtong Zhang + * + * Basic test for the reinforcement learning task environment + * + * mlpack is free software; you may redistribute it and/or modify it under the + * terms of the 3-clause BSD license. You should have received a copy of the + * 3-clause BSD license along with mlpack. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ + +#include + +#include +#include + +#include +#include "test_tools.hpp" + +using namespace mlpack; +using namespace mlpack::rl; + +BOOST_AUTO_TEST_SUITE(RLEnvironmentTest) + +BOOST_AUTO_TEST_CASE(MountainCarTest) +{ + const auto task = MountainCar(); + auto state = task.InitialSample(); + auto action = MountainCar::Action::backward; + auto reward = task.Sample(state, action); + BOOST_REQUIRE(!task.IsTerminal(state)); + BOOST_REQUIRE_EQUAL(3, MountainCar::Action::size); +} + +BOOST_AUTO_TEST_CASE(CartPoleTest) +{ + const auto task = CartPole(); + auto state = task.InitialSample(); + auto action = CartPole::Action::backward; + auto reward = task.Sample(state, action); + BOOST_REQUIRE(!task.IsTerminal(state)); + BOOST_REQUIRE_EQUAL(2, CartPole::Action::size); +} + +BOOST_AUTO_TEST_SUITE_END() From 1acd5d181510f10c63b9e4ba04067466213835d8 Mon Sep 17 00:00:00 2001 From: Shangtong Zhang Date: Mon, 1 May 2017 15:32:29 -0600 Subject: [PATCH 4/4] Fix code style issues --- .../environment/cart_pole.hpp | 120 +++++++++--------- .../environment/mountain_car.hpp | 96 +++++++------- src/mlpack/tests/rl_environment_test.cpp | 22 ++-- 3 files changed, 119 insertions(+), 119 deletions(-) diff --git a/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp b/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp index 23e9f69662f..4d9dceb7116 100644 --- a/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp +++ b/src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp @@ -2,10 +2,10 @@ * @file cart_pole.hpp * @author Shangtong Zhang * - * This file is an implementation of Cart Pole task + * This file is an implementation of Cart Pole task: * https://gym.openai.com/envs/CartPole-v0 * - * TODO: refactor to OpenAI interface + * TODO: provide an option to use dynamics directly from OpenAI gym. * * mlpack is free software; you may redistribute it and/or modify it under the * terms of the 3-clause BSD license. You should have received a copy of the @@ -22,85 +22,82 @@ namespace mlpack { namespace rl { /** - * Implementation of Cart Pole task + * Implementation of Cart Pole task. */ class CartPole { public: /** - * Implementation of state of Cart Pole - * Each state is a tuple of (position, velocity, angle, angular velocity) + * Implementation of state of Cart Pole. + * Each state is a tuple of (position, velocity, angle, angular velocity). */ class State { public: - //! Construct a state instance + //! Construct a state instance. State() : data(4) { } - //! Construct a state instance from given data + //! Construct a state instance from given data. State(const arma::colvec& data) : data(data) { } - /** - * Set the internal data to given value - * @param data desired internal data - */ - void Set(const arma::colvec& data) { this->data = data; } + //! Modify the internal representation of the state. + arma::colvec& Data() { return data; } - //! Get position + //! Get position. double Position() const { return data[0]; } - //! Modify position + //! Modify position. double& Position() { return data[0]; } - //! Get velocity + //! Get velocity. double Velocity() const { return data[1]; } - //! Modify velocity + //! Modify velocity. double& Velocity() { return data[1]; } - //! Get angle + //! Get angle. double Angle() const { return data[2]; } - //! Modify angle + //! Modify angle. double& Angle() { return data[2]; } - //! Get angular velocity + //! Get angular velocity. double AngularVelocity() const { return data[3]; } - //! Modify angular velocity + //! Modify angular velocity. double& AngularVelocity() { return data[3]; } - //! Encode the state to a column vector + //! Encode the state to a column vector. const arma::colvec& Encode() const { return data; } private: - //! Locally-stored (position, velocity, angle, angular velocity) + //! Locally-stored (position, velocity, angle, angular velocity). arma::colvec data; }; /** - * Implementation of action of Cart Pole + * Implementation of action of Cart Pole. */ enum Action { backward, forward, - // Track the size of the action space + // Track the size of the action space. size }; /** - * Construct a Cart Pole instance - * @param gravity gravity - * @param massCart mass of the cart - * @param massPole mass of the pole - * @param length length of the pole - * @param forceMag magnitude of the applied force - * @param tau time interval - * @param thetaThresholdRadians maximum angle - * @param xThreshold maximum position + * Construct a Cart Pole instance. + * @param gravity gravity. + * @param massCart mass of the cart. + * @param massPole mass of the pole. + * @param length length of the pole. + * @param forceMag magnitude of the applied force. + * @param tau time interval. + * @param thetaThresholdRadians maximum angle. + * @param xThreshold maximum position. */ CartPole(double gravity = 9.8, double massCart = 1.0, double massPole = 0.1, double length = 0.5, double forceMag = 10.0, double tau = 0.02, double thetaThresholdRadians = 12 * 2 * 3.1416 / 360, double xThreshold = 2.4) : @@ -109,12 +106,12 @@ class CartPole thetaThresholdRadians(thetaThresholdRadians), xThreshold(xThreshold) { } /** - * Dynamics of Cart Pole - * Get reward and next state based on current state and current action - * @param state Current state - * @param action Current action - * @param nextState Next state - * @return reward, it's always 1.0 + * Dynamics of Cart Pole. + * Get reward and next state based on current state and current action. + * @param state Current state. + * @param action Current action. + * @param nextState Next state. + * @return reward, it's always 1.0. */ double Sample(const State& state, const Action& action, State& nextState) const { @@ -134,11 +131,11 @@ class CartPole } /** - * Dynamics of Cart Pole - * Get reward based on current state and current action - * @param state Current state - * @param action Current action - * @return reward, it's always 1.0 + * Dynamics of Cart Pole. + * Get reward based on current state and current action. + * @param state Current state. + * @param action Current action. + * @return reward, it's always 1.0. */ double Sample(const State& state, const Action& action) const { @@ -147,15 +144,18 @@ class CartPole } /** - * Initial state representation is randomly generated within [-0.05, 0.05] - * @return Initial state for each episode + * Initial state representation is randomly generated within [-0.05, 0.05]. + * @return Initial state for each episode. */ - State InitialSample() const { return State((arma::randu(4) - 0.5) / 10.0); } + State InitialSample() const + { + return State((arma::randu(4) - 0.5) / 10.0); + } /** - * Whether given state is terminal state - * @param state desired state - * @return true if @state is terminal state, otherwise false + * Whether given state is terminal state. + * @param state desired state. + * @return true if @state is terminal state, otherwise false. */ bool IsTerminal(const State& state) const { @@ -164,34 +164,34 @@ class CartPole } private: - //! Locally-stored gravity + //! Locally-stored gravity. double gravity; - //! Locally-stored mass of the cart + //! Locally-stored mass of the cart. double massCart; - //! Locally-stored mass of the pole + //! Locally-stored mass of the pole. double massPole; - //! Locally-stored total mass + //! Locally-stored total mass. double totalMass; - //! Locally-stored length of the pole + //! Locally-stored length of the pole. double length; - //! Locally-stored moment of pole + //! Locally-stored moment of pole. double poleMassLength; - //! Locally-stored magnitude of the applied force + //! Locally-stored magnitude of the applied force. double forceMag; - //! Locally-stored time interval + //! Locally-stored time interval. double tau; - //! Locally-stored maximum angle + //! Locally-stored maximum angle. double thetaThresholdRadians; - //! Locally-stored maximum position + //! Locally-stored maximum position. double xThreshold; }; diff --git a/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp b/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp index edd5c12660a..05c4394ee4d 100644 --- a/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp +++ b/src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp @@ -2,10 +2,10 @@ * @file mountain_car.hpp * @author Shangtong Zhang * - * This file is an implementation of Mountain Car task + * This file is an implementation of Mountain Car task: * https://gym.openai.com/envs/MountainCar-v0 * - * TODO: refactor to OpenAI interface + * TODO: provide an option to use dynamics directly from OpenAI gym. * * mlpack is free software; you may redistribute it and/or modify it under the * terms of the 3-clause BSD license. You should have received a copy of the @@ -22,56 +22,53 @@ namespace mlpack { namespace rl { /** - * Implementation of Mountain Car task + * Implementation of Mountain Car task. */ class MountainCar { public: /** - * Implementation of state of Mountain Car - * Each state is a (velocity, position) pair + * Implementation of state of Mountain Car. + * Each state is a (velocity, position) pair. */ class State { public: - //! Construct a state instance + //! Construct a state instance. State(): data(2, arma::fill::zeros) { } /** - * Construct a state based on given data - * @param data desired internal data + * Construct a state based on given data. + * @param data desired internal data. */ State(const arma::colvec& data): data(data) { } - //! Encode the state to a column vector + //! Encode the state to a column vector. const arma::colvec& Encode() const { return data; } - /** - * Set the internal data to given value - * @param data desired internal data - */ - void Set(const arma::colvec& data) { this->data = data; } + //! Modify the internal representation of the state. + arma::colvec& Data() { return data; } - //! Get velocity + //! Get velocity. double Velocity() const { return data[0]; } - //! Modify velocity + //! Modify velocity. double& Velocity() { return data[0]; } - //! Get position + //! Get position. double Position() const { return data[1]; } - //! Modify position + //! Modify position. double& Position() { return data[1]; } private: - //! Locally-stored velocity and position + //! Locally-stored velocity and position. arma::colvec data; }; /** - * Implementation of action of Mountain Car + * Implementation of action of Mountain Car. */ enum Action { @@ -79,27 +76,27 @@ class MountainCar stop, forward, - // Track the size of the action space + //! Track the size of the action space. size }; /** - * Construct a Mountain Car instance - * @param positionMin minimum legal position - * @param positionMax maximum legal position - * @param velocityMin minimum legal velocity - * @param velocityMax maximum legal velocity + * Construct a Mountain Car instance. + * @param positionMin minimum legal position. + * @param positionMax maximum legal position. + * @param velocityMin minimum legal velocity. + * @param velocityMax maximum legal velocity. */ MountainCar(double positionMin = -1.2, double positionMax = 0.5, double velocityMin = -0.07, double velocityMax = 0.07): positionMin(positionMin), positionMax(positionMax), velocityMin(velocityMin), velocityMax(velocityMax) { } /** - * Dynamics of Mountain Car - * Get reward and next state based on current state and current action - * @param state Current state - * @param action Current action - * @param nextState Next state - * @return reward, it's always -1.0 + * Dynamics of Mountain Car. + * Get reward and next state based on current state and current action. + * @param state Current state. + * @param action Current action. + * @param nextState Next state. + * @return reward, it's always -1.0. */ double Sample(const State& state, const Action& action, State& nextState) const { @@ -119,11 +116,11 @@ class MountainCar } /** - * Dynamics of Mountain Car - * Get reward based on current state and current action - * @param state Current state - * @param action Current action - * @return reward, it's always -1.0 + * Dynamics of Mountain Car. + * Get reward based on current state and current action. + * @param state Current state. + * @param action Current action. + * @return reward, it's always -1.0. */ double Sample(const State& state, const Action& action) const { @@ -132,9 +129,9 @@ class MountainCar } /** - * Initial position is randomly generated within [-0.6, -0.4] - * Initial velocity is 0 - * @return Initial state for each episode + * Initial position is randomly generated within [-0.6, -0.4]. + * Initial velocity is 0. + * @return Initial state for each episode. */ State InitialSample() const { @@ -145,23 +142,26 @@ class MountainCar } /** - * Whether given state is terminal state - * @param state desired state - * @return true if @state is terminal state, otherwise false + * Whether given state is terminal state. + * @param state desired state. + * @return true if @state is terminal state, otherwise false. */ - bool IsTerminal(const State& state) const { return std::abs(state.Position() - positionMax) <= 1e-5; } + bool IsTerminal(const State& state) const + { + return std::abs(state.Position() - positionMax) <= 1e-5; + } private: - //! Locally-stored minimum legal position + //! Locally-stored minimum legal position. double positionMin; - //! Locally-stored maximum legal position + //! Locally-stored maximum legal position. double positionMax; - //! Locally-stored minimum legal velocity + //! Locally-stored minimum legal velocity. double velocityMin; - //! Locally-stored maximum legal velocity + //! Locally-stored maximum legal velocity. double velocityMax; }; diff --git a/src/mlpack/tests/rl_environment_test.cpp b/src/mlpack/tests/rl_environment_test.cpp index c2f64315c2b..09f106a46d8 100644 --- a/src/mlpack/tests/rl_environment_test.cpp +++ b/src/mlpack/tests/rl_environment_test.cpp @@ -2,7 +2,7 @@ * @file rl_environment_test.hpp * @author Shangtong Zhang * - * Basic test for the reinforcement learning task environment + * Basic test for the reinforcement learning task environment. * * mlpack is free software; you may redistribute it and/or modify it under the * terms of the 3-clause BSD license. You should have received a copy of the @@ -23,22 +23,22 @@ using namespace mlpack::rl; BOOST_AUTO_TEST_SUITE(RLEnvironmentTest) -BOOST_AUTO_TEST_CASE(MountainCarTest) +BOOST_AUTO_TEST_CASE(SimpleMountainCarTest) { - const auto task = MountainCar(); - auto state = task.InitialSample(); - auto action = MountainCar::Action::backward; - auto reward = task.Sample(state, action); + const MountainCar task = MountainCar(); + MountainCar::State state = task.InitialSample(); + MountainCar::Action action = MountainCar::Action::backward; + double reward = task.Sample(state, action); BOOST_REQUIRE(!task.IsTerminal(state)); BOOST_REQUIRE_EQUAL(3, MountainCar::Action::size); } -BOOST_AUTO_TEST_CASE(CartPoleTest) +BOOST_AUTO_TEST_CASE(SimpleCartPoleTest) { - const auto task = CartPole(); - auto state = task.InitialSample(); - auto action = CartPole::Action::backward; - auto reward = task.Sample(state, action); + const CartPole task = CartPole(); + CartPole::State state = task.InitialSample(); + CartPole::Action action = CartPole::Action::backward; + double reward = task.Sample(state, action); BOOST_REQUIRE(!task.IsTerminal(state)); BOOST_REQUIRE_EQUAL(2, CartPole::Action::size); }