New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Async n-step q-learning and one step sarsa #1084
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
/** | ||
* @file n_step_q_learning_worker.hpp | ||
* @author Shangtong Zhang | ||
* | ||
* This file is the definition of NStepQLearningWorker class, | ||
* which implements an episode for async n step Q-Learning algorithm. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_METHODS_RL_WORKER_N_STEP_Q_LEARNING_WORKER_HPP | ||
#define MLPACK_METHODS_RL_WORKER_N_STEP_Q_LEARNING_WORKER_HPP | ||
|
||
#include <mlpack/methods/reinforcement_learning/training_config.hpp> | ||
|
||
namespace mlpack { | ||
namespace rl { | ||
|
||
/** | ||
* N step Q-Learning worker. | ||
* | ||
* @tparam EnvironmentType The type of the reinforcement learning task. | ||
* @tparam NetworkType The type of the network model. | ||
* @tparam UpdaterType The type of the optimizer. | ||
* @tparam PolicyType The type of the behavior policy. * | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you remove the |
||
*/ | ||
template < | ||
typename EnvironmentType, | ||
typename NetworkType, | ||
typename UpdaterType, | ||
typename PolicyType | ||
> | ||
class NStepQLearningWorker | ||
{ | ||
public: | ||
using StateType = typename EnvironmentType::State; | ||
using ActionType = typename EnvironmentType::Action; | ||
using TransitionType = std::tuple<StateType, ActionType, double, StateType>; | ||
|
||
/** | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mind to add a method description here; something like should do:
|
||
* @param updater The optimizer. | ||
* @param environment The reinforcement learning task. | ||
* @param config Hyper-parameters. | ||
* @param deterministic Whether it should be deterministic. | ||
*/ | ||
NStepQLearningWorker( | ||
const UpdaterType& updater, | ||
const EnvironmentType& environment, | ||
const TrainingConfig& config, | ||
bool deterministic): | ||
updater(updater), | ||
environment(environment), | ||
config(config), | ||
deterministic(deterministic), | ||
pending(config.UpdateInterval()) | ||
{ reset(); } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you use Upper camel casing for all method names? |
||
|
||
/** | ||
* Initialize the worker. | ||
* @param learningNetwork The shared network. | ||
*/ | ||
void Initialize(NetworkType& learningNetwork) | ||
{ | ||
updater.Initialize(learningNetwork.Parameters().n_rows, | ||
learningNetwork.Parameters().n_cols); | ||
// Build local network. | ||
network = learningNetwork; | ||
} | ||
|
||
/** | ||
* The agent will execute one step. | ||
* | ||
* @param learningNetwork The shared learning network. | ||
* @param targetNetwork The shared target network. | ||
* @param totalSteps The shared counter for total steps. | ||
* @param policy The shared behavior policy. | ||
* @param totalReward This will be the episode return if the episode ends | ||
* after this step. Otherwise this is invalid. | ||
* @return Indicate whether current episode ends after this step. | ||
*/ | ||
bool Step(NetworkType& learningNetwork, | ||
NetworkType& targetNetwork, | ||
size_t& totalSteps, | ||
PolicyType& policy, | ||
double& totalReward) | ||
{ | ||
// Interact with the environment. | ||
arma::colvec actionValue; | ||
network.Predict(state.Encode(), actionValue); | ||
ActionType action = policy.Sample(actionValue, deterministic); | ||
StateType nextState; | ||
double reward = environment.Sample(state, action, nextState); | ||
bool terminal = environment.IsTerminal(nextState); | ||
|
||
episodeReturn += reward; | ||
steps++; | ||
|
||
terminal = terminal || steps >= config.StepLimit(); | ||
if (deterministic) | ||
{ | ||
if (terminal) | ||
{ | ||
totalReward = episodeReturn; | ||
reset(); | ||
// Sync with latest learning network. | ||
network = learningNetwork; | ||
return true; | ||
} | ||
state = nextState; | ||
return false; | ||
} | ||
|
||
#pragma omp atomic | ||
totalSteps++; | ||
|
||
pending[pendingIndex] = std::make_tuple(state, action, reward, nextState); | ||
pendingIndex++; | ||
|
||
if (terminal || pendingIndex >= config.UpdateInterval()) | ||
{ | ||
// Initialize the gradient storage. | ||
arma::mat totalGradients(learningNetwork.Parameters().n_rows, | ||
learningNetwork.Parameters().n_cols); | ||
|
||
// Bootstrap from the value of next state. | ||
arma::colvec actionValue; | ||
double target = 0; | ||
if (!terminal) | ||
{ | ||
#pragma omp critical | ||
{ targetNetwork.Predict(nextState.Encode(), actionValue); }; | ||
target = actionValue.max(); | ||
} | ||
|
||
// Update in reverse order. | ||
for (int i = pending.size() - 1; i >= 0; --i) | ||
{ | ||
TransitionType &transition = pending[i]; | ||
target = config.Discount() * target + std::get<2>(transition); | ||
|
||
// Compute the training target for current state. | ||
network.Forward(std::get<0>(transition).Encode(), actionValue); | ||
actionValue[std::get<1>(transition)] = target; | ||
|
||
// Compute gradient. | ||
arma::mat gradients; | ||
network.Backward(actionValue, gradients); | ||
|
||
// Accumulate gradients. | ||
totalGradients += gradients; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should initialize |
||
} | ||
|
||
// Clamp the accumulated gradients. | ||
totalGradients.transform( | ||
[&](double gradient) | ||
{ return std::min(std::max(gradient, -config.GradientLimit()), | ||
config.GradientLimit()); }); | ||
|
||
// Perform async update of the global network. | ||
updater.Update(learningNetwork.Parameters(), | ||
config.StepSize(), totalGradients); | ||
|
||
// Sync the local network with the global network. | ||
network = learningNetwork; | ||
|
||
pendingIndex = 0; | ||
} | ||
|
||
// Update global target network. | ||
if (totalSteps % config.TargetNetworkSyncInterval() == 0) | ||
{ | ||
#pragma omp critical | ||
{ targetNetwork = learningNetwork; } | ||
} | ||
|
||
policy.Anneal(); | ||
|
||
if (terminal) | ||
{ | ||
totalReward = episodeReturn; | ||
reset(); | ||
return true; | ||
} | ||
state = nextState; | ||
return false; | ||
} | ||
|
||
private: | ||
/** | ||
* Reset the worker for a new episdoe. | ||
*/ | ||
void reset() | ||
{ | ||
steps = 0; | ||
episodeReturn = 0; | ||
pendingIndex = 0; | ||
state = environment.InitialSample(); | ||
} | ||
|
||
//! Locally-stored optimizer. | ||
UpdaterType updater; | ||
|
||
//! Locally-stored task. | ||
EnvironmentType environment; | ||
|
||
//! Locally-stored hyper-parameters. | ||
TrainingConfig config; | ||
|
||
//! Whether this episode is deterministic or not. | ||
bool deterministic; | ||
|
||
//! Total steps in current episode. | ||
size_t steps; | ||
|
||
//! Total reward in current episode. | ||
double episodeReturn; | ||
|
||
//! Buffer for delayed update. | ||
std::vector<TransitionType> pending; | ||
|
||
//! Current position of the buffer. | ||
size_t pendingIndex; | ||
|
||
//! Local network of the worker. | ||
NetworkType network; | ||
|
||
//! Current state of the agent. | ||
StateType state; | ||
}; | ||
|
||
} // namespace rl | ||
} // namespace mlpack | ||
|
||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you comment on the template parameter?