Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async n-step q-learning and one step sarsa #1084

Merged
merged 4 commits into from Aug 23, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 86 additions & 2 deletions src/mlpack/methods/reinforcement_learning/async_learning.hpp
Expand Up @@ -16,6 +16,8 @@

#include <mlpack/prereqs.hpp>
#include "worker/one_step_q_learning_worker.hpp"
#include "worker/one_step_sarsa_worker.hpp"
#include "worker/n_step_q_learning_worker.hpp"
#include "training_config.hpp"

namespace mlpack {
Expand Down Expand Up @@ -128,7 +130,14 @@ class AsyncLearning
EnvironmentType environment;
};

// Forward declaration.
/**
* Forward declaration of OneStepQLearningWorker.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
typename NetworkType,
Expand All @@ -137,7 +146,46 @@ template <
>
class OneStepQLearningWorker;

// Convenient typedef for async one step q-learning.
/**
* Forward declaration of OneStepSarsaWorker.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on the template parameter?

typename NetworkType,
typename UpdaterType,
typename PolicyType
>
class OneStepSarsaWorker;

/**
* Forward declaration of NStepQLearningWorker.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
class NStepQLearningWorker;

/**
* Convenient typedef for async one step q-learning.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
typename NetworkType,
Expand All @@ -148,6 +196,42 @@ using OneStepQLearning = AsyncLearning<OneStepQLearningWorker<EnvironmentType,
NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
UpdaterType, PolicyType>;

/**
* Convenient typedef for async one step Sarsa.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
using OneStepSarsa = AsyncLearning<OneStepSarsaWorker<EnvironmentType,
NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
UpdaterType, PolicyType>;

/**
* Convenient typedef for async n step q-learning.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy.
*/
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
using NStepQLearning = AsyncLearning<NStepQLearningWorker<EnvironmentType,
NetworkType, UpdaterType, PolicyType>, EnvironmentType, NetworkType,
UpdaterType, PolicyType>;

} // namespace rl
} // namespace mlpack

Expand Down
Expand Up @@ -2,6 +2,8 @@
# Anything not in this list will not be compiled into mlpack.
set(SOURCES
one_step_q_learning_worker.hpp
one_step_sarsa_worker.hpp
n_step_q_learning_worker.hpp
)

# Add directory name to sources.
Expand Down
@@ -0,0 +1,236 @@
/**
* @file n_step_q_learning_worker.hpp
* @author Shangtong Zhang
*
* This file is the definition of NStepQLearningWorker class,
* which implements an episode for async n step Q-Learning algorithm.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_RL_WORKER_N_STEP_Q_LEARNING_WORKER_HPP
#define MLPACK_METHODS_RL_WORKER_N_STEP_Q_LEARNING_WORKER_HPP

#include <mlpack/methods/reinforcement_learning/training_config.hpp>

namespace mlpack {
namespace rl {

/**
* N step Q-Learning worker.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy. *
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove the * at the end?

*/
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
class NStepQLearningWorker
{
public:
using StateType = typename EnvironmentType::State;
using ActionType = typename EnvironmentType::Action;
using TransitionType = std::tuple<StateType, ActionType, double, StateType>;

/**
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind to add a method description here; something like should do:

Construct N step Q-Learning worker with the given parameters and environment.

* @param updater The optimizer.
* @param environment The reinforcement learning task.
* @param config Hyper-parameters.
* @param deterministic Whether it should be deterministic.
*/
NStepQLearningWorker(
const UpdaterType& updater,
const EnvironmentType& environment,
const TrainingConfig& config,
bool deterministic):
updater(updater),
environment(environment),
config(config),
deterministic(deterministic),
pending(config.UpdateInterval())
{ reset(); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use Upper camel casing for all method names?


/**
* Initialize the worker.
* @param learningNetwork The shared network.
*/
void Initialize(NetworkType& learningNetwork)
{
updater.Initialize(learningNetwork.Parameters().n_rows,
learningNetwork.Parameters().n_cols);
// Build local network.
network = learningNetwork;
}

/**
* The agent will execute one step.
*
* @param learningNetwork The shared learning network.
* @param targetNetwork The shared target network.
* @param totalSteps The shared counter for total steps.
* @param policy The shared behavior policy.
* @param totalReward This will be the episode return if the episode ends
* after this step. Otherwise this is invalid.
* @return Indicate whether current episode ends after this step.
*/
bool Step(NetworkType& learningNetwork,
NetworkType& targetNetwork,
size_t& totalSteps,
PolicyType& policy,
double& totalReward)
{
// Interact with the environment.
arma::colvec actionValue;
network.Predict(state.Encode(), actionValue);
ActionType action = policy.Sample(actionValue, deterministic);
StateType nextState;
double reward = environment.Sample(state, action, nextState);
bool terminal = environment.IsTerminal(nextState);

episodeReturn += reward;
steps++;

terminal = terminal || steps >= config.StepLimit();
if (deterministic)
{
if (terminal)
{
totalReward = episodeReturn;
reset();
// Sync with latest learning network.
network = learningNetwork;
return true;
}
state = nextState;
return false;
}

#pragma omp atomic
totalSteps++;

pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
pendingIndex++;

if (terminal || pendingIndex >= config.UpdateInterval())
{
// Initialize the gradient storage.
arma::mat totalGradients(learningNetwork.Parameters().n_rows,
learningNetwork.Parameters().n_cols);

// Bootstrap from the value of next state.
arma::colvec actionValue;
double target = 0;
if (!terminal)
{
#pragma omp critical
{ targetNetwork.Predict(nextState.Encode(), actionValue); };
target = actionValue.max();
}

// Update in reverse order.
for (int i = pending.size() - 1; i >= 0; --i)
{
TransitionType &transition = pending[i];
target = config.Discount() * target + std::get<2>(transition);

// Compute the training target for current state.
network.Forward(std::get<0>(transition).Encode(), actionValue);
actionValue[std::get<1>(transition)] = target;

// Compute gradient.
arma::mat gradients;
network.Backward(actionValue, gradients);

// Accumulate gradients.
totalGradients += gradients;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should initialize totalGradients with zero.

}

// Clamp the accumulated gradients.
totalGradients.transform(
[&](double gradient)
{ return std::min(std::max(gradient, -config.GradientLimit()),
config.GradientLimit()); });

// Perform async update of the global network.
updater.Update(learningNetwork.Parameters(),
config.StepSize(), totalGradients);

// Sync the local network with the global network.
network = learningNetwork;

pendingIndex = 0;
}

// Update global target network.
if (totalSteps % config.TargetNetworkSyncInterval() == 0)
{
#pragma omp critical
{ targetNetwork = learningNetwork; }
}

policy.Anneal();

if (terminal)
{
totalReward = episodeReturn;
reset();
return true;
}
state = nextState;
return false;
}

private:
/**
* Reset the worker for a new episdoe.
*/
void reset()
{
steps = 0;
episodeReturn = 0;
pendingIndex = 0;
state = environment.InitialSample();
}

//! Locally-stored optimizer.
UpdaterType updater;

//! Locally-stored task.
EnvironmentType environment;

//! Locally-stored hyper-parameters.
TrainingConfig config;

//! Whether this episode is deterministic or not.
bool deterministic;

//! Total steps in current episode.
size_t steps;

//! Total reward in current episode.
double episodeReturn;

//! Buffer for delayed update.
std::vector<TransitionType> pending;

//! Current position of the buffer.
size_t pendingIndex;

//! Local network of the worker.
NetworkType network;

//! Current state of the agent.
StateType state;
};

} // namespace rl
} // namespace mlpack

#endif