# Imports

In [None]:
import os
import gym
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_atari_env

# Test environment

In [None]:
# !curl www.atarimania.com/roms/Roms.rar -o "../input/roms.rar"
# !unrar x "../input/roms.rar" "../input/"
# !unzip "../input/*zip" -d "../input/"
# !cd ../input && rm 'HC ROMS'  'HC ROMS.zip'   roms.rar   ROMS.zip -r

In [None]:
!python -m atari_py.import_roms ../input/ROMS/

In [None]:
environment_name = 'Breakout-v0'
env = gym.make(environment_name)

In [None]:
env.reset()

In [None]:
env.action_space

In [None]:
env.observation_space

In [None]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score += reward
    print(f"Episode: {episode}, Score: {score}")
env.close()

# Vectorise environment and train model

In [None]:
env = make_atari_env(environment_name, n_envs=4, seed=0)
env = VecFrameStack(env, n_stack=4)

In [None]:
env.reset()
env.render()
env.close()

In [None]:
log_path = os.path.join('../output','training', 'logs')

In [None]:
model = A2C('CnnPolicy', env, verbose=1, tensorboard_log=log_path)

In [None]:
model.learn(total_timesteps=200_000)

# Save and reload model

In [None]:
a2c_path = os.path.join('../output', 'training', 'saved models', 'a2c_breakout')

In [None]:
model.save(a2c_path)
del model

In [None]:
model = A2C.load(a2c_path, env)

# Evaluate and test

In [None]:
env = make_atari_env(environment_name, n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)

In [None]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)
env.close()

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir={log_path}