Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement two classical control problems for testing reinforcement learning method #989

Merged
merged 6 commits into from May 9, 2017
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -3,3 +3,5 @@ xcode*
.DS_Store
src/mlpack/core/util/gitversion.hpp
src/mlpack/core/util/arma_config.hpp
.idea
cmake-build-*
@@ -0,0 +1,15 @@
# Define the files we need to compile
# Anything not in this list will not be compiled into mlpack.
set(SOURCES
mountain_car.hpp
cart_pole.hpp
)

# Add directory name to sources.
set(DIR_SRCS)
foreach(file ${SOURCES})
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
endforeach()
# Append sources (with directory name) to list of all mlpack sources (used at
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
201 changes: 201 additions & 0 deletions src/mlpack/methods/reinforcement_learning/environment/cart_pole.hpp
@@ -0,0 +1,201 @@
/**
* @file cart_pole.hpp
* @author Shangtong Zhang
*
* This file is an implementation of Cart Pole task
* https://gym.openai.com/envs/CartPole-v0
*
* TODO: refactor to OpenAI interface
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
#define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
namespace rl {

/**
* Implementation of Cart Pole task
*/
class CartPole
{
public:

/**
* Implementation of state of Cart Pole
* Each state is a tuple of (position, velocity, angle, angular velocity)
*/
class State
{
public:
//! Construct a state instance
State() : data(4) { }

//! Construct a state instance from given data
State(const arma::colvec& data) : data(data) { }

/**
* Set the internal data to given value
* @param data desired internal data
*/
void Set(const arma::colvec& data) { this->data = data; }

//! Get position
double Position() const { return data[0]; }

//! Modify position
double& Position() { return data[0]; }

//! Get velocity
double Velocity() const { return data[1]; }

//! Modify velocity
double& Velocity() { return data[1]; }

//! Get angle
double Angle() const { return data[2]; }

//! Modify angle
double& Angle() { return data[2]; }

//! Get angular velocity
double AngularVelocity() const { return data[3]; }

//! Modify angular velocity
double& AngularVelocity() { return data[3]; }

//! Encode the state to a column vector
const arma::colvec& Encode() const { return data; }

private:
//! Locally-stored (position, velocity, angle, angular velocity)
arma::colvec data;
};

/**
* Implementation of action of Cart Pole
*/
enum Action
{
backward,
forward,

// Track the size of the action space
size
};

/**
* Construct a Cart Pole instance
* @param gravity gravity
* @param massCart mass of the cart
* @param massPole mass of the pole
* @param length length of the pole
* @param forceMag magnitude of the applied force
* @param tau time interval
* @param thetaThresholdRadians maximum angle
* @param xThreshold maximum position
*/
CartPole(double gravity = 9.8, double massCart = 1.0, double massPole = 0.1, double length = 0.5, double forceMag = 10.0,
double tau = 0.02, double thetaThresholdRadians = 12 * 2 * 3.1416 / 360, double xThreshold = 2.4) :
gravity(gravity), massCart(massCart), massPole(massPole), totalMass(massCart + massPole),
length(length), poleMassLength(massPole * length), forceMag(forceMag), tau(tau),
thetaThresholdRadians(thetaThresholdRadians), xThreshold(xThreshold) { }

/**
* Dynamics of Cart Pole
* Get reward and next state based on current state and current action
* @param state Current state
* @param action Current action
* @param nextState Next state
* @return reward, it's always 1.0
*/
double Sample(const State& state, const Action& action, State& nextState) const
{
double force = action ? forceMag : -forceMag;
double cosTheta = std::cos(state.Angle());
double sinTheta = std::sin(state.Angle());
double temp = (force + poleMassLength * state.AngularVelocity() * state.AngularVelocity() * sinTheta) / totalMass;
double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
(length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
nextState.Position() = state.Position() + tau * state.Velocity();
nextState.Velocity() = state.Velocity() + tau * xAcc;
nextState.Angle() = state.Angle() + tau * state.AngularVelocity();
nextState.AngularVelocity() = state.AngularVelocity() + tau * thetaAcc;

return 1.0;
}

/**
* Dynamics of Cart Pole
* Get reward based on current state and current action
* @param state Current state
* @param action Current action
* @return reward, it's always 1.0
*/
double Sample(const State& state, const Action& action) const
{
State nextState;
return Sample(state, action, nextState);
}

/**
* Initial state representation is randomly generated within [-0.05, 0.05]
* @return Initial state for each episode
*/
State InitialSample() const { return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0); }

/**
* Whether given state is terminal state
* @param state desired state
* @return true if @state is terminal state, otherwise false
*/
bool IsTerminal(const State& state) const
{
return std::abs(state.Position()) > xThreshold ||
std::abs(state.Angle()) > thetaThresholdRadians;
}

private:
//! Locally-stored gravity
double gravity;

//! Locally-stored mass of the cart
double massCart;

//! Locally-stored mass of the pole
double massPole;

//! Locally-stored total mass
double totalMass;

//! Locally-stored length of the pole
double length;

//! Locally-stored moment of pole
double poleMassLength;

//! Locally-stored magnitude of the applied force
double forceMag;

//! Locally-stored time interval
double tau;

//! Locally-stored maximum angle
double thetaThresholdRadians;

//! Locally-stored maximum position
double xThreshold;
};

} // namespace rl
} // namespace mlpack

#endif
172 changes: 172 additions & 0 deletions src/mlpack/methods/reinforcement_learning/environment/mountain_car.hpp
@@ -0,0 +1,172 @@
/**
* @file mountain_car.hpp
* @author Shangtong Zhang
*
* This file is an implementation of Mountain Car task
* https://gym.openai.com/envs/MountainCar-v0
*
* TODO: refactor to OpenAI interface
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that we should fully compliant with the OpenAI Interface. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh maybe this comment is kind of misleading, I didn't mean to have same API as OpenAI gym. I just want to provide an alternative implementation, i.e. the logic of Sample can also be done via OpenAI gym if user specifies this.

*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/

#ifndef MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP
#define MLPACK_METHODS_RL_ENVIRONMENT_MOUNTAIN_CAR_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
namespace rl {

/**
* Implementation of Mountain Car task
*/
class MountainCar
{
public:

/**
* Implementation of state of Mountain Car
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Picky format issue, please use proper punctuation for all comments and parameter descriptions.

* Each state is a (velocity, position) pair
*/
class State
{
public:
//! Construct a state instance
State(): data(2, arma::fill::zeros) { }

/**
* Construct a state based on given data
* @param data desired internal data
*/
State(const arma::colvec& data): data(data) { }

//! Encode the state to a column vector
const arma::colvec& Encode() const { return data; }

/**
* Set the internal data to given value
* @param data desired internal data
*/
void Set(const arma::colvec& data) { this->data = data; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use setter/getter functions, I guess a simple solution would be to use arma::colvec& Data() { return data; } here, and do: task.Data() = data;.


//! Get velocity
double Velocity() const { return data[0]; }

//! Modify velocity
double& Velocity() { return data[0]; }

//! Get position
double Position() const { return data[1]; }

//! Modify position
double& Position() { return data[1]; }

private:
//! Locally-stored velocity and position
arma::colvec data;
};

/**
* Implementation of action of Mountain Car
*/
enum Action
{
backward,
stop,
forward,

// Track the size of the action space
size
};

/**
* Construct a Mountain Car instance
* @param positionMin minimum legal position
* @param positionMax maximum legal position
* @param velocityMin minimum legal velocity
* @param velocityMax maximum legal velocity
*/
MountainCar(double positionMin = -1.2, double positionMax = 0.5, double velocityMin = -0.07, double velocityMax = 0.07):
positionMin(positionMin), positionMax(positionMax), velocityMin(velocityMin), velocityMax(velocityMax) { }

/**
* Dynamics of Mountain Car
* Get reward and next state based on current state and current action
* @param state Current state
* @param action Current action
* @param nextState Next state
* @return reward, it's always -1.0
*/
double Sample(const State& state, const Action& action, State& nextState) const
{
int direction = action - 1;
nextState.Velocity() = state.Velocity() + 0.001 * direction - 0.0025 * std::cos(3 * state.Position());
nextState.Velocity() = std::min(std::max(nextState.Velocity(), velocityMin), velocityMax);

nextState.Position() = state.Position() + nextState.Velocity();
nextState.Position() = std::min(std::max(nextState.Position(), positionMin), positionMax);

if (std::abs(nextState.Position() - positionMin) <= 1e-5)
{
nextState.Velocity() = 0.0;
}

return -1.0;
}

/**
* Dynamics of Mountain Car
* Get reward based on current state and current action
* @param state Current state
* @param action Current action
* @return reward, it's always -1.0
*/
double Sample(const State& state, const Action& action) const
{
State nextState;
return Sample(state, action, nextState);
}

/**
* Initial position is randomly generated within [-0.6, -0.4]
* Initial velocity is 0
* @return Initial state for each episode
*/
State InitialSample() const
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could avoid a copy if we pass the state by reference if the State is complex this could speed things up, but I'm not sure that is necessary here. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be optimized by Return Value Optimization even with -o0 unless user specifies -fno-elide-constructors. In the other hand, by using const State & s = InitialSample() user can explicitly extend the lifetime of the local variable, so there is also no overhead. If we refactor it to const State& InitialSample() const, it seems to be an undefined behavior per this and clang does gives a warning for it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your thoughts and clarification; besides some minor style issues, I think this is ready to be merged. I'll go and fix the issues once I merge this in.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's cool! Thanks for your review!

{
State state;
state.Velocity() = 0.0;
state.Position() = arma::as_scalar(arma::randu(1)) * 0.2 - 0.6;
return state;
}

/**
* Whether given state is terminal state
* @param state desired state
* @return true if @state is terminal state, otherwise false
*/
bool IsTerminal(const State& state) const { return std::abs(state.Position() - positionMax) <= 1e-5; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another picky format issue: "Lines should be no more than 80 characters wide." for more informations take a look at the DesignGuidelines (https://github.com/mlpack/mlpack/wiki/DesignGuidelines).


private:
//! Locally-stored minimum legal position
double positionMin;

//! Locally-stored maximum legal position
double positionMax;

//! Locally-stored minimum legal velocity
double velocityMin;

//! Locally-stored maximum legal velocity
double velocityMax;

};

} // namespace rl
} // namespace mlpack

#endif