Skip to content

Commit

Permalink
Merge pull request #214 from tareknaser/gym
Browse files Browse the repository at this point in the history
Reinforcement Learning: Examples for `DDPG` and `TD3` with Gymnasium Environments.
  • Loading branch information
zoq committed Sep 4, 2023
2 parents 01c4616 + dc6f7d8 commit 9af53a8
Show file tree
Hide file tree
Showing 4 changed files with 462 additions and 0 deletions.
41 changes: 41 additions & 0 deletions reinforcement_learning_gym/mountain_car_ddpg/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# This is a simple Makefile used to build the example source code.
# This example might requires some modifications in order to work correctly on
# your system.
# If you're not using the Armadillo wrapper, replace `armadillo` with linker commands
# for the BLAS and LAPACK libraries that you are using.

TARGET := mountain_car_ddpg
SRC := mountain_car_ddpg.cpp
LIBS_NAME := armadillo boost_iostreams

CXX := g++
CXXFLAGS += -std=c++14 -Wall -Wextra -O3 -DNDEBUG -fopenmp
# Use these CXXFLAGS instead if you want to compile with debugging symbols and
# without optimizations.
# CXXFLAGS += -std=c++14 -Wall -Wextra -g -O0
LDFLAGS += -fopenmp
# Add header directories for any includes that aren't on the
# default compiler search path.
INCLFLAGS := -I .
# If you have mlpack or ensmallen installed somewhere nonstandard, uncomment and
# update the lines below.
# INCLFLAGS += -I/path/to/ensmallen/include/
CXXFLAGS += $(INCLFLAGS)

OBJS := $(SRC:.cpp=.o)
LIBS := $(addprefix -l,$(LIBS_NAME))
CLEAN_LIST := $(TARGET) $(OBJS)

# default rule
default: all

$(TARGET): $(OBJS)
$(CXX) $(CXXFLAGS) $(OBJS) -o $(TARGET) $(LDFLAGS) $(LIBS)

.PHONY: all
all: $(TARGET)

.PHONY: clean
clean:
@echo CLEAN $(CLEAN_LIST)
@rm -f $(CLEAN_LIST)
198 changes: 198 additions & 0 deletions reinforcement_learning_gym/mountain_car_ddpg/mountain_car_ddpg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/**
* In this example, we train a
* [Deep Deterministic Policy Gradient](https://arxiv.org/abs/1509.02971)
* agent to get high scores for the
* [Mountain Car Continuous]
* (https://gymnasium.farama.org/environments/classic_control/mountain_car_continuous/)
* environment.
* We make the agent train and test on OpenAI Gymnasium toolkit's GUI interface
* provided through a distributed infrastructure (TCP API). More details can be
* found [here](https://github.com/zoq/gym_tcp_api).
* A video of the trained agent can be seen in the end.
*/

// Including necessary libraries and namespaces.
#include <mlpack.hpp>

// Used to run the agent on gym's environment for testing.
#include "../gym/environment.hpp"

using namespace mlpack;
using namespace ens;

template<typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType,
typename NoiseType,
typename ReplayType = RandomReplay<EnvironmentType>>
void Train(gym::Environment& env,
DDPG<EnvironmentType, NetworkType, UpdaterType, PolicyType, NoiseType>& agent,
RandomReplay<EnvironmentType>& replayMethod,
TrainingConfig& config,
std::vector<double>& returnList,
size_t& episodes,
size_t& consecutiveEpisodes,
const size_t numSteps)
{
agent.Deterministic() = false;
std::cout << "Training for " << numSteps << " steps." << std::endl;
while (agent.TotalSteps() < numSteps)
{
double episodeReturn = 0;
env.reset();
do
{
agent.State().Data() = env.observation;
agent.SelectAction();
arma::mat action = {double(agent.Action().action[0] * 2)};

env.step(action);
ContinuousActionEnv<2, 1>::State nextState;
nextState.Data() = env.observation;

replayMethod.Store(
agent.State(), agent.Action(), env.reward, nextState, env.done, 0.99);
episodeReturn += env.reward;
agent.TotalSteps()++;

if (agent.Deterministic()
|| agent.TotalSteps() < config.ExplorationSteps())
{
continue;
}

for (size_t i = 0; i < config.UpdateInterval(); i++)
agent.Update();
}

while (!env.done);
returnList.push_back(episodeReturn);

episodes += 1;

if (returnList.size() > consecutiveEpisodes)
returnList.erase(returnList.begin());

double averageReturn =
std::accumulate(returnList.begin(), returnList.end(), 0.0)
/ returnList.size();
if (episodes % 4 == 0)
{
std::cout << "Avg return in last " << returnList.size()
<< " episodes: " << averageReturn
<< "\t Episode return: " << episodeReturn
<< "\t Total steps: " << agent.TotalSteps() << std::endl;
}
}
}

int main()
{
// Initializing the agent.

// Set up the actor and critic networks.
FFN<EmptyLoss, GaussianInitialization> policyNetwork(
EmptyLoss(), GaussianInitialization(0, 0.01));
policyNetwork.Add<Linear>(128);
policyNetwork.Add<ReLU>();
policyNetwork.Add<Linear>(128);
policyNetwork.Add<ReLU>();
policyNetwork.Add<Linear>(ContinuousActionEnv<2, 1>::Action::size);
policyNetwork.Add<TanH>();

FFN<EmptyLoss, GaussianInitialization> qNetwork(
EmptyLoss(), GaussianInitialization(0, 0.01));
qNetwork.Add<Linear>(128);
qNetwork.Add<ReLU>();
qNetwork.Add<Linear>(128);
qNetwork.Add<ReLU>();
qNetwork.Add<Linear>(1);

// Set up the policy method.
RandomReplay<ContinuousActionEnv<2, 1>> replayMethod(32, 10000);

// Set up training configurations.
TrainingConfig config;
config.ExplorationSteps() = 3200;
config.TargetNetworkSyncInterval() = 1;
config.UpdateInterval() = 1;

// Set up the OUNoise parameters.
int size = 1;
double mu = 0.0;
double theta = 1.0;
double sigma = 0.1;

// Create an instance of the OUNoise class.
OUNoise ouNoise(size, mu, theta, sigma);

// Set up the DDPG agent.
DDPG<ContinuousActionEnv<2, 1>,
decltype(qNetwork),
decltype(policyNetwork),
OUNoise,
AdamUpdate>
agent(config, qNetwork, policyNetwork, ouNoise, replayMethod);

// Preparation for training the agent.

// Set up the gym training environment.
gym::Environment env("localhost", "4040", "MountainCarContinuous-v0");

// Initializing training variables.
std::vector<double> returnList;
size_t episodes = 0;
bool converged = true;

// The number of episode returns to keep track of.
size_t consecutiveEpisodes = 25;

// Training the agent for a total of 100000 steps.
Train(env,
agent,
replayMethod,
config,
returnList,
episodes,
consecutiveEpisodes,
100000);

// Testing the trained agent.
agent.Deterministic() = true;

// Creating and setting up the gym environment for testing.
gym::Environment envTest("localhost", "4040", "MountainCarContinuous-v0-render");

// Resets the environment.
envTest.reset();

double totalReward = 0;
size_t totalSteps = 0;

while (true)
{
// State from the environment is passed to the agent's internal representation.
agent.State().Data() = envTest.observation;

// With the given state, the agent selects an action according to its defined policy.
agent.SelectAction();

// Action to take, decided by the policy.
arma::mat action = {double(agent.Action().action[0] * 2)};

envTest.step(action);
totalReward += envTest.reward;
totalSteps += 1;

if (envTest.done)
{
std::cout << " Total steps: " << totalSteps
<< "\t Total reward: " << totalReward << std::endl;
break;
}
}

envTest.close();
std::cout << envTest.url() << std::endl;
}
41 changes: 41 additions & 0 deletions reinforcement_learning_gym/pendulum_td3/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# This is a simple Makefile used to build the example source code.
# This example might requires some modifications in order to work correctly on
# your system.
# If you're not using the Armadillo wrapper, replace `armadillo` with linker commands
# for the BLAS and LAPACK libraries that you are using.

TARGET := pendulum_td3
SRC := pendulum_td3.cpp
LIBS_NAME := armadillo boost_iostreams

CXX := g++
CXXFLAGS += -std=c++14 -Wall -Wextra -O3 -DNDEBUG -fopenmp
# Use these CXXFLAGS instead if you want to compile with debugging symbols and
# without optimizations.
# CXXFLAGS += -std=c++14 -Wall -Wextra -g -O0
LDFLAGS += -fopenmp
# Add header directories for any includes that aren't on the
# default compiler search path.
INCLFLAGS := -I .
# If you have mlpack or ensmallen installed somewhere nonstandard, uncomment and
# update the lines below.
# INCLFLAGS += -I/path/to/ensmallen/include/
CXXFLAGS += $(INCLFLAGS)

OBJS := $(SRC:.cpp=.o)
LIBS := $(addprefix -l,$(LIBS_NAME))
CLEAN_LIST := $(TARGET) $(OBJS)

# default rule
default: all

$(TARGET): $(OBJS)
$(CXX) $(CXXFLAGS) $(OBJS) -o $(TARGET) $(LDFLAGS) $(LIBS)

.PHONY: all
all: $(TARGET)

.PHONY: clean
clean:
@echo CLEAN $(CLEAN_LIST)
@rm -f $(CLEAN_LIST)
Loading

0 comments on commit 9af53a8

Please sign in to comment.