[![Binder](https://mybinder.org/badge_logo.svg)](https://lab.mlpack.org/v2/gh/mlpack/examples/master?urlpath=lab%2Ftree%2Fq_learning%2Fcartpole_dqn.ipynb)

You can easily run this notebook at https://lab.mlpack.org/

This notebook shows how to get started with training reinforcement learning agents, particularly DQN agents, using mlpack. Here, we train a [Simple DQN](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) agent to get high scores for the [CartPole](https://gym.openai.com/envs/CartPole-v0) environment. 

We make the agent train and test on OpenAI Gym toolkit's GUI interface provided through a distributed infrastructure (TCP API). More details can be found [here](https://github.com/zoq/gym_tcp_api).

A video of the trained agent can be seen in the end.

## Including necessary libraries and namespaces

In [1]:
#include <mlpack/core.hpp>

In [2]:
#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/reinforcement_learning/q_learning.hpp>
#include <mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp>
#include <mlpack/methods/reinforcement_learning/environment/env_type.hpp>
#include <mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp>
#include <mlpack/methods/reinforcement_learning/training_config.hpp>

In [3]:
// Used to run the agent on gym's environment (provided externally) for testing.
#include <gym/environment.hpp>

In [4]:
// Used to generate and display a video of the trained agent.
#include "xwidgets/ximage.hpp"
#include "xwidgets/xvideo.hpp"
#include "xwidgets/xaudio.hpp"

In [5]:
using namespace mlpack;

In [6]:
using namespace mlpack::ann;

In [7]:
using namespace ens;

In [8]:
using namespace mlpack::rl;

## Initializing the agent

In [9]:
// Set up the state and action space.
DiscreteActionEnv::State::dimension = 4;
DiscreteActionEnv::Action::size = 2;

In [10]:
// Set up the network.
SimpleDQN<> model(4, 128, 32, 2);

In [11]:
// Set up the policy and replay method.
GreedyPolicy<DiscreteActionEnv> policy(1.0, 1000, 0.1, 0.99);
RandomReplay<DiscreteActionEnv> replayMethod(32, 2000);

In [12]:
// Set up training configurations.
TrainingConfig config;
config.StepSize() = 0.001;
config.Discount() = 0.99;
config.TargetNetworkSyncInterval() = 100;
config.ExplorationSteps() = 100;
config.DoubleQLearning() = false;
config.StepLimit() = 200;

In [13]:
// Set up DQN agent.
QLearning<DiscreteActionEnv, decltype(model), AdamUpdate, decltype(policy)>
    agent(config, model, policy, replayMethod);

## Preparation for training the agent

In [14]:
// Set up the gym training environment.
gym::Environment env("gym.kurg.org", "4040", "CartPole-v0");

// Set up the gym testing environment.
gym::Environment envTest("gym.kurg.org", "4040", "CartPole-v0");
// Start test env monitor.
envTest.compression(9);
envTest.monitor.start("./dummy/", true, true);

In [15]:
// Initializing training variables.
std::vector<double> returnList;
size_t episodes = 0;
bool converged = true;
// The number of episode returns to keep track of.
size_t consecutiveEpisodes = 50;

In [16]:
// Function to train the agent on mlpack's own implementation of the CartPole environment.
void train(const size_t numSteps)
{
  agent.Deterministic() = false;
  std::cout << "Training for " << numSteps << " steps." << std::endl;
  while (agent.TotalSteps() < numSteps)
  {
    double episodeReturn = 0;
    env.reset();
    do
    {
      agent.State().Data() = env.observation;
      agent.SelectAction();
      arma::mat action = {double(agent.Action().action)};

      env.step(action);
      DiscreteActionEnv::State nextState;
      nextState.Data() = env.observation;

      replayMethod.Store(agent.State(), agent.Action(), env.reward, nextState,
          env.done, 0.99);
      episodeReturn += env.reward;
      agent.TotalSteps()++;
      if (agent.Deterministic() || agent.TotalSteps() < config.ExplorationSteps())
        continue;
      agent.TrainAgent();
    } while (!env.done);
    returnList.push_back(episodeReturn);
    episodes += 1;

    if (returnList.size() > consecutiveEpisodes)
      returnList.erase(returnList.begin());
        
    double averageReturn = std::accumulate(returnList.begin(),
                                           returnList.end(), 0.0) /
                           returnList.size();
    if(episodes % 1 == 0)
    {
      std::cout << "Avg return in last " << consecutiveEpisodes
          << " episodes: " << averageReturn
          << "\t Episode return: " << episodeReturn
          << "\t Total steps: " << agent.TotalSteps() << std::endl;
    }
  }
}

## Let the training begin

In [17]:
// Training the agent for a total of at least 2500 steps.
train(2500)

Training for 2500 steps.
Avg return in last 50 episodes: 16	 Episode return: 16	 Total steps: 16
Avg return in last 50 episodes: 15	 Episode return: 14	 Total steps: 30
Avg return in last 50 episodes: 16	 Episode return: 18	 Total steps: 48
Avg return in last 50 episodes: 15	 Episode return: 12	 Total steps: 60
Avg return in last 50 episodes: 15.2	 Episode return: 16	 Total steps: 76
Avg return in last 50 episodes: 15.5	 Episode return: 17	 Total steps: 93
Avg return in last 50 episodes: 17.7143	 Episode return: 31	 Total steps: 124
Avg return in last 50 episodes: 17.625	 Episode return: 17	 Total steps: 141
Avg return in last 50 episodes: 17	 Episode return: 12	 Total steps: 153
Avg return in last 50 episodes: 16.5	 Episode return: 12	 Total steps: 165
Avg return in last 50 episodes: 15.9091	 Episode return: 10	 Total steps: 175
Avg return in last 50 episodes: 15.8333	 Episode return: 15	 Total steps: 190
Avg return in last 50 episodes: 16.3846	 Episode return: 23	 Total steps: 213
Av

Avg return in last 50 episodes: 11.1	 Episode return: 28	 Total steps: 1586
Avg return in last 50 episodes: 11.12	 Episode return: 9	 Total steps: 1595
Avg return in last 50 episodes: 10.9	 Episode return: 10	 Total steps: 1605
Avg return in last 50 episodes: 10.84	 Episode return: 14	 Total steps: 1619
Avg return in last 50 episodes: 10.78	 Episode return: 8	 Total steps: 1627
Avg return in last 50 episodes: 10.78	 Episode return: 9	 Total steps: 1636
Avg return in last 50 episodes: 10.72	 Episode return: 9	 Total steps: 1645
Avg return in last 50 episodes: 11.46	 Episode return: 48	 Total steps: 1693
Avg return in last 50 episodes: 11.46	 Episode return: 10	 Total steps: 1703
Avg return in last 50 episodes: 11.46	 Episode return: 13	 Total steps: 1716
Avg return in last 50 episodes: 12.58	 Episode return: 65	 Total steps: 1781
Avg return in last 50 episodes: 13.38	 Episode return: 50	 Total steps: 1831
Avg return in last 50 episodes: 15.24	 Episode return: 102	 Total steps: 1933
Avg 

## Testing the trained agent

In [18]:
agent.Deterministic() = true;

// Resets the environment.
envTest.reset();
envTest.render();

double totalReward = 0;
size_t totalSteps = 0;

// Testing the agent on gym's environment.
while (1)
{
  // State from the environment is passed to the agent's internal representation.
  agent.State().Data() = envTest.observation;

  // With the given state, the agent selects an action according to its defined policy.
  agent.SelectAction();

  // Action to take, decided by the policy.
  arma::mat action = {double(agent.Action().action)};

  envTest.step(action);
  totalReward += env.reward;
  totalSteps += 1;

  if (envTest.done)
  {
    std::cout << " Total steps: " << totalSteps << "\t Total reward: "
        << totalReward << std::endl;
    break;
  }

  // Uncomment the following lines to see the reward and action in each step.
  // std::cout << " Current step: " << totalSteps << "\t current reward: "
  //   << totalReward << "\t Action taken: " << action;
}

envTest.close();
std::string url = envTest.url();

auto video = xw::video_from_url(url).finalize();
video

 Total steps: 200	 Total reward: 200


A Jupyter widget

## A little more training...

In [19]:
// Training the same agent for a total of at least 5000 episodes.
train(5000)

Training for 5000 steps.
Avg return in last 50 episodes: 28.42	 Episode return: 134	 Total steps: 2671
Avg return in last 50 episodes: 29.16	 Episode return: 46	 Total steps: 2717
Avg return in last 50 episodes: 32.06	 Episode return: 153	 Total steps: 2870
Avg return in last 50 episodes: 33.1	 Episode return: 61	 Total steps: 2931
Avg return in last 50 episodes: 35.02	 Episode return: 104	 Total steps: 3035
Avg return in last 50 episodes: 36.9	 Episode return: 103	 Total steps: 3138
Avg return in last 50 episodes: 39.32	 Episode return: 130	 Total steps: 3268
Avg return in last 50 episodes: 41.48	 Episode return: 116	 Total steps: 3384
Avg return in last 50 episodes: 45.28	 Episode return: 200	 Total steps: 3584
Avg return in last 50 episodes: 48.94	 Episode return: 193	 Total steps: 3777
Avg return in last 50 episodes: 52.74	 Episode return: 200	 Total steps: 3977
Avg return in last 50 episodes: 56.4	 Episode return: 200	 Total steps: 4177
Avg return in last 50 episodes: 60.18	 Episo

# Final agent testing!

In [None]:
agent.Deterministic() = true;

// Resets the environment.
envTest.reset();
envTest.render();

double totalReward = 0;
size_t totalSteps = 0;

// Testing the agent on gym's environment.
while (1)
{
  // State from the environment is passed to the agent's internal representation.
  agent.State().Data() = envTest.observation;

  // With the given state, the agent selects an action according to its defined policy.
  agent.SelectAction();

  // Action to take, decided by the policy.
  arma::mat action = {double(agent.Action().action)};

  envTest.step(action);
  totalReward += env.reward;
  totalSteps += 1;

  if (envTest.done)
  {
    std::cout << " Total steps: " << totalSteps << "\t Total reward: "
        << totalReward << std::endl;
    break;
  }

  // Uncomment the following lines to see the reward and action in each step.
  // std::cout << " Current step: " << totalSteps << "\t current reward: "
  //   << totalReward << "\t Action taken: " << action;
}

envTest.close();
std::string url = envTest.url();

auto video = xw::video_from_url(url).finalize();
video