# 1. Import dependencies

In [1]:
!pip install stable-baselines3[extra]

Collecting stable-baselines3[extra]
  Downloading stable_baselines3-1.1.0-py3-none-any.whl (172 kB)
Collecting torch>=1.4.0
  Downloading torch-1.9.0-cp39-cp39-win_amd64.whl (222.0 MB)
Collecting atari-py~=0.2.0
  Downloading atari_py-0.2.9-cp39-cp39-win_amd64.whl (1.6 MB)
Collecting psutil
  Downloading psutil-5.8.0-cp39-cp39-win_amd64.whl (246 kB)
Installing collected packages: torch, stable-baselines3, psutil, atari-py
Successfully installed atari-py-0.2.9 psutil-5.8.0 stable-baselines3-1.1.0 torch-1.9.0


In [14]:
import os
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

# 2. Load Environment

In [15]:
environment_name = 'CartPole-v0'
env = gym.make(environment_name)

In [3]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0
     
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score += reward
    print('Episode:{}, Score:{}'.format(episode, score))
env.close()

Episode:1, Score:18.0
Episode:2, Score:19.0
Episode:3, Score:19.0
Episode:4, Score:12.0
Episode:5, Score:19.0


# Understanding The Environment

In [16]:
env.action_space

Discrete(2)

In [17]:
env.action_space.sample()

1

In [18]:
env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [19]:
env.observation_space.sample()

array([-7.1619913e-02, -1.8191247e+38, -2.1800922e-01, -2.5326838e+38],
      dtype=float32)

# 3. Train a RL Model

In [20]:
log_path = os.path.join('Training', 'Logs')

In [21]:
env = gym.make(environment_name)
env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cuda device


In [9]:
model.learn(total_timesteps=20000)

Logging to Training\Logs\PPO_1
-----------------------------
| time/              |      |
|    fps             | 722  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 593         |
|    iterations           | 2           |
|    time_elapsed         | 6           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008631874 |
|    clip_fraction        | 0.0807      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | 0.0024      |
|    learning_rate        | 0.0003      |
|    loss                 | 6.1         |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0117     |
|    value_loss           | 41.9        |
-----------------------------------------
---

<stable_baselines3.ppo.ppo.PPO at 0x1de4dde2700>

# 4. Save and Reload Model

In [22]:
PPO_path = os.path.join('Training', 'Saved Models', 'PPO_Model_CartPole')

In [11]:
model.save(PPO_path)

In [12]:
del model

In [13]:
model = PPO.load(PPO_path, env=env)

In [14]:
model.learn(total_timesteps=1000)

Logging to Training\Logs\PPO_6
-----------------------------
| time/              |      |
|    fps             | 841  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------


<stable_baselines3.ppo.ppo.PPO at 0x11663de93d0>

# 5. Evaluation

In [16]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)

(200.0, 0.0)

In [17]:
env.close()

# 6. Test Model

In [20]:
episodes = 5
for episode in range(1, episodes+1):
    obs = env.reset()
    done = False
    score = 0
     
    while not done:
        env.render()
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(action)
        score += reward
    print('Episode:{}, Score:{}'.format(episode, score))
env.close()

Episode:1, Score:[200.]
Episode:2, Score:[200.]
Episode:3, Score:[200.]
Episode:4, Score:[200.]
Episode:5, Score:[200.]


# 7. Viewing Logs in Tensorboard

In [23]:
training_log_path = os.path.join(log_path, 'PPO_1')

In [13]:
!tensorboard --logdir={training_log_path}

^C


# 8. Adding a callback to the training Stage

In [25]:
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

In [27]:
save_path = os.path.join('Training', 'Saved Models')

In [29]:
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)
eval_callback = EvalCallback(env,
                            callback_on_new_best=stop_callback,
                            eval_freq=10000,
                            best_model_save_path=save_path,
                            verbose=1)

In [30]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cuda device


In [31]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to Training\Logs\PPO_2
-----------------------------
| time/              |      |
|    fps             | 798  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 606          |
|    iterations           | 2            |
|    time_elapsed         | 6            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0078354925 |
|    clip_fraction        | 0.0973       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.686       |
|    explained_variance   | 0.000879     |
|    learning_rate        | 0.0003       |
|    loss                 | 6.03         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.0145      |
|    value_loss           | 51.3         |
----------------------------



Eval num_timesteps=10000, episode_reward=176.20 +/- 28.06
Episode length: 176.20 +/- 28.06
------------------------------------------
| eval/                   |              |
|    mean_ep_length       | 176          |
|    mean_reward          | 176          |
| time/                   |              |
|    total timesteps      | 10000        |
| train/                  |              |
|    approx_kl            | 0.0084099155 |
|    clip_fraction        | 0.0901       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.607       |
|    explained_variance   | 0.346        |
|    learning_rate        | 0.0003       |
|    loss                 | 21.2         |
|    n_updates            | 40           |
|    policy_gradient_loss | -0.0195      |
|    value_loss           | 64.4         |
------------------------------------------
New best mean reward!
------------------------------
| time/              |       |
|    fps             | 487   |
|    iterations     

<stable_baselines3.ppo.ppo.PPO at 0x1de9041ba90>

# 9. Changing Policies

In [33]:
net_arch =[dict(pi=[128,128,128,128], vf=[128,128,128,128])]

In [34]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path, policy_kwargs={'net_arch':net_arch})

Using cuda device


In [35]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to Training\Logs\PPO_3
-----------------------------
| time/              |      |
|    fps             | 697  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 516         |
|    iterations           | 2           |
|    time_elapsed         | 7           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.015462247 |
|    clip_fraction        | 0.181       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.682      |
|    explained_variance   | -0.012      |
|    learning_rate        | 0.0003      |
|    loss                 | 2.01        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0181     |
|    value_loss           | 18.7        |
-----------------------------------------
---



Eval num_timesteps=10000, episode_reward=200.00 +/- 0.00
Episode length: 200.00 +/- 0.00
-----------------------------------------
| eval/                   |             |
|    mean_ep_length       | 200         |
|    mean_reward          | 200         |
| time/                   |             |
|    total timesteps      | 10000       |
| train/                  |             |
|    approx_kl            | 0.013740765 |
|    clip_fraction        | 0.167       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.571      |
|    explained_variance   | 0.284       |
|    learning_rate        | 0.0003      |
|    loss                 | 11.3        |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.0215     |
|    value_loss           | 37.4        |
-----------------------------------------
------------------------------
| time/              |       |
|    fps             | 419   |
|    iterations      | 5     |
|    time_elapsed    | 24    |


<stable_baselines3.ppo.ppo.PPO at 0x1de96987e80>

# 10. Using an Alternate Algorithm

In [36]:
from stable_baselines3 import DQN

In [37]:
model = DQN('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cuda device


In [38]:
model.learn(total_timesteps=20000)

Logging to Training\Logs\DQN_1
----------------------------------
| rollout/            |          |
|    exploration rate | 0.97     |
| time/               |          |
|    episodes         | 4        |
|    fps              | 6416     |
|    time_elapsed     | 0        |
|    total timesteps  | 64       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.931    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 8076     |
|    time_elapsed     | 0        |
|    total timesteps  | 145      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.895    |
| time/               |          |
|    episodes         | 12       |
|    fps              | 9233     |
|    time_elapsed     | 0        |
|    total timesteps  | 221      |
----------------------------------
------------------------

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 108      |
|    fps              | 10127    |
|    time_elapsed     | 0        |
|    total timesteps  | 2341     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 112      |
|    fps              | 10153    |
|    time_elapsed     | 0        |
|    total timesteps  | 2428     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 116      |
|    fps              | 10265    |
|    time_elapsed     | 0        |
|    total timesteps  | 2557     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 216      |
|    fps              | 10875    |
|    time_elapsed     | 0        |
|    total timesteps  | 4911     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 220      |
|    fps              | 10907    |
|    time_elapsed     | 0        |
|    total timesteps  | 5034     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 224      |
|    fps              | 10908    |
|    time_elapsed     | 0        |
|    total timesteps  | 5111     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 324      |
|    fps              | 10947    |
|    time_elapsed     | 0        |
|    total timesteps  | 7299     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 328      |
|    fps              | 10962    |
|    time_elapsed     | 0        |
|    total timesteps  | 7407     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 332      |
|    fps              | 10945    |
|    time_elapsed     | 0        |
|    total timesteps  | 7472     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 432      |
|    fps              | 11032    |
|    time_elapsed     | 0        |
|    total timesteps  | 9583     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 436      |
|    fps              | 11003    |
|    time_elapsed     | 0        |
|    total timesteps  | 9645     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 440      |
|    fps              | 10992    |
|    time_elapsed     | 0        |
|    total timesteps  | 9712     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 540      |
|    fps              | 10807    |
|    time_elapsed     | 1        |
|    total timesteps  | 11897    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 544      |
|    fps              | 10813    |
|    time_elapsed     | 1        |
|    total timesteps  | 12011    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 548      |
|    fps              | 10784    |
|    time_elapsed     | 1        |
|    total timesteps  | 12065    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 648      |
|    fps              | 10740    |
|    time_elapsed     | 1        |
|    total timesteps  | 14470    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 652      |
|    fps              | 10755    |
|    time_elapsed     | 1        |
|    total timesteps  | 14608    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 656      |
|    fps              | 10750    |
|    time_elapsed     | 1        |
|    total timesteps  | 14665    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 756      |
|    fps              | 10806    |
|    time_elapsed     | 1        |
|    total timesteps  | 16908    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 760      |
|    fps              | 10806    |
|    time_elapsed     | 1        |
|    total timesteps  | 16995    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 764      |
|    fps              | 10777    |
|    time_elapsed     | 1        |
|    total timesteps  | 17089    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 864      |
|    fps              | 10302    |
|    time_elapsed     | 1        |
|    total timesteps  | 19359    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 868      |
|    fps              | 10288    |
|    time_elapsed     | 1        |
|    total timesteps  | 19456    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 872      |
|    fps              | 10274    |
|    time_elapsed     | 1        |
|    total timesteps  | 19522    |
----------------------------------
----------------------------------
| rollout/          

<stable_baselines3.dqn.dqn.DQN at 0x1de969987c0>