# 1. Import dependencies

In [None]:
# Execute this first
%pip install git+https://github.com/DLR-RM/stable-baselines3
# Then install the package with extras (gymnasium, atari, etc)
%pip install stable-baselines3[extra]

In [None]:
# Install CUDA acceleration
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117

In [None]:
# Enable Atari environment
%pip install gymnasium[atari]
%pip install gymnasium[accept-rom-license]

In [1]:
import os
import gymnasium as gym
import time
# Algorithm
from stable_baselines3 import A2C
# This allows to vectorize our environment for parallel training
from stable_baselines3.common.vec_env import VecFrameStack
# Makes easier to evaluate how our model is running
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_atari_env

# 2. Test environment

In [2]:
environment_name = 'Breakout-v4'
env = gym.make(environment_name)

In [None]:
env.reset()

In [None]:
env.action_space

In [None]:
env.observation_space

In [None]:
# Test model in five episodes
env = gym.make(environment_name, render_mode='human')
episodes = 5

for episode in range(1, episodes + 1):
    obs = env.reset()
    done = False
    score = 0
    while not done:
        action = env.action_space.sample()  # Use our model here
        obs, reward, done, truncated, info = env.step(action)
        score += reward
        env.render()
    print('Episode: {} Score: {}'.format(episode, score))

In [None]:
env.close()

# 3. Vectorise environment and train model

In [3]:
environment_name = 'Breakout-v4'
env = make_atari_env(environment_name, n_envs=4, seed=0)
env = VecFrameStack(env, n_stack=4)

In [4]:
log_path = os.path.join('Training', 'Logs')
model = A2C('CnnPolicy', env, verbose=1, tensorboard_log=log_path)

Using cuda device
Wrapping the env in a VecTransposeImage.


In [5]:
model.learn(total_timesteps=300000)

Logging to Training\Logs\A2C_2
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 278      |
|    ep_rew_mean        | 1.57     |
| time/                 |          |
|    fps                | 115      |
|    iterations         | 100      |
|    time_elapsed       | 17       |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -1.38    |
|    explained_variance | 0.349    |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | -0.0843  |
|    value_loss         | 0.00681  |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 281      |
|    ep_rew_mean        | 1.57     |
| time/                 |          |
|    fps                | 147      |
|    iterations         | 200      |
|    time_elapsed       | 27       |
|    total_timesteps    | 4000     |
| train

<stable_baselines3.a2c.a2c.A2C at 0x20dd1178110>

# 4. Save and reload model

In [None]:
a2c_path = os.path.join('Training', 'Saved Models', 'A2C_Breakout_Model')

In [None]:
model.save(a2c_path)

In [None]:
del model

In [None]:
model = A2C.load(a2c_path)

# 5. Evaluate and test

In [None]:
env = make_atari_env(environment_name, n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)
evaluate_policy(model, env, n_eval_episodes=10, render=True)

In [None]:
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

In [None]:
# Manually test the model using "model.predict"
EPISODES = 5

for episode in range(1, EPISODES + 1):
    # Setup
    done = False
    score = 0
    
    # Reset environment and get initial observation
    obs = env.reset()
    
    # RL Loop
    while not done:
        # Graphical view
        env.render()
        
        # Get one random available action
        action, _ = model.predict(obs)
        
        # Take that action
        obs, reward, done, info = env.step(action)
        
        # Update score
        score += reward
        
    # Print statistics
    print(f"Episode #{episode} - Score: {score}")

In [None]:
env.close()