# 1. Import Dependencies

In [1]:
import gym 
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_atari_env
import os

# 2. Test Environment

Install atari environments:
- http://www.atarimania.com/roms/Roms.rar
- depackage all the files (get two folders: ROMS, HC ROMS)
- install atari_py: `pip3 install atari-py`
- run the command: `python -m atari_py.import_roms ./ROMS`

In [2]:
environment_name = "Breakout-v0"
env = gym.make(environment_name)

In [3]:
x = env.reset()
x.shape

(210, 160, 3)

In [4]:
for i in range(10):
    print(env.action_space.sample(), end=' ')
env.action_space

2 2 2 2 1 3 2 2 1 3 

Discrete(4)

In [5]:
env.observation_space

Box([[[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]

 ...

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  ...
  [0 0 0]
  [0 0 0]
  [0 0 0]]], [[[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
  [255 255 255]
  [255 255 255]
  [255 255 255]]

 [[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
  [255 255 255]
  [255 255 255]
  [255 255 255]]

 [[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
  [255 255 255]
  [255 255 255]
  [255 255 255]]

 ...

 [[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
  [255 255 255]
  [255 255 255]
  [255 255 255]]

 [[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
  [255 255 255]
  [255 255 255]
  [255 255 255]]

 [[255 255 255]
  [255 255 255]
  [255 255 255]
  ...
 

In [6]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()

Episode:1 Score:3.0
Episode:2 Score:0.0
Episode:3 Score:1.0
Episode:4 Score:2.0
Episode:5 Score:1.0


# 3. Vectorize Environment and Train Model

In [7]:
# train 4 environments at the same time
# vectorizing multiple environments allow us to train the agent faster by training in parallel
env = make_atari_env('Breakout-v0', n_envs=4, seed=0)
env = VecFrameStack(env, n_stack=4)

In [8]:
log_path = os.path.join('Training', 'Logs')

In [9]:
model = A2C("CnnPolicy", env, verbose=1, tensorboard_log=log_path)

Using cuda device
Wrapping the env in a VecTransposeImage.


In [11]:
model.learn(total_timesteps=500000)

Logging to Training/Logs/A2C_2
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 295      |
|    ep_rew_mean        | 1.86     |
| time/                 |          |
|    fps                | 363      |
|    iterations         | 100      |
|    time_elapsed       | 5        |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -0.824   |
|    explained_variance | 0.961    |
|    learning_rate      | 0.0007   |
|    n_updates          | 976      |
|    policy_loss        | -0.00768 |
|    value_loss         | 0.00891  |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 301      |
|    ep_rew_mean        | 1.96     |
| time/                 |          |
|    fps                | 366      |
|    iterations         | 200      |
|    time_elapsed       | 10       |
|    total_timesteps    | 4000     |
| train

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 419      |
|    ep_rew_mean        | 4.55     |
| time/                 |          |
|    fps                | 366      |
|    iterations         | 1400     |
|    time_elapsed       | 76       |
|    total_timesteps    | 28000    |
| train/                |          |
|    entropy_loss       | -0.829   |
|    explained_variance | 0.944    |
|    learning_rate      | 0.0007   |
|    n_updates          | 2276     |
|    policy_loss        | -0.0247  |
|    value_loss         | 0.0539   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 427      |
|    ep_rew_mean        | 4.87     |
| time/                 |          |
|    fps                | 368      |
|    iterations         | 1500     |
|    time_elapsed       | 81       |
|    total_timesteps    | 30000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 444      |
|    ep_rew_mean        | 5.2      |
| time/                 |          |
|    fps                | 379      |
|    iterations         | 2800     |
|    time_elapsed       | 147      |
|    total_timesteps    | 56000    |
| train/                |          |
|    entropy_loss       | -1.01    |
|    explained_variance | 0.795    |
|    learning_rate      | 0.0007   |
|    n_updates          | 3676     |
|    policy_loss        | 0.0671   |
|    value_loss         | 0.127    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 447      |
|    ep_rew_mean        | 5.28     |
| time/                 |          |
|    fps                | 379      |
|    iterations         | 2900     |
|    time_elapsed       | 152      |
|    total_timesteps    | 58000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 489      |
|    ep_rew_mean        | 5.99     |
| time/                 |          |
|    fps                | 383      |
|    iterations         | 4100     |
|    time_elapsed       | 213      |
|    total_timesteps    | 82000    |
| train/                |          |
|    entropy_loss       | -0.176   |
|    explained_variance | -0.432   |
|    learning_rate      | 0.0007   |
|    n_updates          | 4976     |
|    policy_loss        | 0.0144   |
|    value_loss         | 0.075    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 491      |
|    ep_rew_mean        | 5.9      |
| time/                 |          |
|    fps                | 383      |
|    iterations         | 4200     |
|    time_elapsed       | 218      |
|    total_timesteps    | 84000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 541      |
|    ep_rew_mean        | 7.08     |
| time/                 |          |
|    fps                | 386      |
|    iterations         | 5500     |
|    time_elapsed       | 284      |
|    total_timesteps    | 110000   |
| train/                |          |
|    entropy_loss       | -0.356   |
|    explained_variance | 0.773    |
|    learning_rate      | 0.0007   |
|    n_updates          | 6376     |
|    policy_loss        | -0.032   |
|    value_loss         | 0.0785   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 550      |
|    ep_rew_mean        | 7.25     |
| time/                 |          |
|    fps                | 387      |
|    iterations         | 5600     |
|    time_elapsed       | 289      |
|    total_timesteps    | 112000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 571      |
|    ep_rew_mean        | 7.59     |
| time/                 |          |
|    fps                | 389      |
|    iterations         | 6900     |
|    time_elapsed       | 354      |
|    total_timesteps    | 138000   |
| train/                |          |
|    entropy_loss       | -0.45    |
|    explained_variance | 0.864    |
|    learning_rate      | 0.0007   |
|    n_updates          | 7776     |
|    policy_loss        | 0.0465   |
|    value_loss         | 0.162    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 572      |
|    ep_rew_mean        | 7.64     |
| time/                 |          |
|    fps                | 389      |
|    iterations         | 7000     |
|    time_elapsed       | 359      |
|    total_timesteps    | 140000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 588      |
|    ep_rew_mean        | 8.06     |
| time/                 |          |
|    fps                | 391      |
|    iterations         | 8300     |
|    time_elapsed       | 424      |
|    total_timesteps    | 166000   |
| train/                |          |
|    entropy_loss       | -0.145   |
|    explained_variance | 0.293    |
|    learning_rate      | 0.0007   |
|    n_updates          | 9176     |
|    policy_loss        | -0.00913 |
|    value_loss         | 0.0865   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 592      |
|    ep_rew_mean        | 8.05     |
| time/                 |          |
|    fps                | 391      |
|    iterations         | 8400     |
|    time_elapsed       | 428      |
|    total_timesteps    | 168000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 636      |
|    ep_rew_mean        | 9.15     |
| time/                 |          |
|    fps                | 393      |
|    iterations         | 9700     |
|    time_elapsed       | 492      |
|    total_timesteps    | 194000   |
| train/                |          |
|    entropy_loss       | -0.0451  |
|    explained_variance | 0.76     |
|    learning_rate      | 0.0007   |
|    n_updates          | 10576    |
|    policy_loss        | 0.00148  |
|    value_loss         | 0.119    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 637      |
|    ep_rew_mean        | 9.26     |
| time/                 |          |
|    fps                | 394      |
|    iterations         | 9800     |
|    time_elapsed       | 497      |
|    total_timesteps    | 196000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 656      |
|    ep_rew_mean        | 9.79     |
| time/                 |          |
|    fps                | 396      |
|    iterations         | 11100    |
|    time_elapsed       | 560      |
|    total_timesteps    | 222000   |
| train/                |          |
|    entropy_loss       | -0.315   |
|    explained_variance | 0.917    |
|    learning_rate      | 0.0007   |
|    n_updates          | 11976    |
|    policy_loss        | 0.0203   |
|    value_loss         | 0.116    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 670      |
|    ep_rew_mean        | 10       |
| time/                 |          |
|    fps                | 396      |
|    iterations         | 11200    |
|    time_elapsed       | 565      |
|    total_timesteps    | 224000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 699      |
|    ep_rew_mean        | 10.9     |
| time/                 |          |
|    fps                | 398      |
|    iterations         | 12500    |
|    time_elapsed       | 627      |
|    total_timesteps    | 250000   |
| train/                |          |
|    entropy_loss       | -0.0454  |
|    explained_variance | 0.417    |
|    learning_rate      | 0.0007   |
|    n_updates          | 13376    |
|    policy_loss        | 0.00423  |
|    value_loss         | 0.064    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 691      |
|    ep_rew_mean        | 10.9     |
| time/                 |          |
|    fps                | 398      |
|    iterations         | 12600    |
|    time_elapsed       | 632      |
|    total_timesteps    | 252000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 665      |
|    ep_rew_mean        | 10.2     |
| time/                 |          |
|    fps                | 399      |
|    iterations         | 13900    |
|    time_elapsed       | 695      |
|    total_timesteps    | 278000   |
| train/                |          |
|    entropy_loss       | -0.151   |
|    explained_variance | 0.364    |
|    learning_rate      | 0.0007   |
|    n_updates          | 14776    |
|    policy_loss        | -0.125   |
|    value_loss         | 0.499    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 674      |
|    ep_rew_mean        | 10.2     |
| time/                 |          |
|    fps                | 399      |
|    iterations         | 14000    |
|    time_elapsed       | 700      |
|    total_timesteps    | 280000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 719      |
|    ep_rew_mean        | 11.4     |
| time/                 |          |
|    fps                | 401      |
|    iterations         | 15300    |
|    time_elapsed       | 762      |
|    total_timesteps    | 306000   |
| train/                |          |
|    entropy_loss       | -0.105   |
|    explained_variance | 0.741    |
|    learning_rate      | 0.0007   |
|    n_updates          | 16176    |
|    policy_loss        | 0.0347   |
|    value_loss         | 0.104    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 713      |
|    ep_rew_mean        | 11.1     |
| time/                 |          |
|    fps                | 401      |
|    iterations         | 15400    |
|    time_elapsed       | 767      |
|    total_timesteps    | 308000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 657      |
|    ep_rew_mean        | 9.9      |
| time/                 |          |
|    fps                | 401      |
|    iterations         | 16700    |
|    time_elapsed       | 831      |
|    total_timesteps    | 334000   |
| train/                |          |
|    entropy_loss       | -0.144   |
|    explained_variance | 0.623    |
|    learning_rate      | 0.0007   |
|    n_updates          | 17576    |
|    policy_loss        | -0.0189  |
|    value_loss         | 0.181    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 651      |
|    ep_rew_mean        | 9.77     |
| time/                 |          |
|    fps                | 401      |
|    iterations         | 16800    |
|    time_elapsed       | 836      |
|    total_timesteps    | 336000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 700      |
|    ep_rew_mean        | 11.3     |
| time/                 |          |
|    fps                | 402      |
|    iterations         | 18100    |
|    time_elapsed       | 899      |
|    total_timesteps    | 362000   |
| train/                |          |
|    entropy_loss       | -0.238   |
|    explained_variance | -1.7     |
|    learning_rate      | 0.0007   |
|    n_updates          | 18976    |
|    policy_loss        | 0.151    |
|    value_loss         | 0.185    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 709      |
|    ep_rew_mean        | 11.5     |
| time/                 |          |
|    fps                | 402      |
|    iterations         | 18200    |
|    time_elapsed       | 904      |
|    total_timesteps    | 364000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 718      |
|    ep_rew_mean        | 11.5     |
| time/                 |          |
|    fps                | 403      |
|    iterations         | 19500    |
|    time_elapsed       | 967      |
|    total_timesteps    | 390000   |
| train/                |          |
|    entropy_loss       | -0.105   |
|    explained_variance | 0.871    |
|    learning_rate      | 0.0007   |
|    n_updates          | 20376    |
|    policy_loss        | 0.00264  |
|    value_loss         | 0.028    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 715      |
|    ep_rew_mean        | 11.4     |
| time/                 |          |
|    fps                | 403      |
|    iterations         | 19600    |
|    time_elapsed       | 972      |
|    total_timesteps    | 392000   |
| train/                |          |
|

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 736       |
|    ep_rew_mean        | 11.8      |
| time/                 |           |
|    fps                | 403       |
|    iterations         | 20900     |
|    time_elapsed       | 1035      |
|    total_timesteps    | 418000    |
| train/                |           |
|    entropy_loss       | -0.216    |
|    explained_variance | 0.866     |
|    learning_rate      | 0.0007    |
|    n_updates          | 21776     |
|    policy_loss        | -6.09e-05 |
|    value_loss         | 0.033     |
-------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 737      |
|    ep_rew_mean        | 11.8     |
| time/                 |          |
|    fps                | 403      |
|    iterations         | 21000    |
|    time_elapsed       | 1040     |
|    total_timesteps    | 420000   |
| train/             

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 775      |
|    ep_rew_mean        | 13.1     |
| time/                 |          |
|    fps                | 404      |
|    iterations         | 22200    |
|    time_elapsed       | 1098     |
|    total_timesteps    | 444000   |
| train/                |          |
|    entropy_loss       | -0.142   |
|    explained_variance | 0.394    |
|    learning_rate      | 0.0007   |
|    n_updates          | 23076    |
|    policy_loss        | 0.0331   |
|    value_loss         | 0.109    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 787      |
|    ep_rew_mean        | 13.4     |
| time/                 |          |
|    fps                | 404      |
|    iterations         | 22300    |
|    time_elapsed       | 1103     |
|    total_timesteps    | 446000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 764      |
|    ep_rew_mean        | 12.8     |
| time/                 |          |
|    fps                | 404      |
|    iterations         | 23600    |
|    time_elapsed       | 1165     |
|    total_timesteps    | 472000   |
| train/                |          |
|    entropy_loss       | -0.164   |
|    explained_variance | 0.873    |
|    learning_rate      | 0.0007   |
|    n_updates          | 24476    |
|    policy_loss        | 0.035    |
|    value_loss         | 0.0814   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 758      |
|    ep_rew_mean        | 12.6     |
| time/                 |          |
|    fps                | 404      |
|    iterations         | 23700    |
|    time_elapsed       | 1170     |
|    total_timesteps    | 474000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 762      |
|    ep_rew_mean        | 12.7     |
| time/                 |          |
|    fps                | 405      |
|    iterations         | 25000    |
|    time_elapsed       | 1232     |
|    total_timesteps    | 500000   |
| train/                |          |
|    entropy_loss       | -0.203   |
|    explained_variance | 0.171    |
|    learning_rate      | 0.0007   |
|    n_updates          | 25876    |
|    policy_loss        | 0.0178   |
|    value_loss         | 0.0761   |
------------------------------------


<stable_baselines3.a2c.a2c.A2C at 0x7f1b9be5a640>

# 4. Save and Reload Model

In [12]:
a2c_path = os.path.join('Training', 'Saved Models', 'A2C_model')
model.save(a2c_path)

In [13]:
del model

In [14]:
# we must pass only one environment when using evaluate_policy()
env = make_atari_env('Breakout-v0', n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)

In [15]:
model = A2C.load(a2c_path, env)

Wrapping the env in a VecTransposeImage.


# 5. Evaluate and Test

In [16]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)

(13.1, 4.846648326421054)

In [17]:
env.close()

In [23]:
for i in range(100):
    obs = env.reset()
    while True:
        action, _states = model.predict(obs)
        obs, rewards, done, info = env.step(action)
        env.render()
        if done: break
env.close()