In [1]:
import gym

from stable_baselines3 import DQN, PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_vec_env

from sb3_contrib import QRDQN

In [16]:
from typing import Callable

def linear_schedule(initial_value):
    """
    Linear learning rate schedule.
    :param initial_value: (float or str)
    :return: (function)
    """
    if isinstance(initial_value, str):
        initial_value = float(initial_value)

    def func(progress):
        """
        Progress will decrease from 1 (beginning) to 0
        :param progress: (float)
        :return: (float)
        """
        return progress * initial_value

    return func

In [20]:
env = gym.make("LunarLander-v2")
model = QRDQN(
    "MlpPolicy",
    env,
    # learning_rate=linear_schedule(1.5e-3),
    learning_rate=6.3e-4,
    batch_size=128,
    buffer_size=100000,
    learning_starts=10000,
    gamma=0.995,
    target_update_interval=1,
    train_freq=256,
    gradient_steps=-1,
    exploration_fraction=0.24,
    exploration_final_eps=0.18,
    policy_kwargs=dict(net_arch=[256, 256], n_quantiles=170),
    verbose=1
)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [3]:
# env = make_vec_env("LunarLander-v2", n_envs=16)
# model = PPO(
#     "MlpPolicy",
#     env,
#     n_steps=1024,
#     batch_size=64,
#     gae_lambda=0.98,
#     gamma=0.999,
#     n_epochs=4,
#     ent_coef=0.01,
#     verbose=1
# )


In [21]:
model.learn(total_timesteps=int(1e5))

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 88.8     |
|    ep_rew_mean      | -139     |
|    exploration_rate | 0.988    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 737      |
|    time_elapsed     | 0        |
|    total_timesteps  | 355      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 91.2     |
|    ep_rew_mean      | -152     |
|    exploration_rate | 0.975    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 735      |
|    time_elapsed     | 0        |
|    total_timesteps  | 730      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 93.2     |
|    ep_rew_mean      | -151     |
|    exploration_rate | 0.962    |
| time/               |          |
|    episodes       

<sb3_contrib.qrdqn.qrdqn.QRDQN at 0x1afcd966710>

In [22]:
model.save("data/policies/LunarLander-v2#qrdqn#canitrot_dubreuil_sb3")
del model