[![mlpack-lab Image](https://img.shields.io/endpoint?url=https%3A%2F%2Flab.kurg.org%2Fstatus%2Fstatus.json)](https://lab.mlpack.org)

You can easily run this notebook at https://lab.mlpack.org/

This notebook shows how to get started with training reinforcement learning agents, particularly DQN agents, using mlpack. Here, we train a [Simple DQN](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) agent to get high scores for the [CartPole](https://gym.openai.com/envs/CartPole-v0) environment. 

Mlpack contains non-gui implementations of some of OpenAI gym's environments. In this notebook, we use one such environment for training the agent, as it is fast to train on.

As for testing, we make the agent run on OpenAI Gym toolkit's GUI interface provided through a distributed infrastructure (TCP API). A video of the trained agent can be seen in the end.

## Including necessary libraries and namespaces

In [1]:
#include <mlpack/core.hpp>

In [2]:
#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/reinforcement_learning/q_learning.hpp>
#include <mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp>
#include <mlpack/methods/reinforcement_learning/environment/cart_pole.hpp>
#include <mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp>
#include <mlpack/methods/reinforcement_learning/training_config.hpp>

In [3]:
// Used to run the agent on gym's environment (provided externally) for testing
#include <gym/environment.hpp>

In [4]:
// Used to generate and display a video of the trained agent.
#include "xwidgets/ximage.hpp"
#include "xwidgets/xvideo.hpp"
#include "xwidgets/xaudio.hpp"

In [5]:
using namespace mlpack;

In [6]:
using namespace mlpack::ann;

In [7]:
using namespace ens;

In [8]:
using namespace mlpack::rl;

## Initializing the agent

In [9]:
// Set up the network.
SimpleDQN<> model(4, 256, 128, 2);

In [10]:
// Set up the policy and replay method.
GreedyPolicy<CartPole> policy(1.0, 1000, 0.1, 0.99);
RandomReplay<CartPole> replayMethod(20, 10000);

In [11]:
// Set up training configurations.
TrainingConfig config;
config.StepSize() = 0.01;
config.Discount() = 0.9;
config.TargetNetworkSyncInterval() = 100;
config.ExplorationSteps() = 100;
config.DoubleQLearning() = false;
config.StepLimit() = 200;

In [12]:
// Set up DQN agent.
QLearning<CartPole, decltype(model), AdamUpdate, decltype(policy)>
  agent(std::move(config), std::move(model), std::move(policy),
  std::move(replayMethod));

## Training the agent

In [13]:
// Training the agent on mlpack's own implementation of the CartPole environment:
arma::running_stat<double> averageReturn;
size_t episodes = 0;

while (true)
{
    double episodeReturn = agent.Episode();
    averageReturn(episodeReturn);
    episodes += 1;
    
    if(episodes % 10 == 0)
        std::cout << "Average return: " << averageReturn.mean()
            << "\t Episode return: " << episodeReturn 
            << "\t Episode number: " << episodes << std::endl;
    
    if (episodes > 1000)
    {
      std::cout << "Cart Pole with DQN failed." << std::endl;
      break;
    }
    
    if (averageReturn.mean() > 50)
    {
      agent.Deterministic() = true;
      arma::running_stat<double> testReturn;
      for (size_t i = 0; i < 10; ++i)
        testReturn(agent.Episode());

      std::cout << "Average return in deterministic test: "
          << testReturn.mean() << std::endl;
      break;
    }
}

Average return: 21.5	 Episode return: 28	 Episode number: 10
Average return: 19.85	 Episode return: 13	 Episode number: 20
Average return: 21.8667	 Episode return: 39	 Episode number: 30
Average return: 43.175	 Episode return: 146	 Episode number: 40
Average return in deterministic test: 145.1


## Testing the agent

In [14]:
double totalReward = 0;
size_t totalSteps = 0;

In [15]:
gym::Environment env("gym.kurg.org", "4040", "CartPole-v0");

In [16]:
env.monitor.start("./dummy/", true, true);

In [17]:
env.reset()

@0x7fb1a0d8d530

In [18]:
env.render()

In [19]:
// Testing the agent on gym's environment
while (1)
  {
    // State from the environment is passed to the agent's internal representation
    agent.State().Data() = env.observation;
    
    // with the given state, the agent performs an action according to its defined policy
    agent.Step();
    
    // Action to take, decided by the policy
    arma::mat action = {double(agent.Action())};

    std::cout << "action: " << action;

    env.step(action);
    totalReward += env.reward;
    totalSteps += 1;

    if (env.done)
      break;

    std::cout << " Current step: " << totalSteps << "\t current reward: "
              << totalReward << std::endl;
  }

action:    1.0000
 Current step: 1	 current reward: 1
action:    1.0000
 Current step: 2	 current reward: 2
action:    1.0000
 Current step: 3	 current reward: 3
action:         0
 Current step: 4	 current reward: 4
action:         0
 Current step: 5	 current reward: 5
action:    1.0000
 Current step: 6	 current reward: 6
action:         0
 Current step: 7	 current reward: 7
action:         0
 Current step: 8	 current reward: 8
action:    1.0000
 Current step: 9	 current reward: 9
action:         0
 Current step: 10	 current reward: 10
action:    1.0000
 Current step: 11	 current reward: 11
action:         0
 Current step: 12	 current reward: 12
action:    1.0000
 Current step: 13	 current reward: 13
action:         0
 Current step: 14	 current reward: 14
action:    1.0000
 Current step: 15	 current reward: 15
action:         0
 Current step: 16	 current reward: 16
action:    1.0000
 Current step: 17	 current reward: 17
action:         0
 Current step: 18	 current reward: 18
action:   

In [20]:
env.close();
std::string url = env.url();

In [21]:
auto video = xw::video_from_url(url).finalize();
video

A Jupyter widget