# Solving Reinforcement Learning problems by using scikit-decide

[Scikit-decide](https://github.com/airbus/scikit-decide) is an open-source library developed by Airbus to model and solve sequential decision-making problems such as reinforcement learning, planning and scheduling.
The documentation of the library can be found [here](https://airbus.github.io/scikit-decide/).

One benefit of scikit-decide relies on the opportunity to solve a problem of a given class, e.g. path planning problem, by using algorithms from different communities and to compare them, e.g. reinforcement learning and planning.

In this notebook, we will:
1. Solve classical RL benchmarks by using well-known RL frameworks which are directly accessible in scikit-decide.
2. Solve planning problems by using the same RL frameworks, demonstrating the ability of the library to solve planning problems seen as RL problems.
3. Solve an aircraft taxiing control problem by using RL algorithms.

Below are the import libraries that we will use throughout this notebook, after having downloaded the potentially missing packages.

In [None]:
!python -m pip --default-timeout=1000 install "scikit-decide[all]" "shimmy[gym]" pygame gym_jsbsim JSBSim==1.1.1 --use-pep517

In [None]:
from time import sleep
from typing import Optional

import gymnasium as gym
import matplotlib.pyplot as plt
from IPython.display import clear_output

import folium
from gym_jsbsim.catalogs.catalog import Catalog as prp
from gym_jsbsim.envs.taxi_utils import *

from stable_baselines3 import PPO
from ray.rllib.algorithms.dqn import DQN, DQNConfig

from skdecide import Solver
from skdecide.builders.domain.initialization import DeterministicInitialized
from skdecide.hub.domain.gym import (
    GymDiscreteActionDomain,
    GymDomain,
    GymDomainHashable
)
from skdecide.hub.domain.simple_grid_world import SimpleGridWorld
from skdecide.hub.solver.stable_baselines import StableBaseline
from skdecide.hub.solver.ray_rllib import RayRLlib
from skdecide.hub.solver.astar import Astar
from skdecide.hub.solver.mcts import UCT
from skdecide.utils import rollout as skd_rollout

# choose standard matplolib inline backend to render plots
%matplotlib inline

## Solving classical RL benchmarks

We will consider the classical Cart Pole problem, as provided by Gymnasium and defined [here](https://gymnasium.farama.org/environments/classic_control/cart_pole/).

### Wrap Gymnasium environment in scikit-decide

We choose the gymnasium environment we would like to use.

In [None]:
ENV_NAME = "CartPole-v1"

We define a domain factory using GymDomain proxy available in scikit-decide which will wrap the Gym environment.

In [None]:
domain_factory = lambda: GymDomain(gym.make(ENV_NAME, render_mode="rgb_array"))

Here is a screenshot of such an environment.

In [None]:
domain = domain_factory()
domain.reset()
plt.imshow(domain.render())
plt.axis("off")
domain.close()

### Solve with Reinforcement Learning (StableBaseline + PPO)

We first try a solver coming from the Reinforcement Learning community that is make use of OpenAI [stable_baselines3](https://github.com/DLR-RM/stable-baselines3), which give access to a lot of RL algorithms.

Here we choose [Proximal Policy Optimization (PPO)](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) solver. It directly optimizes the weights of the policy network using stochastic gradient ascent. See more details in stable baselines [documentation](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) and [original paper](https://arxiv.org/abs/1707.06347). 

#### Check compatibility
We check the compatibility of the domain with the chosen solver.

In [None]:
domain = domain_factory()
assert StableBaseline.check_domain(domain)
domain.close()

#### Solver instantiation

In [None]:
solver = StableBaseline(
    PPO, "MlpPolicy", learn_config={"total_timesteps": 10000}, verbose=True
)

#### Training solver on domain

In [None]:
GymDomain.solve_with(solver, domain_factory)

#### Rolling out a solution

We can use the trained solver to roll out an episode to see if this is actually solving the problem at hand.

For educative purpose, we define here our own rollout (which will probably be needed if you want to actually use the solver in a real case). If you want to take a look at the (more complex) one already implemented in the library, see the `rollout()` function in [utils.py](https://github.com/airbus/scikit-decide/blob/master/skdecide/utils.py) module.

By default, we display the solution in a matplotlib figure. If you need only to check whether the goal is reached or not, you can specify `render=False`. In this case, the rollout is greatly speed up and a message is still printed at the end of process specifying success or not, with the number of steps required.

In [None]:
def rollout(
    domain: GymDomain,
    solver: Solver,
    max_steps: int,
    pause_between_steps: Optional[float] = 0.01,
    render: bool = True,
):
    """Roll out one episode in a domain according to the policy of a trained solver.

    Args:
        domain: the maze domain to solve
        solver: a trained solver
        max_steps: maximum number of steps allowed to reach the goal
        pause_between_steps: time (s) paused between agent movements.
          No pause if None.
        render: if True, the rollout is rendered in a matplotlib figure as an animation;
            if False, speed up a lot the rollout.

    """
    # Initialize episode
    solver.reset()
    observation = domain.reset()

    # Initialize image
    if render:
        plt.ioff()
        fig, ax = plt.subplots(1)
        ax.axis("off")
        plt.ion()
        img = ax.imshow(domain.render(mode="rgb_array"))
        display(fig)

    # loop until max_steps or goal is reached
    for i_step in range(1, max_steps + 1):
        if pause_between_steps is not None:
            sleep(pause_between_steps)

        # choose action according to solver
        action = solver.sample_action(observation)
        # get corresponding action
        outcome = domain.step(action)
        observation = outcome.observation

        # update image
        if render:
            img.set_data(domain.render())
            fig.canvas.draw()
            clear_output(wait=True)
            display(fig)

        # final state reached?
        if outcome.termination:
            break

    # close the figure to avoid jupyter duplicating the last image
    if render:
        plt.close(fig)

    # goal reached?
    is_goal_reached = observation[0] >= 0.45
    if is_goal_reached:
        print(f"Goal reached in {i_step} steps!")
    else:
        print(f"Goal not reached after {i_step} steps!")

    return is_goal_reached, i_step

We create a domain for the roll out and close it at the end. If not closing it, an OpenGL popup windows stays open, at least on local Jupyter sessions.

In [None]:
domain = domain_factory()
try:
    rollout(
        domain=domain,
        solver=solver,
        max_steps=300,
        pause_between_steps=None,
        render=True,
    )
finally:
    domain.close()

#### Cleaning up

Some solvers need proper cleaning before being deleted.

In [None]:
solver._cleanup()

Note that this is automatically done if you use the solver within a `with` statement. The syntax would look something like:

```python
with solver_factory() as solver:
    MyDomain.solve_with(solver, domain_factory)
    rollout(domain=domain, solver=solver)
```

### Solve with Reinforcement Learning (RLLib + DQN)

For the sake of demonstration, we now show how to solve the same Cart Pole benchmark with a different RL library and algorithm: the [DQN](https://arxiv.org/abs/1312.5602v1) algorithm as implemented in the [Ray RLLib](https://docs.ray.io/en/latest/rllib/rllib-algorithms.html#dqn) library.

Interestingly, one can note that the domain definition is exactly the same as before. We just change the solving library and algorithm.

In [None]:
if RayRLlib.check_domain(domain):
    solver_factory = lambda: RayRLlib(DQN, train_iterations=100, config=DQNConfig().training(n_step=300))

    with solver_factory() as solver:
        # Solve domain
        GymDomain.solve_with(solver, domain_factory)

        # Test solution
        rollout(
            domain=domain,
            solver=solver,
            max_steps=300,
            pause_between_steps=None,
            render=True,
        )

## Solving planning domains with Reinforcement Learning

Reinforcement Learning algorithms can learn to solve Markov Decision Processes, a large problem class which also includes path planning problems.

To demonstrate this, we consider a single grid world domain, but more challenging path planning problems like mazes can be found in the [documentation](https://airbus.github.io/scikit-decide/notebooks/#maze-tutorial) of the library.

We define a domain factory creating a simple grid world environment.

In [None]:
domain_factory = lambda: SimpleGridWorld(num_cols=10, num_rows=10)

We learn below a control policy with DQN from RLLib which tries to reach the goal state of the grid located at cell (9, 9) starting from cell (0, 0).

In [None]:
if RayRLlib.check_domain(domain):
    solver_factory = lambda: RayRLlib(
        DQN, train_iterations=100, config=DQNConfig().training(n_step=25)
    )

    with solver_factory() as solver:
        # Solve domain
        SimpleGridWorld.solve_with(solver, domain_factory)

        # Test solution
        skd_rollout(
            domain=domain,
            solver=solver,
            max_steps=25,
        )

As a comparison, and to show the ability of the scikit-decide library to solve the same domain with algorithms from different research fields, we briefly show how to use the famous [A\*](https://en.wikipedia.org/wiki/A*_search_algorithm) path planning algorithm to solve the grid domain.

In [None]:
if RayRLlib.check_domain(domain):
    solver_factory = lambda: Astar(domain_factory)

    with solver_factory() as solver:
        # Solve domain
        SimpleGridWorld.solve_with(solver, domain_factory)

        # Test solution
        skd_rollout(
            domain=domain,
            solver=solver,
            max_steps=25,
        )

## Learning to control a taxiing aircraft

To finish the demonstration of reinforcement learning with scikit-decide, we want to learn to control the taxiing phase of an aircraft.

We rely on the open-source [JSBSim](https://jsbsim.sourceforge.net/) aircraft simulator through the [gym-jsbsim](https://pypi.org/project/gym-jsbsim/) environment.

We opt for a continuous learning algorithm named [UCT](https://link.springer.com/chapter/10.1007/11871842_29), a variant of Monte-Carlo Tree Search which became popular to originally solve games like Chess and Go.

First, we embed the gym-jsbsim environment in a scikit-decide `GymDomain` compatible with UCT. It requires us to inherit from `GymDomainHashable`, `GymDiscreteActionDomain` and `DeterministicInitialized` which provide methods called by the scikit-decide implementation of the UCT algorithm.

In [None]:
ENV_NAME = "GymJsbsim-TaxiapControlTask-v0"

def normalize_and_round(state):
    ns = np.array([s[0] for s in state])
    # ns = ns / np.linalg.norm(ns) if np.linalg.norm(ns) != 0 else ns
    scale = np.array([10.0, 0.1, 100.0, 100.0, 100.0, 100.0, 1.0, 1.0, 1.0, 1.0])
    ns = 1 / (1 + np.exp(-ns / scale))
    np.around(ns, decimals=3, out=ns)
    return ns


class D(GymDomainHashable, GymDiscreteActionDomain, DeterministicInitialized):
    pass


class GymUCTDomain(D):
    """This class wraps a cost-based deterministic gymnasium environment as a domain
        usable by UCT

    !!! warning
        Using this class requires gymnasium to be installed.
    """

    def __init__(
        self,
        gym_env: gym.Env,
        discretization_factor: int = 10,
        branching_factor: int = None,
        max_depth: int = None,
    ) -> None:
        """Initialize GymUCTDomain.

        # Parameters
        gym_env: The deterministic Gym environment (gym.env) to wrap.
        discretization_factor: Number of discretized action variable values per continuous action variable
        branching_factor: if not None, sample branching_factor actions from the resulting list of discretized actions
        max_depth: maximum depth of states to explore from the initial state
        """
        GymDomainHashable.__init__(self, gym_env=gym_env)
        GymDiscreteActionDomain.__init__(
            self,
            discretization_factor=discretization_factor,
            branching_factor=branching_factor,
        )
        gym_env._max_episode_steps = max_depth
        self._map = None
        self._path = None

    def _render_from(self, memory: D.T_memory[D.T_state], **kwargs):
        # Get rid of the current state and just look at the gym env's current internal state
        lon = self._gym_env.sim.get_property_value(prp.position_long_gc_deg)
        lat = self._gym_env.sim.get_property_value(prp.position_lat_geod_deg)
        if self._map is None:
            self._map = folium.Map(location=[lat, lon], zoom_start=18)
            taxiPath = taxi_path()
            for p in taxiPath.centerlinepoints:
                folium.Marker((p[1], p[0]), popup=p).add_to(self._map)
            folium.PolyLine(
                taxiPath.centerlinepoints, color="blue", weight=2.5, opacity=1
            ).add_to(self._map)
            self._path = folium.PolyLine(
                [(lat, lon)], color="red", weight=2.5, opacity=1
            )
            self._path.add_to(self._map)
            f = open("gym_jsbsim_map.html", "w")
            f.write(
                "<!DOCTYPE html>\n"
                + "<HTML>\n"
                + "<HEAD>\n"
                + '<META http-equiv="refresh" content="60">\n'
                + "</HEAD>\n"
                + "<FRAMESET>\n"
                + '<FRAME src="gym_jsbsim_map_update.html">\n'
                + "</FRAMESET>\n"
                + "</HTML>"
            )
            f.close()
        else:
            self._path.locations.append(folium.utilities.validate_location((lat, lon)))
            self._map.location = folium.utilities.validate_location((lat, lon))
        self._map.save("gym_jsbsim_map_update.html")

domain_factory = lambda: GymUCTDomain(
    gym_env=gym.make("GymV21Environment-v0", env_id=ENV_NAME),
    discretization_factor=9,
    max_depth=500,
)

We now call UCT with a budget of 10 rollouts or 1 second at each state of the environment. It means that the policy will be optimised at the same time as it is executed: only the first action of the policy is executed in the current state, after the policy has been optimised for maximum 1 second of computation and 10 rollouts.

In [None]:
if UCT.check_domain(domain_factory()):
    solver_factory = lambda: UCT(
        domain_factory=domain_factory,
        time_budget=1000,
        rollout_budget=10,
        max_depth=500,
        discount=1.0,
        transition_mode=UCT.Options.TransitionMode.Step,
        continuous_planning=True,
        online_node_garbage=True,
        parallel=False,
        debug_logs=False,
    )
    with solver_factory() as solver:
        GymUCTDomain.solve_with(solver, domain_factory)
        solver._domain.reset()
        rollout(
            domain_factory(),
            solver,
            num_episodes=1,
            max_steps=500,
            max_framerate=30,
            outcome_formatter=lambda o: f"{o.observation} - reward: {o.value.reward:.2f}",
            action_formatter=lambda a: f"{a}",
            verbose=True,
        )