-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add epsilon greedy policy for DQN #1012
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Define the files we need to compile | ||
# Anything not in this list will not be compiled into mlpack. | ||
set(SOURCES | ||
greedy_policy.hpp | ||
) | ||
|
||
# Add directory name to sources. | ||
set(DIR_SRCS) | ||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
# Append sources (with directory name) to list of all mlpack sources (used at | ||
# the parent scope). | ||
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/** | ||
* @file greedy_policy.hpp | ||
* @author Shangtong Zhang | ||
* | ||
* This file is an implementation of epsilon greedy policy. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_METHODS_RL_POLICY_GREEDY_POLICY_HPP | ||
#define MLPACK_METHODS_RL_POLICY_GREEDY_POLICY_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
namespace mlpack { | ||
namespace rl { | ||
|
||
/** | ||
* Implementation for epsilon greedy policy. | ||
* | ||
* In general we will select an action greedily based on the action value, | ||
* however under sometimes we will also randomly select an action to | ||
* encourage exploration. | ||
* | ||
* @tparam EnvironmentType The reinforcement learning task. | ||
*/ | ||
template <typename EnvironmentType> | ||
class GreedyPolicy { | ||
public: | ||
using ActionType = typename EnvironmentType::Action; | ||
|
||
/** | ||
* Constructor for epsilon greedy policy class. | ||
* @param initialEpsilon The initial probability to explore (select a random action). | ||
* @param annealInterval The steps during which the probability to explore will anneal. | ||
* @param minEpsilon Epsilon will never be less than this value. | ||
*/ | ||
GreedyPolicy(double initialEpsilon, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use |
||
size_t annealInterval, | ||
double minEpsilon) : | ||
epsilon(initialEpsilon), | ||
minEpsilon(minEpsilon), | ||
delta((initialEpsilon - minEpsilon) / annealInterval) | ||
{ /* Nothing to do here. */ } | ||
|
||
/** | ||
* Sample an action based on given action values. | ||
* @param actionValue Values for each action. | ||
* @return Sampled action | ||
*/ | ||
ActionType Sample(const arma::colvec& actionValue) | ||
{ | ||
double exploration = math::Random(); | ||
|
||
// Select the action randomly. | ||
if (exploration < epsilon) | ||
return static_cast<ActionType>(math::RandInt(ActionType::size)); | ||
|
||
// Select the action greedily. | ||
size_t bestAction = 0; | ||
double maxActionValue = actionValue[0]; | ||
for (size_t action = 1; action < ActionType::size; ++action) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm, do you think we should backport the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was considering this too. It's more concise but costs twice time. But the performance here may not matter. I will do it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, performance wise it's probably not the best decision, I don't mind to keep the existing code, so we can go with the solution you like the most. |
||
{ | ||
if (maxActionValue < actionValue[action]) | ||
{ | ||
maxActionValue = actionValue[action]; | ||
bestAction = action; | ||
} | ||
} | ||
return static_cast<ActionType>(bestAction); | ||
}; | ||
|
||
/** | ||
* Exploration probability will anneal at each step. | ||
*/ | ||
void Anneal() | ||
{ | ||
epsilon -= delta; | ||
epsilon = std::max(minEpsilon, epsilon); | ||
} | ||
|
||
/** | ||
* @return Current possibility to explore. | ||
*/ | ||
const double& Epsilon() const { return epsilon; } | ||
|
||
private: | ||
//! Locally-stored probability to explore. | ||
double epsilon; | ||
|
||
//! Locally-stored lower bound for epsilon. | ||
double minEpsilon; | ||
|
||
//! Locally-stored stride for epsilon to anneal. | ||
double delta; | ||
}; | ||
|
||
} // namespace rl | ||
} // namespace mlpack | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
#include <mlpack/methods/reinforcement_learning/environment/mountain_car.hpp> | ||
#include <mlpack/methods/reinforcement_learning/environment/cart_pole.hpp> | ||
#include <mlpack/methods/reinforcement_learning/replay/random_replay.hpp> | ||
#include <mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp> | ||
|
||
#include <boost/test/unit_test.hpp> | ||
#include "test_tools.hpp" | ||
|
@@ -87,18 +88,34 @@ BOOST_AUTO_TEST_CASE(RandomReplayTest) | |
BOOST_REQUIRE_EQUAL(1, replay.Size()); | ||
|
||
//! Overwrite the memory with a nonsense record | ||
for (size_t i = 0; i < 5; ++i) { | ||
for (size_t i = 0; i < 5; ++i) | ||
replay.Store(nextState, action, reward, state, true); | ||
} | ||
|
||
BOOST_REQUIRE_EQUAL(3, replay.Size()); | ||
|
||
//! Sample several times, the original record shouldn't appear | ||
for (size_t i = 0; i < 30; ++i) { | ||
for (size_t i = 0; i < 30; ++i) | ||
{ | ||
replay.Sample(sampledState, sampledAction, sampledReward, sampledNextState, sampledTerminal); | ||
CheckMatrices(state.Encode(), sampledNextState); | ||
CheckMatrices(nextState.Encode(), sampledState); | ||
BOOST_REQUIRE_EQUAL(true, arma::as_scalar(sampledTerminal)); | ||
} | ||
} | ||
|
||
/** | ||
* Construct a greedy policy instance and check if it works as | ||
* it should be. | ||
*/ | ||
BOOST_AUTO_TEST_CASE(GreedyPolicyTest) | ||
{ | ||
GreedyPolicy<CartPole> policy(1.0, 10, 0.0); | ||
for (int i = 0; i < 15; ++i) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mind to use |
||
policy.Anneal(); | ||
BOOST_REQUIRE_CLOSE(0.0, policy.Epsilon(), 1e-5); | ||
arma::colvec actionValue = arma::randn<arma::colvec>(CartPole::Action::size); | ||
CartPole::Action action = policy.Sample(actionValue); | ||
BOOST_REQUIRE_CLOSE(actionValue[action], actionValue.max(), 1e-5); | ||
} | ||
|
||
BOOST_AUTO_TEST_SUITE_END() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove
under
here.