# Import Packages


In [1]:
import os;
import gym;
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

# Load Environment


In [2]:
environment_name ='CartPole-v1'
env=gym.make(environment_name)


In [3]:
environment_name

'CartPole-v1'

In [5]:
episodes=5
for episode in range(1,episodes+1):
    state=env.reset()
    done=False
    score=0
    while not done:
        env.render()
        action=env.action_space.sample()
        n_state,reward,done,info=env.step(action)
        score+=reward
    print('Episode:{} Score{}'.format(episode,score))
env.close()
        

Episode:1 Score17.0
Episode:2 Score17.0
Episode:3 Score16.0
Episode:4 Score10.0
Episode:5 Score17.0


# Train Model


In [4]:
#make dir in path
log_path=os.path.join('Training','Logs')

In [5]:
log_path

'Training/Logs'

In [6]:
env=gym.make(environment_name)
env=DummyVecEnv([lambda:env])
model=PPO('MlpPolicy',env,verbose=1,tensorboard_log=log_path)

Using cpu device


In [7]:
model.learn(total_timesteps=20000)

Logging to Training/Logs/PPO_6
-----------------------------
| time/              |      |
|    fps             | 858  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1357        |
|    iterations           | 2           |
|    time_elapsed         | 3           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008344904 |
|    clip_fraction        | 0.0825      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.687      |
|    explained_variance   | -0.00299    |
|    learning_rate        | 0.0003      |
|    loss                 | 10.9        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0127     |
|    value_loss           | 56          |
-----------------------------------------
---

<stable_baselines3.ppo.ppo.PPO at 0x2873339a0>

In [8]:
PPO_Path=os.path.join('Training','Saved Models','PPO_Model_CartPole')

In [9]:
model.save(PPO_Path)

In [10]:
evaluate_policy(model,env,n_eval_episodes=10,render=True)



(500.0, 0.0)

# Testing

In [13]:
env.close()


In [15]:
episodes=5
for episode in range(1,episodes+1):
    obs=env.reset()
    done=False
    score=0
    while not done:
        env.render()
        action,_=model.predict(obs)
        n_state,reward,done,info=env.step(action)
        score+=reward
    print('Episode:{} Score{}'.format(episode,score))
env.close()
        

Episode:1 Score[15.]
Episode:2 Score[22.]
Episode:3 Score[15.]
Episode:4 Score[10.]
Episode:5 Score[21.]


# Viewing Logs in TensorBoard


In [16]:
training_log_path=os.path.join(log_path,'PPO_2')

In [17]:
training_log_path

'Training/Logs/PPO_2'

Use this command in terminal 
$tensorboard --logdir=Training/Logs/PPO_2
http://localhost:6006/

# Adding Callback to the training stage

In [23]:
from stable_baselines3.common.callbacks import EvalCallback,StopTrainingOnRewardThreshold

In [27]:
save_path=os.path.join('Training','Saved Models')

In [28]:
stop_callback=StopTrainingOnRewardThreshold(reward_threshold=200,verbose=1)
eval_callback=EvalCallback(env,callback_on_new_best=stop_callback,
                          eval_freq=10000,
                          best_model_save_path=save_path,
                          verbose=1)

In [29]:
model=PPO('MlpPolicy',env,verbose=1,tensorboard_log=log_path)

Using cpu device


In [30]:
model.learn(total_timesteps=20000,callback=eval_callback)

Logging to Training/Logs/PPO_7
-----------------------------
| time/              |      |
|    fps             | 6317 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 4357        |
|    iterations           | 2           |
|    time_elapsed         | 0           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008362288 |
|    clip_fraction        | 0.0862      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | -0.00139    |
|    learning_rate        | 0.0003      |
|    loss                 | 6.48        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0125     |
|    value_loss           | 52.3        |
-----------------------------------------
---



<stable_baselines3.ppo.ppo.PPO at 0x2b59c55e0>

# Changing Policies

In [32]:
net_arch=[dict(pi=[128,128,128,128],vf=[128,128,128,128])]

In [34]:
model=PPO('MlpPolicy',env,verbose=1,tensorboard_log=log_path,policy_kwargs={'net_arch':net_arch})

Using cpu device


In [35]:
model.learn(total_timesteps=20000,callback=eval_callback)

Logging to Training/Logs/PPO_8
-----------------------------
| time/              |      |
|    fps             | 5134 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 3035        |
|    iterations           | 2           |
|    time_elapsed         | 1           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.014861594 |
|    clip_fraction        | 0.222       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.681      |
|    explained_variance   | 0.00236     |
|    learning_rate        | 0.0003      |
|    loss                 | 3.09        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0246     |
|    value_loss           | 18          |
-----------------------------------------
---



------------------------------
| time/              |       |
|    fps             | 2304  |
|    iterations      | 5     |
|    time_elapsed    | 4     |
|    total_timesteps | 10240 |
------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 2276        |
|    iterations           | 6           |
|    time_elapsed         | 5           |
|    total_timesteps      | 12288       |
| train/                  |             |
|    approx_kl            | 0.008786695 |
|    clip_fraction        | 0.111       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.564      |
|    explained_variance   | 0.799       |
|    learning_rate        | 0.0003      |
|    loss                 | 3.5         |
|    n_updates            | 50          |
|    policy_gradient_loss | -0.0118     |
|    value_loss           | 18.8        |
-----------------------------------------
---------------------------

<stable_baselines3.ppo.ppo.PPO at 0x2b55f92e0>

# Using Alternative Algorithm

In [36]:
from stable_baselines3 import DQN

In [37]:
model=DQN('MlpPolicy',env,verbose=1,tensorboard_log=log_path)

Using cpu device


In [38]:
model.learn(total_timesteps=20000)

Logging to Training/Logs/DQN_1
----------------------------------
| rollout/            |          |
|    exploration rate | 0.964    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 5810     |
|    time_elapsed     | 0        |
|    total timesteps  | 75       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.921    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 7588     |
|    time_elapsed     | 0        |
|    total timesteps  | 167      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.88     |
| time/               |          |
|    episodes         | 12       |
|    fps              | 9151     |
|    time_elapsed     | 0        |
|    total timesteps  | 252      |
----------------------------------
------------------------

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 108      |
|    fps              | 21437    |
|    time_elapsed     | 0        |
|    total timesteps  | 2343     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 112      |
|    fps              | 21908    |
|    time_elapsed     | 0        |
|    total timesteps  | 2462     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 116      |
|    fps              | 22145    |
|    time_elapsed     | 0        |
|    total timesteps  | 2536     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 216      |
|    fps              | 26100    |
|    time_elapsed     | 0        |
|    total timesteps  | 4568     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 220      |
|    fps              | 26230    |
|    time_elapsed     | 0        |
|    total timesteps  | 4654     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 224      |
|    fps              | 26296    |
|    time_elapsed     | 0        |
|    total timesteps  | 4719     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 324      |
|    fps              | 28412    |
|    time_elapsed     | 0        |
|    total timesteps  | 6887     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 328      |
|    fps              | 28549    |
|    time_elapsed     | 0        |
|    total timesteps  | 7009     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 332      |
|    fps              | 28591    |
|    time_elapsed     | 0        |
|    total timesteps  | 7075     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 432      |
|    fps              | 29187    |
|    time_elapsed     | 0        |
|    total timesteps  | 9227     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 436      |
|    fps              | 29194    |
|    time_elapsed     | 0        |
|    total timesteps  | 9317     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 440      |
|    fps              | 29177    |
|    time_elapsed     | 0        |
|    total timesteps  | 9435     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 540      |
|    fps              | 29502    |
|    time_elapsed     | 0        |
|    total timesteps  | 11743    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 544      |
|    fps              | 29514    |
|    time_elapsed     | 0        |
|    total timesteps  | 11826    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 548      |
|    fps              | 29474    |
|    time_elapsed     | 0        |
|    total timesteps  | 11934    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 648      |
|    fps              | 30174    |
|    time_elapsed     | 0        |
|    total timesteps  | 14221    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 652      |
|    fps              | 30205    |
|    time_elapsed     | 0        |
|    total timesteps  | 14341    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 656      |
|    fps              | 30222    |
|    time_elapsed     | 0        |
|    total timesteps  | 14415    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 756      |
|    fps              | 30539    |
|    time_elapsed     | 0        |
|    total timesteps  | 16680    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 760      |
|    fps              | 30581    |
|    time_elapsed     | 0        |
|    total timesteps  | 16793    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 764      |
|    fps              | 30591    |
|    time_elapsed     | 0        |
|    total timesteps  | 16864    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 864      |
|    fps              | 30644    |
|    time_elapsed     | 0        |
|    total timesteps  | 18867    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 868      |
|    fps              | 30650    |
|    time_elapsed     | 0        |
|    total timesteps  | 18946    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 872      |
|    fps              | 30611    |
|    time_elapsed     | 0        |
|    total timesteps  | 18999    |
----------------------------------
----------------------------------
| rollout/          

<stable_baselines3.dqn.dqn.DQN at 0x2b5af1ac0>

In [40]:
model.save(save_path)