In [None]:
# https://stable-baselines3.readthedocs.io/en/master/guide/rl.html
# https://spinningup.openai.com/en/latest/spinningup/rl_intro2.html#a-taxonomy-of-rl-algorithms

# 1. Import Dependencies

In [3]:
!pip install stable-baselines3[extra]

Process is interrupted.


In [1]:
import os
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

# 2. Load Environment

In [2]:
ENV_NAME = 'CartPole-v0'
env = gym.make(ENV_NAME)

In [3]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+= reward
        
    print(f'Episode: {episode} Score: {score}')
env.close()

Episode: 1 Score: 16.0
Episode: 2 Score: 12.0
Episode: 3 Score: 17.0
Episode: 4 Score: 9.0
Episode: 5 Score: 26.0


## Understanding The Environment

https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py

In [10]:
env.action_space # Always Check for choosing support Algorithm

Discrete(2)

In [8]:
# 0-push cart to left, 1-push cart to the right
env.action_space.sample()

1

In [11]:
env.observation_space

Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)

In [9]:
# [cart position, cart velocity, pole angle, pole angular velocity]
env.observation_space.sample()

array([ 4.3913503e+00,  3.3374228e+38,  2.1033937e-01, -2.4968812e+37],
      dtype=float32)

# 3. Train an RL Model

In [4]:
# Make your directory first
log_path = os.path.join('Training', 'Logs')

In [6]:
env = gym.make(ENV_NAME)
env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [11]:
model.learn(total_timesteps=20000)

Logging to Training/Logs/PPO_3
-----------------------------
| time/              |      |
|    fps             | 800  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 560         |
|    iterations           | 2           |
|    time_elapsed         | 7           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.006706044 |
|    clip_fraction        | 0.0378      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.555      |
|    explained_variance   | 0.49        |
|    learning_rate        | 0.0003      |
|    loss                 | 66          |
|    n_updates            | 120         |
|    policy_gradient_loss | -0.00239    |
|    value_loss           | 104         |
-----------------------------------------
---

<stable_baselines3.ppo.ppo.PPO at 0x7ff7be8d0550>

# 4. Save and Reload Model

In [8]:
PPO_Path = os.path.join('Training', 'SavedModels', 'PPO_Model_CartPole')

In [12]:
model.save(PPO_Path)

In [13]:
del model

In [14]:
model = PPO.load(PPO_Path, env=env)

# 5. Evaluation

In [15]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)



(200.0, 0.0)

In [16]:
env.close()

# 6. Test Model

In [36]:
episodes = 20
for episode in range(1, episodes+1):
    obs = env.reset()
    done = False
    score = 0
    
    while not done:
        env.render()
        action, _ = model.predict(obs) # Change to our Agent
        obs, reward, done, info = env.step(action)
        score+= reward
    print(f'Episode: {episode} Score: {score}')
env.close()

Episode: 1 Score: [200.]
Episode: 2 Score: [200.]
Episode: 3 Score: [200.]
Episode: 4 Score: [200.]
Episode: 5 Score: [200.]
Episode: 6 Score: [200.]
Episode: 7 Score: [200.]
Episode: 8 Score: [200.]
Episode: 9 Score: [200.]
Episode: 10 Score: [200.]
Episode: 11 Score: [200.]
Episode: 12 Score: [200.]
Episode: 13 Score: [200.]
Episode: 14 Score: [200.]
Episode: 15 Score: [200.]
Episode: 16 Score: [200.]
Episode: 17 Score: [200.]
Episode: 18 Score: [200.]
Episode: 19 Score: [200.]
Episode: 20 Score: [200.]


# 7. Viewing Logs in Tensorboard

In [41]:
training_log_path = os.path.join(log_path, 'PPO_3')

In [42]:
!tensorboard --logdir={training_log_path}

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.5.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C


# 8. Adding a Callback to The Training Stage

In [43]:
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

In [44]:
save_path = os.path.join('Training', 'SavedModels')

In [45]:
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)
eval_callback = EvalCallback(env,
                            callback_on_new_best=stop_callback,
                            eval_freq=10000,
                            best_model_save_path=save_path,
                            verbose=1)

In [46]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [47]:
model.learn(total_timesteps=30000, callback=eval_callback)

Logging to Training/Logs/PPO_4
-----------------------------
| time/              |      |
|    fps             | 770  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 526          |
|    iterations           | 2            |
|    time_elapsed         | 7            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0025689472 |
|    clip_fraction        | 0.0761       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.687       |
|    explained_variance   | 0.000449     |
|    learning_rate        | 0.0003       |
|    loss                 | 7.69         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.0116      |
|    value_loss           | 54.1         |
----------------------------



Eval num_timesteps=10000, episode_reward=200.00 +/- 0.00
Episode length: 200.00 +/- 0.00
New best mean reward!
Stopping training because the mean reward 200.00  is above the threshold 200


<stable_baselines3.ppo.ppo.PPO at 0x7ff7bc68f9d0>

# 9. Changing Policies

In [49]:
net_arch = [dict(pi=[128,128,128,128], vf=[128,128,128,128])]

In [51]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path, policy_kwargs={'net_arch':net_arch})

Using cpu device


In [52]:
model.learn(total_timesteps=30000, callback=eval_callback)

Logging to Training/Logs/PPO_5
-----------------------------
| time/              |      |
|    fps             | 429  |
|    iterations      | 1    |
|    time_elapsed    | 4    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 282         |
|    iterations           | 2           |
|    time_elapsed         | 14          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.017277732 |
|    clip_fraction        | 0.225       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.682      |
|    explained_variance   | -0.00328    |
|    learning_rate        | 0.0003      |
|    loss                 | 2.28        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.021      |
|    value_loss           | 16.9        |
-----------------------------------------
---



Eval num_timesteps=10000, episode_reward=192.60 +/- 14.80
Episode length: 192.60 +/- 14.80
----------------------------------------
| eval/                   |            |
|    mean_ep_length       | 193        |
|    mean_reward          | 193        |
| time/                   |            |
|    fps                  | 234        |
|    iterations           | 5          |
|    time_elapsed         | 43         |
|    total_timesteps      | 10240      |
| train/                  |            |
|    approx_kl            | 0.00945996 |
|    clip_fraction        | 0.131      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.581     |
|    explained_variance   | 0.496      |
|    learning_rate        | 0.0003     |
|    loss                 | 13.7       |
|    n_updates            | 40         |
|    policy_gradient_loss | -0.0222    |
|    value_loss           | 45.6       |
----------------------------------------
-----------------------------------------
| time

<stable_baselines3.ppo.ppo.PPO at 0x7ff7bc56f6d0>

# 10. Using an Alternate Algorithm

In [53]:
from stable_baselines3 import DQN

In [54]:
model = DQN('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [55]:
model.learn(total_timesteps=30000, callback=eval_callback)

Logging to Training/Logs/DQN_1
----------------------------------
| rollout/            |          |
|    exploration rate | 0.982    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 1999     |
|    time_elapsed     | 0        |
|    total timesteps  | 57       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.959    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 2700     |
|    time_elapsed     | 0        |
|    total timesteps  | 130      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.935    |
| time/               |          |
|    episodes         | 12       |
|    fps              | 2880     |
|    time_elapsed     | 0        |
|    total timesteps  | 206      |
----------------------------------
------------------------

----------------------------------
| rollout/            |          |
|    exploration rate | 0.225    |
| time/               |          |
|    episodes         | 108      |
|    fps              | 3502     |
|    time_elapsed     | 0        |
|    total timesteps  | 2448     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.198    |
| time/               |          |
|    episodes         | 112      |
|    fps              | 3518     |
|    time_elapsed     | 0        |
|    total timesteps  | 2534     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.172    |
| time/               |          |
|    episodes         | 116      |
|    fps              | 3535     |
|    time_elapsed     | 0        |
|    total timesteps  | 2616     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 216      |
|    fps              | 3065     |
|    time_elapsed     | 1        |
|    total timesteps  | 5095     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 220      |
|    fps              | 3070     |
|    time_elapsed     | 1        |
|    total timesteps  | 5184     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 224      |
|    fps              | 3073     |
|    time_elapsed     | 1        |
|    total timesteps  | 5243     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 324      |
|    fps              | 3247     |
|    time_elapsed     | 2        |
|    total timesteps  | 7507     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 328      |
|    fps              | 3258     |
|    time_elapsed     | 2        |
|    total timesteps  | 7597     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 332      |
|    fps              | 3277     |
|    time_elapsed     | 2        |
|    total timesteps  | 7718     |
----------------------------------
----------------------------------
| rollout/          



----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 424      |
|    fps              | 3419     |
|    time_elapsed     | 2        |
|    total timesteps  | 9958     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 428      |
|    fps              | 3384     |
|    time_elapsed     | 2        |
|    total timesteps  | 10056    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 432      |
|    fps              | 3394     |
|    time_elapsed     | 2        |
|    total timesteps  | 10151    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 532      |
|    fps              | 3572     |
|    time_elapsed     | 3        |
|    total timesteps  | 12695    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 536      |
|    fps              | 3575     |
|    time_elapsed     | 3        |
|    total timesteps  | 12776    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 540      |
|    fps              | 3579     |
|    time_elapsed     | 3        |
|    total timesteps  | 12842    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 640      |
|    fps              | 3703     |
|    time_elapsed     | 4        |
|    total timesteps  | 15311    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 644      |
|    fps              | 3706     |
|    time_elapsed     | 4        |
|    total timesteps  | 15385    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 648      |
|    fps              | 3710     |
|    time_elapsed     | 4        |
|    total timesteps  | 15461    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 748      |
|    fps              | 3789     |
|    time_elapsed     | 4        |
|    total timesteps  | 17685    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 752      |
|    fps              | 3791     |
|    time_elapsed     | 4        |
|    total timesteps  | 17767    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 756      |
|    fps              | 3796     |
|    time_elapsed     | 4        |
|    total timesteps  | 17867    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 852      |
|    fps              | 3825     |
|    time_elapsed     | 5        |
|    total timesteps  | 19897    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 856      |
|    fps              | 3828     |
|    time_elapsed     | 5        |
|    total timesteps  | 20022    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 860      |
|    fps              | 3831     |
|    time_elapsed     | 5        |
|    total timesteps  | 20106    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 960      |
|    fps              | 3881     |
|    time_elapsed     | 5        |
|    total timesteps  | 22357    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 964      |
|    fps              | 3882     |
|    time_elapsed     | 5        |
|    total timesteps  | 22413    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 968      |
|    fps              | 3889     |
|    time_elapsed     | 5        |
|    total timesteps  | 22600    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 1068     |
|    fps              | 3880     |
|    time_elapsed     | 6        |
|    total timesteps  | 24974    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 1072     |
|    fps              | 3874     |
|    time_elapsed     | 6        |
|    total timesteps  | 25093    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 1076     |
|    fps              | 3861     |
|    time_elapsed     | 6        |
|    total timesteps  | 25164    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 1176     |
|    fps              | 3844     |
|    time_elapsed     | 7        |
|    total timesteps  | 27342    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 1180     |
|    fps              | 3843     |
|    time_elapsed     | 7        |
|    total timesteps  | 27411    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 1184     |
|    fps              | 3846     |
|    time_elapsed     | 7        |
|    total timesteps  | 27531    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 1280     |
|    fps              | 3796     |
|    time_elapsed     | 7        |
|    total timesteps  | 29810    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 1284     |
|    fps              | 3794     |
|    time_elapsed     | 7        |
|    total timesteps  | 29911    |
----------------------------------


<stable_baselines3.dqn.dqn.DQN at 0x7ff7bc60a310>

In [60]:
dqn_path = os.path.join('Training', 'Saved Models', 'DQN_model')

In [None]:
model.save(dqn_path)

In [None]:
model = DQN.load(dqn_path, env=env)

In [None]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)

In [None]:
env.close()