[![Binder](https://mybinder.org/badge_logo.svg)](https://lab.mlpack.org/v2/gh/mlpack/examples/master?urlpath=lab%2Ftree%2Fq_learning%2Fbipedal_walker_sac.ipynb)

You can easily run this notebook at https://lab.mlpack.org/

Here, we train a [Soft Actor-Critic](https://arxiv.org/abs/1801.01290) agent to get high scores for the [Bipedal Walker](https://gym.openai.com/envs/BipedalWalker-v2/) environment. 

We make the agent train and test on OpenAI Gym toolkit's GUI interface provided through a distributed infrastructure (TCP API). More details can be found [here](https://github.com/zoq/gym_tcp_api).

A video of the trained agent can be seen in the end.

## Including necessary libraries and namespaces

In [1]:
#include <mlpack/core.hpp>

In [2]:
#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/reinforcement_learning/sac.hpp>
#include <mlpack/methods/ann/loss_functions/empty_loss.hpp>
#include <mlpack/methods/ann/init_rules/gaussian_init.hpp>
#include <mlpack/methods/reinforcement_learning/environment/env_type.hpp>
#include <mlpack/methods/reinforcement_learning/training_config.hpp>

In [3]:
// Used to run the agent on gym's environment (provided externally) for testing.
#include <gym/environment.hpp>

In [4]:
// Used to generate and display a video of the trained agent.
#include "xwidgets/ximage.hpp"
#include "xwidgets/xvideo.hpp"
#include "xwidgets/xaudio.hpp"

In [5]:
using namespace mlpack;

In [6]:
using namespace mlpack::ann;

In [7]:
using namespace ens;

In [8]:
using namespace mlpack::rl;

## Initializing the agent

In [9]:
// Set up the state and action space.
ContinuousActionEnv::State::dimension = 24;
ContinuousActionEnv::Action::size = 4;

In [10]:
// Set up the actor and critic networks.
FFN<EmptyLoss<>, GaussianInitialization>
    policyNetwork(EmptyLoss<>(), GaussianInitialization(0, 0.1));
FFN<EmptyLoss<>, GaussianInitialization>
    qNetwork(EmptyLoss<>(), GaussianInitialization(0, 0.1));

In [11]:
// Set up the replay method.
RandomReplay<ContinuousActionEnv> replayMethod(32, 10000);

In [12]:
// Set up training configurations.
TrainingConfig config;
config.ExplorationSteps() = 3200;
config.TargetNetworkSyncInterval() = 1;
config.UpdateInterval() = 3;

In [13]:
!ls

50policyNetwork.xml
50qNetwork.xml
appveyor.yml
bipedal_walker_sac.ipynb
breast_cancer_wisconsin_transformation_with_pca
cifar10_transformation_with_pca
forest_covertype_prediction_with_random_forests
go
LICENSE.txt
lstm_electricity_consumption
lstm_stock_prediction
Manifest.toml
mnist_batch_norm
mnist_cnn
mnist_simple
mnist_vae_cnn
movie_lens_prediction_with_cf
neural_network_regression
Project.toml
README.md
reinforcement_learning_gym
tools


In [14]:
data::Load("./50qNetwork.xml", "episode", qNetwork);

In file included from input_line_7:1:
In file included from /srv/conda/envs/notebook/include/mlpack/core.hpp:83:
In file included from /srv/conda/envs/notebook/include/mlpack/prereqs.hpp:88:
In file included from /srv/conda/envs/notebook/include/mlpack/core/data/has_serialize.hpp:18:
In file included from /srv/conda/envs/notebook/include/boost/archive/xml_oarchive.hpp:31:
In file included from /srv/conda/envs/notebook/include/boost/archive/basic_xml_oarchive.hpp:22:
In file included from /srv/conda/envs/notebook/include/boost/archive/detail/common_oarchive.hpp:22:
In file included from /srv/conda/envs/notebook/include/boost/archive/detail/interface_oarchive.hpp:23:
In file included from /srv/conda/envs/notebook/include/boost/archive/detail/oserializer.hpp:40:
In file included from /srv/conda/envs/notebook/include/boost/serialization/extended_type_info_typeid.hpp:32:
        use(* m_instance);
              ^~~~~~~~~~


Interpreter Exception: 

In [None]:
data::Load("50policyNetwork.xml", "episode", policyNetwork);

## Testing the trained agent

In [None]:
std::cout << "Loading complete!" << std::endl;
// Set up Soft actor-critic agent.
SAC<ContinuousActionEnv, decltype(qNetwork), decltype(policyNetwork), AdamUpdate>
    agent(config, qNetwork, policyNetwork, replayMethod);

agent.Deterministic() = true;

// Creating and setting up the gym environment for testing.
gym::Environment envTest("gym.kurg.org", "4040", "BipedalWalker-v3");
envTest.monitor.start("./dummy/", true, true);

// Resets the environment.
envTest.reset();
envTest.render();

double totalReward = 0;
size_t totalSteps = 0;

// Testing the agent on gym's environment.
while (1)
{
  // State from the environment is passed to the agent's internal representation.
  agent.State().Data() = envTest.observation;

  // With the given state, the agent selects an action according to its defined policy.
  agent.SelectAction();

  // Action to take, decided by the policy.
  arma::mat action = {agent.Action().action};

  envTest.step(action);
  totalReward += envTest.reward;
  totalSteps += 1;

  if (envTest.done)
  {
    std::cout << " Total steps: " << totalSteps << "\t Total reward: "
        << totalReward << std::endl;
    break;
  }

  // Uncomment the following lines to see the reward and action in each step.
  // std::cout << " Current step: " << totalSteps << "\t current reward: "
  //   << totalReward << "\t Action taken: " << action;
}

envTest.close();
std::string url = envTest.url();
std::cout << url;
auto video = xw::video_from_url(url).finalize();
video

## A little more training...

In [None]:
// Training the same agent for a total of at least 100000 steps.
train(100000)

# Final agent testing!

In [None]:
agent.Deterministic() = true;

// Creating and setting up the gym environment for testing.
gym::Environment envTest("gym.kurg.org", "4040", "BipedalWalker-v3");
envTest.monitor.start("./dummy/", true, true);

// Resets the environment.
envTest.reset();
envTest.render();

double totalReward = 0;
size_t totalSteps = 0;

// Testing the agent on gym's environment.
while (1)
{
  // State from the environment is passed to the agent's internal representation.
  agent.State().Data() = envTest.observation;

  // With the given state, the agent selects an action according to its defined policy.
  agent.SelectAction();

  // Action to take, decided by the policy.
  arma::mat action = {double(agent.Action().action[0] * 2)};

  envTest.step(action);
  totalReward += envTest.reward;
  totalSteps += 1;

  if (envTest.done)
  {
    std::cout << " Total steps: " << totalSteps << "\t Total reward: "
        << totalReward << std::endl;
    break;
  }

  // Uncomment the following lines to see the reward and action in each step.
  // std::cout << " Current step: " << totalSteps << "\t current reward: "
  //   << totalReward << "\t Action taken: " << action;
}

envTest.close();
std::string url = envTest.url();
std::cout << url;
auto video = xw::video_from_url(url).finalize();
video