In [None]:
# Setup Docker (including Colab) environment

import sys
from pathlib import Path

if Path('/.dockerenv').exists():
    !apt install cmake libopenmpi-dev zlib1g-dev
    !pip install git+https://github.com/RerRayne/stable-baselines

    !git clone https://github.com/Pastafarianist/rl-attention.git
    
    sys.path.append('rl-attention')

In [None]:
import json
import logging
import os
import random
import time
from pprint import pprint

import gym
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from stable_baselines import A2C
from stable_baselines.common import set_global_seeds
from stable_baselines.common.cmd_util import make_atari_env
from stable_baselines.common.policies import CnnPolicy
from stable_baselines.common.vec_env import VecFrameStack, VecNormalize
from stable_baselines.results_plotter import load_results, ts2xy

from models import attention_cnn


log_dir = Path("/tmp/gym/")
log_dir.mkdir(exist_ok=True)

In [None]:
class Callback(object):
    def __init__(self):
        self.best_mean_reward = -np.inf
        self.n_steps = 0
    
    def __call__(self, _locals, _globals):
        """
        Callback called at each step (for DQN an others) or after n steps (see ACER or PPO2)
        :param _locals: (dict)
        :param _globals: (dict)
        """
        if self.n_steps % 100 == 1:
            # Display frames
            plt.grid(None)
            plt.imshow(_locals["self"].env.render(mode='rgb_array'))
            plt.show()
            #pprint(_locals['self'].graph.get_collection("trainable_variables"))

        self.n_steps += 1
        return True

In [None]:
env = make_atari_env('PongNoFrameskip-v4', num_env=16, seed=0)
env = VecFrameStack(env, n_stack=4) # stack 4 frames

model = A2C(
    CnnPolicy,
    env,
    lr_schedule='constant',
    verbose=1,
    policy_kwargs={'cnn_extractor': attention_cnn}
)
callback = Callback()

model.learn(total_timesteps=int(1e7), callback=callback)