Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add epsilon greedy policy for DQN #1012

Merged
merged 2 commits into from
May 30, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/mlpack/methods/reinforcement_learning/policy/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Define the files we need to compile
# Anything not in this list will not be compiled into mlpack.
set(SOURCES
greedy_policy.hpp
)

# Add directory name to sources.
set(DIR_SRCS)
foreach(file ${SOURCES})
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
endforeach()
# Append sources (with directory name) to list of all mlpack sources (used at
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
103 changes: 103 additions & 0 deletions src/mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/**
* @file greedy_policy.hpp
* @author Shangtong Zhang
*
* This file is an implementation of epsilon greedy policy.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_RL_POLICY_GREEDY_POLICY_HPP
#define MLPACK_METHODS_RL_POLICY_GREEDY_POLICY_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
namespace rl {

/**
* Implementation for epsilon greedy policy.
*
* In general we will select an action greedily based on the action value,
* however under sometimes we will also randomly select an action to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove under here.

* encourage exploration.
*
* @tparam EnvironmentType The reinforcement learning task.
*/
template <typename EnvironmentType>
class GreedyPolicy {
public:
using ActionType = typename EnvironmentType::Action;

/**
* Constructor for epsilon greedy policy class.
* @param initialEpsilon The initial probability to explore (select a random action).
* @param annealInterval The steps during which the probability to explore will anneal.
* @param minEpsilon Epsilon will never be less than this value.
*/
GreedyPolicy(double initialEpsilon,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use constfor the parameter?

size_t annealInterval,
double minEpsilon) :
epsilon(initialEpsilon),
minEpsilon(minEpsilon),
delta((initialEpsilon - minEpsilon) / annealInterval)
{ /* Nothing to do here. */ }

/**
* Sample an action based on given action values.
* @param actionValue Values for each action.
* @return Sampled action
*/
ActionType Sample(const arma::colvec& actionValue)
{
double exploration = math::Random();

// Select the action randomly.
if (exploration < epsilon)
return static_cast<ActionType>(math::RandInt(ActionType::size));

// Select the action greedily.
size_t bestAction = 0;
double maxActionValue = actionValue[0];
for (size_t action = 1; action < ActionType::size; ++action)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, do you think we should backport the index_max function, I guess we could also do arma::as_scalar(arma::find(actionValue.max(), 1)). Let me know what you think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was considering this too. It's more concise but costs twice time. But the performance here may not matter. I will do it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, performance wise it's probably not the best decision, I don't mind to keep the existing code, so we can go with the solution you like the most.

{
if (maxActionValue < actionValue[action])
{
maxActionValue = actionValue[action];
bestAction = action;
}
}
return static_cast<ActionType>(bestAction);
};

/**
* Exploration probability will anneal at each step.
*/
void Anneal()
{
epsilon -= delta;
epsilon = std::max(minEpsilon, epsilon);
}

/**
* @return Current possibility to explore.
*/
const double& Epsilon() const { return epsilon; }

private:
//! Locally-stored probability to explore.
double epsilon;

//! Locally-stored lower bound for epsilon.
double minEpsilon;

//! Locally-stored stride for epsilon to anneal.
double delta;
};

} // namespace rl
} // namespace mlpack

#endif
23 changes: 20 additions & 3 deletions src/mlpack/tests/rl_components_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <mlpack/methods/reinforcement_learning/environment/mountain_car.hpp>
#include <mlpack/methods/reinforcement_learning/environment/cart_pole.hpp>
#include <mlpack/methods/reinforcement_learning/replay/random_replay.hpp>
#include <mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp>

#include <boost/test/unit_test.hpp>
#include "test_tools.hpp"
Expand Down Expand Up @@ -87,18 +88,34 @@ BOOST_AUTO_TEST_CASE(RandomReplayTest)
BOOST_REQUIRE_EQUAL(1, replay.Size());

//! Overwrite the memory with a nonsense record
for (size_t i = 0; i < 5; ++i) {
for (size_t i = 0; i < 5; ++i)
replay.Store(nextState, action, reward, state, true);
}

BOOST_REQUIRE_EQUAL(3, replay.Size());

//! Sample several times, the original record shouldn't appear
for (size_t i = 0; i < 30; ++i) {
for (size_t i = 0; i < 30; ++i)
{
replay.Sample(sampledState, sampledAction, sampledReward, sampledNextState, sampledTerminal);
CheckMatrices(state.Encode(), sampledNextState);
CheckMatrices(nextState.Encode(), sampledState);
BOOST_REQUIRE_EQUAL(true, arma::as_scalar(sampledTerminal));
}
}

/**
* Construct a greedy policy instance and check if it works as
* it should be.
*/
BOOST_AUTO_TEST_CASE(GreedyPolicyTest)
{
GreedyPolicy<CartPole> policy(1.0, 10, 0.0);
for (int i = 0; i < 15; ++i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind to use size_t instead of int here, I know it is kind of a pedantic request.

policy.Anneal();
BOOST_REQUIRE_CLOSE(0.0, policy.Epsilon(), 1e-5);
arma::colvec actionValue = arma::randn<arma::colvec>(CartPole::Action::size);
CartPole::Action action = policy.Sample(actionValue);
BOOST_REQUIRE_CLOSE(actionValue[action], actionValue.max(), 1e-5);
}

BOOST_AUTO_TEST_SUITE_END()