New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement two classical control problems for testing reinforcement learning method #989
Changes from 5 commits
099faed
7e9fdea
a41cda1
e981050
c293016
1acd5d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,5 @@ xcode* | |
.DS_Store | ||
src/mlpack/core/util/gitversion.hpp | ||
src/mlpack/core/util/arma_config.hpp | ||
.idea | ||
cmake-build-* |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Define the files we need to compile | ||
# Anything not in this list will not be compiled into mlpack. | ||
set(SOURCES | ||
mountain_car.hpp | ||
cart_pole.hpp | ||
) | ||
|
||
# Add directory name to sources. | ||
set(DIR_SRCS) | ||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
# Append sources (with directory name) to list of all mlpack sources (used at | ||
# the parent scope). | ||
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
/** | ||
* @file cart_pole.hpp | ||
* @author Shangtong Zhang | ||
* | ||
* This file is an implementation of Cart Pole task | ||
* https://gym.openai.com/envs/CartPole-v0 | ||
* | ||
* TODO: refactor to OpenAI interface | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP | ||
#define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
namespace mlpack { | ||
namespace rl { | ||
|
||
/** | ||
* Implementation of Cart Pole task | ||
*/ | ||
class CartPole | ||
{ | ||
public: | ||
|
||
/** | ||
* Implementation of state of Cart Pole | ||
* Each state is a tuple of (position, velocity, angle, angular velocity) | ||
*/ | ||
class State | ||
{ | ||
public: | ||
//! Construct a state instance | ||
State() : data(4) { } | ||
|
||
//! Construct a state instance from given data | ||
State(const arma::colvec& data) : data(data) { } | ||
|
||
/** | ||
* Set the internal data to given value | ||
* @param data desired internal data | ||
*/ | ||
void Set(const arma::colvec& data) { this->data = data; } | ||
|
||
//! Get position | ||
double Position() const { return data[0]; } | ||
|
||
//! Modify position | ||
double& Position() { return data[0]; } | ||
|
||
//! Get velocity | ||
double Velocity() const { return data[1]; } | ||
|
||
//! Modify velocity | ||
double& Velocity() { return data[1]; } | ||
|
||
//! Get angle | ||
double Angle() const { return data[2]; } | ||
|
||
//! Modify angle | ||
double& Angle() { return data[2]; } | ||
|
||
//! Get angular velocity | ||
double AngularVelocity() const { return data[3]; } | ||
|
||
//! Modify angular velocity | ||
double& AngularVelocity() { return data[3]; } | ||
|
||
//! Encode the state to a column vector | ||
const arma::colvec& Encode() const { return data; } | ||
|
||
private: | ||
//! Locally-stored (position, velocity, angle, angular velocity) | ||
arma::colvec data; | ||
}; | ||
|
||
/** | ||
* Implementation of action of Cart Pole | ||
*/ | ||
enum Action | ||
{ | ||
backward, | ||
forward, | ||
|
||
// Track the size of the action space | ||
size | ||
}; | ||
|
||
/** | ||
* Construct a Cart Pole instance | ||
* @param gravity gravity | ||
* @param massCart mass of the cart | ||
* @param massPole mass of the pole | ||
* @param length length of the pole | ||
* @param forceMag magnitude of the applied force | ||
* @param tau time interval | ||
* @param thetaThresholdRadians maximum angle | ||
* @param xThreshold maximum position | ||
*/ | ||
CartPole(double gravity = 9.8, double massCart = 1.0, double massPole = 0.1, double length = 0.5, double forceMag = 10.0, | ||
double tau = 0.02, double thetaThresholdRadians = 12 * 2 * 3.1416 / 360, double xThreshold = 2.4) : | ||
gravity(gravity), massCart(massCart), massPole(massPole), totalMass(massCart + massPole), | ||
length(length), poleMassLength(massPole * length), forceMag(forceMag), tau(tau), | ||
thetaThresholdRadians(thetaThresholdRadians), xThreshold(xThreshold) { } | ||
|
||
/** | ||
* Dynamics of Cart Pole | ||
* Get reward and next state based on current state and current action | ||
* @param state Current state | ||
* @param action Current action | ||
* @param nextState Next state | ||
* @return reward, it's always 1.0 | ||
*/ | ||
double Sample(const State& state, const Action& action, State& nextState) const | ||
{ | ||
double force = action ? forceMag : -forceMag; | ||
double cosTheta = std::cos(state.Angle()); | ||
double sinTheta = std::sin(state.Angle()); | ||
double temp = (force + poleMassLength * state.AngularVelocity() * state.AngularVelocity() * sinTheta) / totalMass; | ||
double thetaAcc = (gravity * sinTheta - cosTheta * temp) / | ||
(length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass)); | ||
double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass; | ||
nextState.Position() = state.Position() + tau * state.Velocity(); | ||
nextState.Velocity() = state.Velocity() + tau * xAcc; | ||
nextState.Angle() = state.Angle() + tau * state.AngularVelocity(); | ||
nextState.AngularVelocity() = state.AngularVelocity() + tau * thetaAcc; | ||
|
||
return 1.0; | ||
} | ||
|
||
/** | ||
* Dynamics of Cart Pole | ||
* Get reward based on current state and current action | ||
* @param state Current state | ||
* @param action Current action | ||
* @return reward, it's always 1.0 | ||
*/ | ||
double Sample(const State& state, const Action& action) const | ||
{ | ||
State nextState; | ||
return Sample(state, action, nextState); | ||
} | ||
|
||
/** | ||
* Initial state representation is randomly generated within [-0.05, 0.05] | ||
* @return Initial state for each episode | ||
*/ | ||
State InitialSample() const { return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0); } | ||
|
||
/** | ||
* Whether given state is terminal state | ||
* @param state desired state | ||
* @return true if @state is terminal state, otherwise false | ||
*/ | ||
bool IsTerminal(const State& state) const | ||
{ | ||
return std::abs(state.Position()) > xThreshold || | ||
std::abs(state.Angle()) > thetaThresholdRadians; | ||
} | ||
|
||
private: | ||
//! Locally-stored gravity | ||
double gravity; | ||
|
||
//! Locally-stored mass of the cart | ||
double massCart; | ||
|
||
//! Locally-stored mass of the pole | ||
double massPole; | ||
|
||
//! Locally-stored total mass | ||
double totalMass; | ||
|
||
//! Locally-stored length of the pole | ||
double length; | ||
|
||
//! Locally-stored moment of pole | ||
double poleMassLength; | ||
|
||
//! Locally-stored magnitude of the applied force | ||
double forceMag; | ||
|
||
//! Locally-stored time interval | ||
double tau; | ||
|
||
//! Locally-stored maximum angle | ||
double thetaThresholdRadians; | ||
|
||
//! Locally-stored maximum position | ||
double xThreshold; | ||
}; | ||
|
||
} // namespace rl | ||
} // namespace mlpack | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/** | ||
* @file mountain_car.hpp | ||
* @author Shangtong Zhang | ||
* | ||
* This file is an implementation of Mountain Car task | ||
* https://gym.openai.com/envs/MountainCar-v0 | ||
* | ||
* TODO: refactor to OpenAI interface | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
|
||
#ifndef MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP | ||
#define MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
namespace mlpack { | ||
namespace rl { | ||
|
||
/** | ||
* Implementation of Mountain Car task | ||
*/ | ||
class MountainCar | ||
{ | ||
public: | ||
|
||
/** | ||
* Implementation of state of Mountain Car | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Picky format issue, please use proper punctuation for all comments and parameter descriptions. |
||
* Each state is a (velocity, position) pair | ||
*/ | ||
class State | ||
{ | ||
public: | ||
//! Construct a state instance | ||
State(): data(2, arma::fill::zeros) { } | ||
|
||
/** | ||
* Construct a state based on given data | ||
* @param data desired internal data | ||
*/ | ||
State(const arma::colvec& data): data(data) { } | ||
|
||
//! Encode the state to a column vector | ||
const arma::colvec& Encode() const { return data; } | ||
|
||
/** | ||
* Set the internal data to given value | ||
* @param data desired internal data | ||
*/ | ||
void Set(const arma::colvec& data) { this->data = data; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't use setter/getter functions, I guess a simple solution would be to use |
||
|
||
//! Get velocity | ||
double Velocity() const { return data[0]; } | ||
|
||
//! Modify velocity | ||
double& Velocity() { return data[0]; } | ||
|
||
//! Get position | ||
double Position() const { return data[1]; } | ||
|
||
//! Modify position | ||
double& Position() { return data[1]; } | ||
|
||
private: | ||
//! Locally-stored velocity and position | ||
arma::colvec data; | ||
}; | ||
|
||
/** | ||
* Implementation of action of Mountain Car | ||
*/ | ||
enum Action | ||
{ | ||
backward, | ||
stop, | ||
forward, | ||
|
||
// Track the size of the action space | ||
size | ||
}; | ||
|
||
/** | ||
* Construct a Mountain Car instance | ||
* @param positionMin minimum legal position | ||
* @param positionMax maximum legal position | ||
* @param velocityMin minimum legal velocity | ||
* @param velocityMax maximum legal velocity | ||
*/ | ||
MountainCar(double positionMin = -1.2, double positionMax = 0.5, double velocityMin = -0.07, double velocityMax = 0.07): | ||
positionMin(positionMin), positionMax(positionMax), velocityMin(velocityMin), velocityMax(velocityMax) { } | ||
|
||
/** | ||
* Dynamics of Mountain Car | ||
* Get reward and next state based on current state and current action | ||
* @param state Current state | ||
* @param action Current action | ||
* @param nextState Next state | ||
* @return reward, it's always -1.0 | ||
*/ | ||
double Sample(const State& state, const Action& action, State& nextState) const | ||
{ | ||
int direction = action - 1; | ||
nextState.Velocity() = state.Velocity() + 0.001 * direction - 0.0025 * std::cos(3 * state.Position()); | ||
nextState.Velocity() = std::min(std::max(nextState.Velocity(), velocityMin), velocityMax); | ||
|
||
nextState.Position() = state.Position() + nextState.Velocity(); | ||
nextState.Position() = std::min(std::max(nextState.Position(), positionMin), positionMax); | ||
|
||
if (std::abs(nextState.Position() - positionMin) <= 1e-5) | ||
{ | ||
nextState.Velocity() = 0.0; | ||
} | ||
|
||
return -1.0; | ||
} | ||
|
||
/** | ||
* Dynamics of Mountain Car | ||
* Get reward based on current state and current action | ||
* @param state Current state | ||
* @param action Current action | ||
* @return reward, it's always -1.0 | ||
*/ | ||
double Sample(const State& state, const Action& action) const | ||
{ | ||
State nextState; | ||
return Sample(state, action, nextState); | ||
} | ||
|
||
/** | ||
* Initial position is randomly generated within [-0.6, -0.4] | ||
* Initial velocity is 0 | ||
* @return Initial state for each episode | ||
*/ | ||
State InitialSample() const | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could avoid a copy if we pass the state by reference if the State is complex this could speed things up, but I'm not sure that is necessary here. What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it will be optimized by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your thoughts and clarification; besides some minor style issues, I think this is ready to be merged. I'll go and fix the issues once I merge this in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's cool! Thanks for your review! |
||
{ | ||
State state; | ||
state.Velocity() = 0.0; | ||
state.Position() = arma::as_scalar(arma::randu(1)) * 0.2 - 0.6; | ||
return state; | ||
} | ||
|
||
/** | ||
* Whether given state is terminal state | ||
* @param state desired state | ||
* @return true if @state is terminal state, otherwise false | ||
*/ | ||
bool IsTerminal(const State& state) const { return std::abs(state.Position() - positionMax) <= 1e-5; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another picky format issue: "Lines should be no more than 80 characters wide." for more informations take a look at the DesignGuidelines (https://github.com/mlpack/mlpack/wiki/DesignGuidelines). |
||
|
||
private: | ||
//! Locally-stored minimum legal position | ||
double positionMin; | ||
|
||
//! Locally-stored maximum legal position | ||
double positionMax; | ||
|
||
//! Locally-stored minimum legal velocity | ||
double velocityMin; | ||
|
||
//! Locally-stored maximum legal velocity | ||
double velocityMax; | ||
|
||
}; | ||
|
||
} // namespace rl | ||
} // namespace mlpack | ||
|
||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that we should fully compliant with the OpenAI Interface. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh maybe this comment is kind of misleading, I didn't mean to have same API as OpenAI gym. I just want to provide an alternative implementation, i.e. the logic of
Sample
can also be done via OpenAI gym if user specifies this.