In [None]:
// You can easily run this notebook at https://lab.mlpack.org/

In [1]:
#include <mlpack/core.hpp>

In [2]:
#include <mlpack/methods/ann/ffn.hpp>
#include <mlpack/methods/ann/init_rules/gaussian_init.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/loss_functions/mean_squared_error.hpp>
#include <mlpack/methods/reinforcement_learning/q_learning.hpp>
#include <mlpack/methods/reinforcement_learning/q_networks/simple_dqn.hpp>
#include <mlpack/methods/reinforcement_learning/environment/mountain_car.hpp>
#include <mlpack/methods/reinforcement_learning/environment/acrobot.hpp>
#include <mlpack/methods/reinforcement_learning/environment/cart_pole.hpp>
#include <mlpack/methods/reinforcement_learning/environment/double_pole_cart.hpp>
#include <mlpack/methods/reinforcement_learning/policy/greedy_policy.hpp>
#include <mlpack/methods/reinforcement_learning/training_config.hpp>

In [3]:
#include <ensmallen.hpp>

In [4]:
using namespace mlpack;

In [5]:
using namespace mlpack::ann;

In [6]:
using namespace ens;

In [7]:
using namespace mlpack::rl;

In [8]:
// Set up the network.
  SimpleDQN<> model(4, 256, 128, 2);

In [9]:
// Set up the policy and replay method.
GreedyPolicy<CartPole> policy(1.0, 1000, 0.1, 0.99);
RandomReplay<CartPole> replayMethod(20, 10000);

In [10]:
TrainingConfig config;
config.StepSize() = 0.01;
config.Discount() = 0.9;
config.TargetNetworkSyncInterval() = 100;
config.ExplorationSteps() = 100;
config.DoubleQLearning() = false;
config.StepLimit() = 200;

In [11]:
// Set up DQN agent.
QLearning<CartPole, decltype(model), AdamUpdate, decltype(policy)>
  agent(std::move(config), std::move(model), std::move(policy),
  std::move(replayMethod));

In [12]:
// Training the agent on local environment:

arma::running_stat<double> averageReturn;
size_t episodes = 0;
bool converged = true;

while (true)
{
    double episodeReturn = agent.Episode();
    averageReturn(episodeReturn);
    episodes += 1;
    
    std::cout << "Average return: " << averageReturn.mean()
        << " Episode return: " << episodeReturn << std::endl;
    
    if (episodes > 1000)
    {
      std::cout << "Cart Pole with DQN failed." << std::endl;
      converged = false;
      break;
    }
    
    if (averageReturn.mean() > 50)
    {
      agent.Deterministic() = true;
      arma::running_stat<double> testReturn;
      for (size_t i = 0; i < 10; ++i)
        testReturn(agent.Episode());

      std::cout << "Average return in deterministic test: "
          << testReturn.mean() << std::endl;
      break;
    }
}

Average return: 25 Episode return: 25
Average return: 16.5 Episode return: 8
Average return: 18 Episode return: 21
Average return: 20.5 Episode return: 28
Average return: 19.4 Episode return: 15
Average return: 19 Episode return: 17
Average return: 17.8571 Episode return: 11
Average return: 16.75 Episode return: 9
Average return: 16.2222 Episode return: 12
Average return: 24.9 Episode return: 103
Average return: 24 Episode return: 15
Average return: 23.1667 Episode return: 14
Average return: 22.2308 Episode return: 11
Average return: 22.1429 Episode return: 21
Average return: 22.2667 Episode return: 24
Average return: 21.875 Episode return: 16
Average return: 21.9412 Episode return: 23
Average return: 22 Episode return: 23
Average return: 21.4211 Episode return: 11
Average return: 20.95 Episode return: 12
Average return: 20.9048 Episode return: 20
Average return: 20.4545 Episode return: 11
Average return: 20.2609 Episode return: 16
Average return: 20.25 Episode return: 20
Average retur

In [13]:
#include <gym/environment.hpp>

In [14]:
#include "xwidgets/ximage.hpp"
#include "xwidgets/xvideo.hpp"
#include "xwidgets/xaudio.hpp"

In [15]:
double totalReward = 0;
size_t totalSteps = 0;

In [16]:
gym::Environment env("gym.kurg.org", "4040", "CartPole-v0");

In [17]:
env.monitor.start("./dummy/", true, true);

In [18]:
env.reset()

@0x7f85c5e95530

In [19]:
env.render()

In [20]:
// Testing the agent on gym's environment
while (1)
  {
    // State from the environment is passed to the agent's internal representation
    agent.State().Data() = env.observation;
    
    // with the given state, the agent performs an action according to its defined policy
    agent.Step();
    
    // Action to take, decided by the policy
    arma::mat action = {double(agent.Action())};

    // arma::mat action = env.action_space.sample();
    std::cout << "action: " << action;

    env.step(action);
    totalReward += env.reward;
    totalSteps += 1;

    if (env.done)
      break;

    std::cout << " Current step: " << totalSteps << " current reward: "
              << totalReward << std::endl;
  }

action:         0
 Current step: 1 current reward: 1
action:         0
 Current step: 2 current reward: 2
action:    1.0000
 Current step: 3 current reward: 3
action:         0
 Current step: 4 current reward: 4
action:    1.0000
 Current step: 5 current reward: 5
action:    1.0000
 Current step: 6 current reward: 6
action:         0
 Current step: 7 current reward: 7
action:    1.0000
 Current step: 8 current reward: 8
action:         0
 Current step: 9 current reward: 9
action:    1.0000
 Current step: 10 current reward: 10
action:         0
 Current step: 11 current reward: 11
action:    1.0000
 Current step: 12 current reward: 12
action:    1.0000
 Current step: 13 current reward: 13
action:         0
 Current step: 14 current reward: 14
action:         0
 Current step: 15 current reward: 15
action:    1.0000
 Current step: 16 current reward: 16
action:    1.0000
 Current step: 17 current reward: 17
action:         0
 Current step: 18 current reward: 18
action:    1.0000
 Current s

In [21]:
env.close();
std::string url = env.url();
std::cout << url << std::endl;

https://gym.kurg.org/647380835f904/output.webm


In [22]:
auto video = xw::video_from_url(url).finalize();
video

A Jupyter widget