In [21]:
import gym
import numpy as np
from stable_baselines3 import PPO

from stable_baselines3.ppo.policies import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy

from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import DQN

env = gym.make('MountainCar-v0')

model = PPO(MlpPolicy, env, verbose=0)

def evaluate(model, num_episodes=100):
    """
    Evaluate a RL agent
    :param model: (BaseRLModel object) the RL Agent
    :param num_episodes: (int) number of episodes to evaluate it
    :return: (float) Mean reward for the last num_episodes
    """
    # This function will only work for a single Environment
    env = model.get_env()
    all_episode_rewards = []
    for i in range(num_episodes):
        episode_rewards = []
        done = False
        obs = env.reset()
        score = 0
        while not done:
            # _states are only useful when using LSTM policies
            action, _states = model.predict(obs)
            # here, action, rewards and dones are arrays
            # because we are using vectorized env
            obs, reward, done, info = env.step(action)
            episode_rewards.append(reward)
            score += reward
        episode_rewards.append(score)

    return all_episode_rewards

# mean_reward_before_train = evaluate(model, num_episodes=100)

In [22]:
import gym

from stable_baselines3 import DQN


model = DQN(
    'MlpPolicy',
    env,
    learning_rate = 4e-3,
    verbose=1,
    batch_size = 128,
    buffer_size = 10_000,
    learning_starts=1000,
    gamma = 0.98,
    target_update_interval=600,
    train_freq=16,
    gradient_steps=8,
    exploration_fraction = 0.2,
    exploration_final_eps = 0.07,
    policy_kwargs={'net_arch': [256, 256]}
    )
model.learn(total_timesteps=150_000)

model.save("dqn_cartpole")

model = DQN.load("dqn_cartpole")

env = gym.make("MountainCar-v0")
obs = env.reset()
def evaluate(model, num_episodes=1000):
    """
    Evaluate a RL agent
    :param model: (BaseRLModel object) the RL Agent
    :param num_episodes: (int) number of episodes to evaluate it
    :return: (float) Mean reward for the last num_episodes
    """
    all_episode_rewards = []
    for i in range(num_episodes):
        done = False
        obs = env.reset()
        score = 0
        while not done:
            action, _states = model.predict(obs, deterministic=True)
            obs, reward, done, info = env.step(action)
            score += reward
        all_episode_rewards.append(score)

    return all_episode_rewards

evaluate(model)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
| train/              |          |
|    learning_rate    | 0.00063  |
|    loss             | 0.24     |
|    n_updates        | 32076    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 99.3     |
|    ep_rew_mean      | -98.3    |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 232      |
|    fps              | 149      |
|    time_elapsed     | 216      |
|    total_timesteps  | 32429    |
| train/              |          |
|    learning_rate    | 0.00063  |
|    loss             | 0.204    |
|    n_updates        | 32428    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 99.4     |
|    ep_rew_mean      | -98.5    |
|    exploration_rate | 0.1      |
| time/               |          |
|    episodes         | 2

[-72.0,
 -79.0,
 -78.0,
 -64.0,
 -79.0,
 -71.0,
 -73.0,
 -63.0,
 -96.0,
 -64.0,
 -72.0,
 -92.0,
 -86.0,
 -71.0,
 -72.0,
 -75.0,
 -75.0,
 -70.0,
 -73.0,
 -95.0,
 -72.0,
 -78.0,
 -72.0,
 -100.0,
 -75.0,
 -85.0,
 -75.0,
 -91.0,
 -72.0,
 -79.0,
 -64.0,
 -73.0,
 -87.0,
 -113.0,
 -72.0,
 -72.0,
 -123.0,
 -79.0,
 -98.0,
 -72.0,
 -91.0,
 -80.0,
 -63.0,
 -130.0,
 -85.0,
 -64.0,
 -75.0,
 -94.0,
 -73.0,
 -63.0,
 -63.0,
 -75.0,
 -75.0,
 -80.0,
 -64.0,
 -96.0,
 -75.0,
 -95.0,
 -72.0,
 -64.0,
 -116.0,
 -64.0,
 -77.0,
 -87.0,
 -73.0,
 -67.0,
 -103.0,
 -80.0,
 -85.0,
 -80.0,
 -72.0,
 -72.0,
 -75.0,
 -77.0,
 -81.0,
 -92.0,
 -78.0,
 -72.0,
 -93.0,
 -78.0,
 -79.0,
 -72.0,
 -79.0,
 -92.0,
 -70.0,
 -71.0,
 -70.0,
 -73.0,
 -74.0,
 -79.0,
 -74.0,
 -79.0,
 -88.0,
 -72.0,
 -73.0,
 -76.0,
 -75.0,
 -63.0,
 -95.0,
 -95.0,
 -64.0,
 -78.0,
 -80.0,
 -92.0,
 -63.0,
 -72.0,
 -72.0,
 -77.0,
 -64.0,
 -88.0,
 -85.0,
 -79.0,
 -79.0,
 -63.0,
 -87.0,
 -86.0,
 -63.0,
 -110.0,
 -78.0,
 -74.0,
 -72.0,
 -73.0,
 -72.0,
 -64.0,
 