# 1. Import Dependencies

In [1]:
import gym 
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_atari_env
import os

# 2. Test Environment

In [2]:
environment_name = "Breakout-v0"

## You need to install the Atari Roms from http://www.atarimania.com/rom_collection_archive_atari_2600_roms.html
In order to import ROMS, you need to download Roms.rar from the Atari 2600 VCS ROM Collection and extract the .rar file. Once you've done that, run:

python -m atari_py.import_roms <path to folder>

In [3]:
env = gym.make(environment_name)

In [4]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()



Episode:1 Score:0.0
Episode:2 Score:1.0
Episode:3 Score:1.0
Episode:4 Score:2.0
Episode:5 Score:2.0


In [5]:
env.action_space.sample()

3

In [6]:
env.observation_space.sample()

array([[[134,  63, 126],
        [ 47, 103, 237],
        [240,  82, 187],
        ...,
        [251, 157, 111],
        [161, 234,  10],
        [125,  98, 116]],

       [[ 65, 204, 242],
        [  4, 168, 208],
        [ 80, 110, 179],
        ...,
        [ 24,  12, 244],
        [ 38, 173,  87],
        [132, 117,   0]],

       [[137, 197,  45],
        [ 12,  38, 120],
        [246, 204, 222],
        ...,
        [156, 192,  48],
        [155,  97, 122],
        [ 23,  27, 251]],

       ...,

       [[100, 200,  92],
        [  2,  56,  74],
        [ 29, 224,  40],
        ...,
        [ 11,  15, 216],
        [ 85, 249, 100],
        [ 99, 233, 244]],

       [[145, 141,  42],
        [ 13, 243, 195],
        [ 91, 170, 149],
        ...,
        [170, 213, 223],
        [193, 222,   3],
        [ 94, 162, 184]],

       [[115,  85,   8],
        [255, 113, 168],
        [198, 123, 242],
        ...,
        [ 96,  99, 117],
        [ 32, 193, 191],
        [227,   6,  12]]

# 3. Vectorise Environment and Train Model

In [7]:
env = make_atari_env('Breakout-v0', n_envs=4, seed=0)

In [8]:
env = VecFrameStack(env, n_stack=4)

In [9]:
log_path = os.path.join('Training', 'Logs')

In [10]:
model = A2C("CnnPolicy", env, verbose=1, tensorboard_log=log_path)

Using cuda device
Wrapping the env in a VecTransposeImage.


In [11]:
model.learn(total_timesteps=400000)

Logging to Training\Logs\A2C_1
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 288      |
|    ep_rew_mean        | 1.59     |
| time/                 |          |
|    fps                | 66       |
|    iterations         | 100      |
|    time_elapsed       | 30       |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -1.38    |
|    explained_variance | 0.0902   |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | -0.102   |
|    value_loss         | 0.0428   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 302      |
|    ep_rew_mean        | 1.96     |
| time/                 |          |
|    fps                | 74       |
|    iterations         | 200      |
|    time_elapsed       | 53       |
|    total_timesteps    | 4000     |
| train

<stable_baselines3.a2c.a2c.A2C at 0x2a5ce79e8c8>

# 4. Save and Reload Model

In [12]:
a2c_path = os.path.join('Training', 'Saved Models', 'A2C_model')

In [13]:
model.save(a2c_path)

In [14]:
del model

In [15]:
env = make_atari_env('Breakout-v0', n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)

In [16]:
model = A2C.load(a2c_path, env)

Wrapping the env in a VecTransposeImage.


# 5. Evaluate and Test

In [19]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)

(13.8, 4.354308211415448)

In [None]:
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

In [20]:
env.close()