[![mlpack-lab Image](https://img.shields.io/endpoint?url=https%3A%2F%2Flab.kurg.org%2Fstatus%2Fstatus.json)](https://lab.mlpack.org)

You can easily run this notebook at https://lab.mlpack.org/

This notebook shows how to get started with training reinforcement learning agents, particularly DQN agents, using mlpack. Here, we train a [Simple DQN](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) agent to get high scores for the [CartPole](https://gym.openai.com/envs/CartPole-v0) environment. 

mlpack contains non-GUI implementations of some of OpenAI gym's environments. In this notebook, we use one such environment for training the agent, as it is fast to train on.

As for testing, we make the agent run on OpenAI Gym toolkit's GUI interface provided through a distributed infrastructure (TCP API). More details can be found [here](https://github.com/zoq/gym_tcp_api).

A video of the trained agent can be seen in the end.

## Including necessary libraries and namespaces

In [1]:
#include <mlpack/core.hpp>

In [2]:
#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/reinforcement_learning/q_learning.hpp>
#include <mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp>
#include <mlpack/methods/reinforcement_learning/environment/cart_pole.hpp>
#include <mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp>
#include <mlpack/methods/reinforcement_learning/training_config.hpp>

In [3]:
// Used to run the agent on gym's environment (provided externally) for testing.
#include <gym/environment.hpp>

In [4]:
// Used to generate and display a video of the trained agent.
#include "xwidgets/ximage.hpp"
#include "xwidgets/xvideo.hpp"
#include "xwidgets/xaudio.hpp"

In [5]:
using namespace mlpack;

In [6]:
using namespace mlpack::ann;

In [7]:
using namespace ens;

In [8]:
using namespace mlpack::rl;

## Initializing the agent

In [9]:
// Set up the network.
SimpleDQN<> model(4, 128, 32, 2);

In [10]:
// Set up the policy and replay method.
GreedyPolicy<CartPole> policy(1.0, 1000, 0.1, 0.99);
RandomReplay<CartPole> replayMethod(32, 1000);

In [11]:
// Set up training configurations.
TrainingConfig config;
config.StepSize() = 0.001;
config.Discount() = 0.99;
config.TargetNetworkSyncInterval() = 100;
config.ExplorationSteps() = 100;
config.DoubleQLearning() = false;
config.StepLimit() = 200;

In [12]:
// Set up DQN agent.
QLearning<CartPole, decltype(model), AdamUpdate, decltype(policy)>
  agent(std::move(config), std::move(model), std::move(policy),
  std::move(replayMethod));

## Preparation for training the agent

In [13]:
// Initializing training variables.
arma::running_stat<double> averageReturn;
size_t episodes = 0;

In [14]:
// Function to train the agent on mlpack's own implementation of the CartPole environment.
void train(const size_t threshold)
{
  agent.Deterministic() = false;
  while (true)
  {
    double episodeReturn = agent.Episode();
    averageReturn(episodeReturn);
    episodes += 1;

    if(episodes % 10 == 0)
    {
      std::cout << "Average return: " << averageReturn.mean()
          << "\t Episode return: " << episodeReturn 
          << "\t Episode number: " << episodes << std::endl;
    }

    if (episodes > 1000)
    {
      std::cout << "Cart Pole with DQN failed." << std::endl;
      break;
    }

    if (averageReturn.mean() > threshold)
    {
      agent.Deterministic() = true;
      arma::running_stat<double> testReturn;
      for (size_t i = 0; i < 10; ++i)
        testReturn(agent.Episode());

      std::cout << "Average return in deterministic test: "
          << testReturn.mean() << std::endl;
      break;
    }
  }
}

## Let the training begin

In [15]:
// Training the agent till average return reaches 25.
train(25)

Average return: 19.8	 Episode return: 20	 Episode number: 10
Average return: 19.75	 Episode return: 11	 Episode number: 20
Average return: 19.4667	 Episode return: 10	 Episode number: 30
Average return: 17.375	 Episode return: 10	 Episode number: 40
Average return: 17.18	 Episode return: 7	 Episode number: 50
Average return: 16.2833	 Episode return: 11	 Episode number: 60
Average return: 15.5571	 Episode return: 20	 Episode number: 70
Average return: 14.775	 Episode return: 8	 Episode number: 80
Average return: 14.1	 Episode return: 10	 Episode number: 90
Average return: 13.6	 Episode return: 10	 Episode number: 100
Average return: 13.1636	 Episode return: 10	 Episode number: 110
Average return: 13.125	 Episode return: 15	 Episode number: 120
Average return: 15.5077	 Episode return: 35	 Episode number: 130
Average return: 20.45	 Episode return: 62	 Episode number: 140
Average return: 25.3333	 Episode return: 126	 Episode number: 150
Average return in deterministic test: 144.3


## Testing the trained agent

In [16]:
// Creating and setting up the gym environment for testing.
gym::Environment env("gym.kurg.org", "4040", "CartPole-v0");
env.monitor.start("./dummy/", true, true);

// Resets the environment.
env.reset();
env.render();

double totalReward = 0;
size_t totalSteps = 0;

// Testing the agent on gym's environment.
while (1)
{
  // State from the environment is passed to the agent's internal representation.
  agent.State().Data() = env.observation;

  // With the given state, the agent performs an action according to its defined policy.
  agent.Step();

  // Action to take, decided by the policy.
  arma::mat action = {double(agent.Action())};

  env.step(action);
  totalReward += env.reward;
  totalSteps += 1;

  if (env.done)
  {
    std::cout << " Total steps: " << totalSteps << "\t Total reward: "
        << totalReward << std::endl;
    break;
  }

  // Uncomment the following lines to see the reward and action in each step.
  // std::cout << " Current step: " << totalSteps << "\t current reward: "
  //   << totalReward << "\t Action taken: " << action;
}

env.close();
std::string url = env.url();

auto video = xw::video_from_url(url).finalize();
video

 Total steps: 140	 Total reward: 140


A Jupyter widget

## A little more training...

In [17]:
// Training the same agent till average return reaches 60
train(160)

Average return: 30.8	 Episode return: 112	 Episode number: 160
Average return: 36.9706	 Episode return: 200	 Episode number: 170
Average return: 44.2167	 Episode return: 134	 Episode number: 180
Average return: 51.6421	 Episode return: 200	 Episode number: 190
Average return: 58.92	 Episode return: 177	 Episode number: 200
Average return: 64.7048	 Episode return: 147	 Episode number: 210
Average return: 65.6045	 Episode return: 30	 Episode number: 220
Average return: 67.2565	 Episode return: 200	 Episode number: 230
Average return: 72.7875	 Episode return: 200	 Episode number: 240
Average return: 77.876	 Episode return: 200	 Episode number: 250
Average return: 82.5731	 Episode return: 200	 Episode number: 260
Average return: 86.9222	 Episode return: 200	 Episode number: 270
Average return: 90.5679	 Episode return: 200	 Episode number: 280
Average return: 94.3414	 Episode return: 200	 Episode number: 290
Average return: 97.8633	 Episode return: 200	 Episode number: 300
Average return: 1

# Final agent testing!

In [18]:
// Creating and setting up the gym environment for testing.
gym::Environment env("gym.kurg.org", "4040", "CartPole-v0");
env.monitor.start("./dummy/", true, true);

// Resets the environment.
env.reset();
env.render();

double totalReward = 0;
size_t totalSteps = 0;

// Testing the agent on gym's environment.
while (1)
{
  // State from the environment is passed to the agent's internal representation.
  agent.State().Data() = env.observation;

  // With the given state, the agent performs an action according to its defined policy.
  agent.Step();

  // Action to take, decided by the policy.
  arma::mat action = {double(agent.Action())};

  env.step(action);
  totalReward += env.reward;
  totalSteps += 1;

  if (env.done)
  {
    std::cout << " Total steps: " << totalSteps << "\t Total reward: "
        << totalReward << std::endl;
    break;
  }

  // Uncomment the following lines to see the reward and action in each step.
  // std::cout << " Current step: " << totalSteps << "\t current reward: "
  //   << totalReward << "\t Action taken: " << action;
}

env.close();
std::string url = env.url();

auto video = xw::video_from_url(url).finalize();
video

 Total steps: 200	 Total reward: 200


A Jupyter widget