Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reinforcement Learning: Ornstein-Uhlenbeck noise #3499

Merged
merged 3 commits into from
Jul 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
### mlpack ?.?.?
###### ????-??-??
* Reinforcement Learning: Ornstein-Uhlenbeck noise (#3499).

* Reinforcement Learning: Deep Deterministic Policy Gradient (#3494).

### mlpack 4.2.0
Expand Down
7 changes: 7 additions & 0 deletions src/mlpack/methods/reinforcement_learning/ddpg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ namespace mlpack {
* @tparam EnvironmentType The environment of the reinforcement learning task.
* @tparam QNetworkType The network used to estimate the critic's Q-values.
* @tparam PolicyNetworkType The network to compute action value.
* @tparam NoiseType The noise to add for exploration.
* @tparam UpdaterType How to apply gradients when training.
* @tparam ReplayType Experience replay method.
*/
template <
typename EnvironmentType,
typename QNetworkType,
typename PolicyNetworkType,
typename NoiseType,
typename UpdaterType,
typename ReplayType = RandomReplay<EnvironmentType>
>
Expand All @@ -75,6 +77,7 @@ class DDPG
* @param config Hyper-parameters for training.
* @param learningQNetwork The network to compute action value.
* @param policyNetwork The network to produce an action given a state.
* @param noise The noise instance for exploration.
* @param replayMethod Experience replay method.
* @param qNetworkUpdater How to apply gradients to Q network when training.
* @param policyNetworkUpdater How to apply gradients to policy network
Expand All @@ -84,6 +87,7 @@ class DDPG
DDPG(TrainingConfig& config,
QNetworkType& learningQNetwork,
PolicyNetworkType& policyNetwork,
NoiseType& noise,
ReplayType& replayMethod,
UpdaterType qNetworkUpdater = UpdaterType(),
UpdaterType policyNetworkUpdater = UpdaterType(),
Expand Down Expand Up @@ -150,6 +154,9 @@ class DDPG
//! Locally-stored policy network.
PolicyNetworkType& policyNetwork;

//! Locally-stored noise instance.
NoiseType& noise;

//! Locally-stored target policy network.
PolicyNetworkType targetPNetwork;

Expand Down
23 changes: 20 additions & 3 deletions src/mlpack/methods/reinforcement_learning/ddpg_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,29 @@ template <
typename EnvironmentType,
typename QNetworkType,
typename PolicyNetworkType,
typename NoiseType,
typename UpdaterType,
typename ReplayType
>
DDPG<
EnvironmentType,
QNetworkType,
PolicyNetworkType,
NoiseType,
UpdaterType,
ReplayType
>::DDPG(TrainingConfig& config,
QNetworkType& learningQNetwork,
PolicyNetworkType& policyNetwork,
NoiseType& noise,
ReplayType& replayMethod,
UpdaterType qNetworkUpdater,
UpdaterType policyNetworkUpdater,
EnvironmentType environment):
config(config),
learningQNetwork(learningQNetwork),
policyNetwork(policyNetwork),
noise(noise),
replayMethod(replayMethod),
qNetworkUpdater(std::move(qNetworkUpdater)),
#if ENS_VERSION_MAJOR >= 2
Expand All @@ -55,6 +59,9 @@ DDPG<
totalSteps(0),
deterministic(false)
{
// Reset the noise instance.
noise.reset();

// Set up q-learning and policy networks.
targetPNetwork = policyNetwork;
targetQNetwork = learningQNetwork;
Expand Down Expand Up @@ -106,13 +113,15 @@ template <
typename EnvironmentType,
typename QNetworkType,
typename PolicyNetworkType,
typename NoiseType,
typename UpdaterType,
typename ReplayType
>
DDPG<
EnvironmentType,
QNetworkType,
PolicyNetworkType,
NoiseType,
UpdaterType,
ReplayType
>::~DDPG()
Expand All @@ -127,13 +136,15 @@ template <
typename EnvironmentType,
typename QNetworkType,
typename PolicyNetworkType,
typename NoiseType,
typename UpdaterType,
typename ReplayType
>
void DDPG<
EnvironmentType,
QNetworkType,
PolicyNetworkType,
NoiseType,
UpdaterType,
ReplayType
>::SoftUpdate(double rho)
Expand All @@ -148,13 +159,15 @@ template <
typename EnvironmentType,
typename QNetworkType,
typename PolicyNetworkType,
typename NoiseType,
typename UpdaterType,
typename ReplayType
>
void DDPG<
EnvironmentType,
QNetworkType,
PolicyNetworkType,
NoiseType,
UpdaterType,
ReplayType
>::Update()
Expand Down Expand Up @@ -255,13 +268,15 @@ template <
typename EnvironmentType,
typename QNetworkType,
typename PolicyNetworkType,
typename NoiseType,
typename UpdaterType,
typename ReplayType
>
void DDPG<
EnvironmentType,
QNetworkType,
PolicyNetworkType,
NoiseType,
UpdaterType,
ReplayType
>::SelectAction()
Expand All @@ -272,9 +287,9 @@ void DDPG<

if (!deterministic)
{
arma::colvec noise = arma::randn<arma::colvec>(outputAction.n_rows) * 0.1;
noise = arma::clamp(noise, -0.25, 0.25);
outputAction = outputAction + noise;
arma::colvec sample = noise.sample() * 0.1;
sample = arma::clamp(sample, -0.25, 0.25);
outputAction = outputAction + sample;
}
action.action = arma::conv_to<std::vector<double>>::from(outputAction);
}
Expand All @@ -283,13 +298,15 @@ template <
typename EnvironmentType,
typename QNetworkType,
typename PolicyNetworkType,
typename NoiseType,
typename UpdaterType,
typename ReplayType
>
double DDPG<
EnvironmentType,
QNetworkType,
PolicyNetworkType,
NoiseType,
UpdaterType,
ReplayType
>::Episode()
Expand Down
17 changes: 17 additions & 0 deletions src/mlpack/methods/reinforcement_learning/noise/noise.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/**
* @file methods/reinforcement_learning/noise/noise.hpp
* @author Tarek Elsayed
*
* Convenience include for reinforcement learning noises.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_REINFORCEMENT_LEARNING_NOISE_NOISE_HPP
#define MLPACK_METHODS_REINFORCEMENT_LEARNING_NOISE_NOISE_HPP

#include "ornstein_uhlenbeck.hpp"

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/**
* @file methods/reinforcement_learning/noise/ornstein_uhlenbeck.hpp
* @author Tarek Elsayed
*
* This file is the implementation of OUNoise class.
* Ornstein-Uhlenbeck process generates temporally correlated exploration,
* and it effectively copes with physical control problems of inertia.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_RL_NOISE_ORNSTEIN_UHLENBECK_HPP
#define MLPACK_METHODS_RL_NOISE_ORNSTEIN_UHLENBECK_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
class OUNoise
{
public:
/**
* @param size The size of the noise vector.
* @param mu The mean of the noise process.
* @param theta The rate of mean reversion.
* @param sigma The standard deviation of the noise.
*/
OUNoise(int size,
double mu = 0.0,
double theta = 0.15,
double sigma = 0.2) :
mu(mu * arma::ones<arma::colvec>(size)),
theta(theta),
sigma(sigma)
{
reset();
}

/**
* Reset the internal state to the mean (mu).
*/
void reset()
{
state = mu;
}

/**
* Update the internal state and return it as a noise sample.
*
* @return Noise sample.
*/
arma::colvec sample()
{
arma::colvec x = state;
arma::colvec dx = theta * (mu - x) +
sigma * arma::randn<arma::colvec>(x.n_elem);
state = x + dx;
return state;
}

private:
//! Locally-stored state of the noise process.
arma::colvec state;

//! Locally-stored mean of the noise process.
arma::colvec mu;

//! Locally-stored rate of mean reversion.
double theta;

//! Locally-stored standard deviation of the noise.
double sigma;
};

} // namespace mlpack

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "q_networks/q_networks.hpp"
#include "replay/replay.hpp"
#include "worker/worker.hpp"
#include "noise/noise.hpp"

#include "training_config.hpp"
#include "async_learning.hpp"
Expand Down
50 changes: 46 additions & 4 deletions src/mlpack/tests/q_learning_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,9 +598,19 @@ TEST_CASE("PendulumWithDDPG", "[QLearningTest]")
qNetwork.Add(new ReLU());
qNetwork.Add(new Linear(1));

// Set up the OUNoise parameters.
int size = 1;
double mu = 0.0;
double theta = 1.0;
double sigma = 0.1;

// Create an instance of the OUNoise class.
OUNoise ouNoise(size, mu, theta, sigma);

// Set up Deep Deterministic Policy Gradient agent.
DDPG<Pendulum, decltype(qNetwork), decltype(policyNetwork), AdamUpdate>
agent(config, qNetwork, policyNetwork, replayMethod);
DDPG<Pendulum, decltype(qNetwork), decltype(policyNetwork),
OUNoise, AdamUpdate>
agent(config, qNetwork, policyNetwork, ouNoise, replayMethod);

converged = testAgent<decltype(agent)>(agent, -900, 500, 10);
if (converged)
Expand Down Expand Up @@ -633,10 +643,19 @@ TEST_CASE("DDPGForMultipleActions", "[QLearningTest]")
config.TargetNetworkSyncInterval() = 1;
config.UpdateInterval() = 3;

// Set up the OUNoise parameters.
int size = 4;
double mu = 0.0;
double theta = 1.0;
double sigma = 0.1;

// Create an instance of the OUNoise class.
OUNoise ouNoise(size, mu, theta, sigma);

// Set up the DDPG agent.
DDPG<ContinuousActionEnv<3, 4>, decltype(qNetwork), decltype(policyNetwork),
AdamUpdate>
agent(config, qNetwork, policyNetwork, replayMethod);
OUNoise, AdamUpdate>
agent(config, qNetwork, policyNetwork, ouNoise, replayMethod);

agent.State().Data() = arma::randu<arma::colvec>
(ContinuousActionEnv<3, 4>::State::dimension, 1);
Expand All @@ -651,3 +670,26 @@ TEST_CASE("DDPGForMultipleActions", "[QLearningTest]")
// If the agent is able to reach this point of the test, it is assured
// that the agent can handle multiple actions in continuous space.
}

//! Test Ornstein-Uhlenbeck noise class.
TEST_CASE("OUNoiseTest", "[QLearningTest]")
{
// Set up the OUNoise parameters.
int size = 3;
double mu = 0.0;
double theta = 0.15;
double sigma = 0.2;

// Create an instance of the OUNoise class.
OUNoise ouNoise(size, mu, theta, sigma);

// Test the reset function.
ouNoise.reset();
arma::colvec state = ouNoise.sample();
REQUIRE(state.n_elem == size);

// Verify that the sample is not equal to the reset state.
arma::colvec sample = ouNoise.sample();
bool isNotEqual = arma::any(sample != state);
REQUIRE(isNotEqual);
}