[![Binder](https://mybinder.org/badge_logo.svg)](https://lab.mlpack.org/v2/gh/mlpack/examples/master?urlpath=lab%2Ftree%2Fq_learning%2Facrobot_dqn.ipynb)

You can easily run this notebook at https://lab.mlpack.org/

This notebook is shows how to get use 3-Step Double DQN with Prioritized Replay to train an agent to get high scores for the [Acrobot](https://gym.openai.com/envs/Acrobot-v1) environment. 

We make the agent train and test on OpenAI Gym toolkit's GUI interface provided through a distributed infrastructure (TCP API). More details can be found [here](https://github.com/zoq/gym_tcp_api).

A video of the trained agent can be seen in the end.

## Including necessary libraries and namespaces

In [1]:
#include <mlpack/core.hpp>

In [2]:
#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/reinforcement_learning/q_learning.hpp>
#include <mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp>
#include <mlpack/methods/reinforcement_learning/environment/env_type.hpp>
#include <mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp>
#include <mlpack/methods/reinforcement_learning/training_config.hpp>

In [3]:
// Used to run the agent on gym's environment (provided externally) for testing.
#include <gym/environment.hpp>

In [4]:
// Used to generate and display a video of the trained agent.
#include "xwidgets/ximage.hpp"
#include "xwidgets/xvideo.hpp"
#include "xwidgets/xaudio.hpp"

In [5]:
using namespace mlpack;

In [6]:
using namespace mlpack::ann;

In [7]:
using namespace ens;

In [8]:
using namespace mlpack::rl;

## Initializing the agent

In [9]:
// Set up the state and action space.
DiscreteActionEnv::State::dimension = 6;
DiscreteActionEnv::Action::size = 3;

In [10]:
// Set up the network.
FFN<MeanSquaredError<>, RandomInitialization> module(MeanSquaredError<>(), RandomInitialization(-1, 1));
module.Add<Linear<>>(DiscreteActionEnv::State::dimension, 64);
module.Add<ReLULayer<>>();
module.Add<Linear<>>(64, DiscreteActionEnv::Action::size);
SimpleDQN<FFN<MeanSquaredError<>, RandomInitialization>> model(module);

In [11]:
// Set up the policy method.
GreedyPolicy<DiscreteActionEnv> policy(1.0, 1000, 0.1, 0.99);
// To enable 3-step learning, we set the last parameter of the replay method as 3.
PrioritizedReplay<DiscreteActionEnv> replayMethod(64, 5000, 0.6, 3);

In [12]:
// Set up training configurations.
TrainingConfig config;
config.TargetNetworkSyncInterval() = 100;
config.ExplorationSteps() = 500;

// We use double Q learning for this example.
config.DoubleQLearning() = true;

In [13]:
// Set up DQN agent.
QLearning<DiscreteActionEnv, decltype(model), AdamUpdate, decltype(policy), decltype(replayMethod)>
    agent(config, model, policy, replayMethod);

## Preparation for training the agent

In [14]:
// Set up the gym training environment.
gym::Environment env("gym.kurg.org", "4040", "Acrobot-v1");

// Initializing training variables.
std::vector<double> returnList;
size_t episodes = 0;
bool converged = true;

// The number of episode returns to keep track of.
size_t consecutiveEpisodes = 50;

In [15]:
// Function to train the agent on the Acrobot-v1 gym environment.
void train(const size_t numSteps)
{
  agent.Deterministic() = false;
  std::cout << "Training for " << numSteps << " steps." << std::endl;
  while (agent.TotalSteps() < numSteps)
  {
    double episodeReturn = 0;
    env.reset();
    do
    {
      agent.State().Data() = env.observation;
      agent.SelectAction();
      arma::mat action = {double(agent.Action().action)};

      env.step(action);
      DiscreteActionEnv::State nextState;
      nextState.Data() = env.observation;

      replayMethod.Store(agent.State(), agent.Action(), env.reward, nextState,
          env.done, 0.99);
      episodeReturn += env.reward;
      agent.TotalSteps()++;
      if (agent.Deterministic() || agent.TotalSteps() < config.ExplorationSteps())
        continue;
      agent.TrainAgent();
    } while (!env.done);
    returnList.push_back(episodeReturn);
    episodes += 1;

    if (returnList.size() > consecutiveEpisodes)
      returnList.erase(returnList.begin());
        
    double averageReturn = std::accumulate(returnList.begin(),
                                           returnList.end(), 0.0) /
                           returnList.size();
    if(episodes % 5 == 0)
    {
      std::cout << "Avg return in last " << consecutiveEpisodes
          << " episodes: " << averageReturn
          << "\t Episode return: " << episodeReturn
          << "\t Total steps: " << agent.TotalSteps() << std::endl;
    }
  }
}

## Let the training begin

In [16]:
// Training the agent for a total of at least 25000 steps.
train(25000)

Training for 25000 steps.
Avg return in last 50 episodes: -476.4	 Episode return: -465	 Total steps: 2384
Avg return in last 50 episodes: -352.7	 Episode return: -141	 Total steps: 3534
Avg return in last 50 episodes: -289.267	 Episode return: -209	 Total steps: 4351
Avg return in last 50 episodes: -254.25	 Episode return: -106	 Total steps: 5102
Avg return in last 50 episodes: -225.68	 Episode return: -123	 Total steps: 5664
Avg return in last 50 episodes: -206.167	 Episode return: -87	 Total steps: 6212
Avg return in last 50 episodes: -192.743	 Episode return: -108	 Total steps: 6778
Avg return in last 50 episodes: -185.325	 Episode return: -146	 Total steps: 7450
Avg return in last 50 episodes: -194.822	 Episode return: -191	 Total steps: 8808
Avg return in last 50 episodes: -193.8	 Episode return: -164	 Total steps: 9736
Avg return in last 50 episodes: -158.52	 Episode return: -126	 Total steps: 10359
Avg return in last 50 episodes: -152.08	 Episode return: -165	 Total steps: 11187

## Testing the trained agent

In [17]:
agent.Deterministic() = true;

// Creating and setting up the gym environment for testing.
gym::Environment envTest("gym.kurg.org", "4040", "Acrobot-v1");
envTest.monitor.start("./dummy/", true, true);

// Resets the environment.
envTest.reset();
envTest.render();

double totalReward = 0;
size_t totalSteps = 0;

// Testing the agent on gym's environment.
while (1)
{
  // State from the environment is passed to the agent's internal representation.
  agent.State().Data() = envTest.observation;

  // With the given state, the agent selects an action according to its defined policy.
  agent.SelectAction();

  // Action to take, decided by the policy.
  arma::mat action = {double(agent.Action().action)};

  envTest.step(action);
  totalReward += envTest.reward;
  totalSteps += 1;

  if (envTest.done)
  {
    std::cout << " Total steps: " << totalSteps << "\t Total reward: "
        << totalReward << std::endl;
    break;
  }

  // Uncomment the following lines to see the reward and action in each step.
  // std::cout << " Current step: " << totalSteps << "\t current reward: "
  //   << totalReward << "\t Action taken: " << action;
}

envTest.close();
std::string url = envTest.url();
std::cout << url << std::endl;
auto video = xw::video_from_url(url).finalize();
video

 Total steps: 122	 Total reward: -121
https://gym.kurg.org/b52506a7015d4/output.webm


A Jupyter widget

Due to the stochasticity of the environment, it's quite possible that sometimes the agent is not able to solve it in each test. So, we test the agent once more, just to be sure.

You may test the agent any number of times by rerunning either of the testing cells. 

In [18]:
agent.Deterministic() = true;

// Creating and setting up the gym environment for testing.
gym::Environment envTest("gym.kurg.org", "4040", "Acrobot-v1");
envTest.monitor.start("./dummy/", true, true);

// Resets the environment.
envTest.reset();
envTest.render();

double totalReward = 0;
size_t totalSteps = 0;

// Testing the agent on gym's environment.
while (1)
{
  // State from the environment is passed to the agent's internal representation.
  agent.State().Data() = envTest.observation;

  // With the given state, the agent selects an action according to its defined policy.
  agent.SelectAction();

  // Action to take, decided by the policy.
  arma::mat action = {double(agent.Action().action)};

  envTest.step(action);
  totalReward += envTest.reward;
  totalSteps += 1;

  if (envTest.done)
  {
    std::cout << " Total steps: " << totalSteps << "\t Total reward: "
        << totalReward << std::endl;
    break;
  }

  // Uncomment the following lines to see the reward and action in each step.
  // std::cout << " Current step: " << totalSteps << "\t current reward: "
  //   << totalReward << "\t Action taken: " << action;
}

envTest.close();
std::string url = envTest.url();
std::cout << url << std::endl;
auto video = xw::video_from_url(url).finalize();
video

 Total steps: 151	 Total reward: -150
https://gym.kurg.org/be13778abc714/output.webm


A Jupyter widget