In [None]:
# Setup Colab environment

import sys

if 'google.colab' in sys.modules:
    !apt-get install -qq cmake libopenmpi-dev zlib1g-dev
    !pip install -q git+https://github.com/RerRayne/stable-baselines

    !git clone --quiet https://github.com/Pastafarianist/rl-attention.git
    
    sys.path.append('rl-attention')

In [None]:
import json
import logging
import os
import random
import time
from datetime import datetime
from pprint import pprint
from pathlib import Path

import gym
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from stable_baselines import A2C
from stable_baselines.common import set_global_seeds
from stable_baselines.common.cmd_util import make_atari_env
from stable_baselines.common.policies import CnnPolicy
from stable_baselines.common.vec_env import VecFrameStack, VecNormalize
from stable_baselines.results_plotter import load_results, ts2xy

from models import attention_cnn

from google.colab import drive
from google.colab import files

drive.mount('/content/gdrive')

log_dir = Path("/content/gdrive/My Drive/rl-attention/tensorboard_logs/")
log_dir.mkdir(exist_ok=True, parents=True)
saved_params_dir = Path("/content/gdrive/My Drive/rl-attention/saved_params/")
saved_params_dir.mkdir(exist_ok=True, parents=True)

In [None]:
class Callback(object):
    def __init__(self):
        self.best_mean_reward = -np.inf
        self.n_steps = 0
    
    def __call__(self, _locals, _globals):
        """
        Callback called at each step (for DQN an others) or after n steps (see ACER or PPO2)
        :param _locals: (dict)
        :param _globals: (dict)
        """
        if self.n_steps % 100 == 1:
            # Display frames
            plt.grid(None)
            plt.imshow(_locals["self"].env.render(mode='rgb_array'))
            plt.show()
            print("nsteps", self.n_steps)
            #pprint(_locals['self'].graph.get_collection("trainable_variables"))
            
            # Save current parameters to Google drive
            model_save_target = "GDRIVE"
            model_save_filename = "params-" + datetime.now().strftime('%Y-%m-%d_%H:%M:%S') + ".pkl"
            
            if model_save_target == "GDRIVE":
                _locals['self'].save(str(Path(saved_params_dir, model_save_filename)))
            else:
                model_save_filepath = str(Path("/tmp/", model_save_filename))
                _locals['self'].save(model_save_filepath)
                files.download(model_save_filepath)

        self.n_steps += 1
        return True

In [None]:
env = make_atari_env('PongNoFrameskip-v4', num_env=16, seed=0)
env = VecFrameStack(env, n_stack=4) # stack 4 frames

# TODO: perhaps the below can be packaged up into a nice .py file
# with different types of training algorithms etc.

model_params_source = 'FRESH' # or: 'UPLOAD' or 'GDRIVE'

if model_params_source == 'UPLOAD':
    uploaded_files = files.upload()  # will save the file to the /content folder
    uploaded_filename = list(uploaded_files.keys())[0]
    model = A2C.load(
        uploaded_filename,
        env,
        prioritized_replay=True,
        param_noise=True,
        policy_kwargs={'cnn_extractor': attention_cnn}
    )
    
elif model_params_source == 'GDRIVE':
    model = A2C.load(
        str(Path(saved_params_dir, "./params-2019-04-21_18:50:34.pkl")),
        env,
        prioritized_replay=True,
        param_noise=True,
        policy_kwargs={'cnn_extractor': attention_cnn}
    )   
    
elif model_params_source == 'FRESH':
    model = A2C(
        CnnPolicy,
        env,
        lr_schedule='constant',
        verbose=1,
        policy_kwargs={'cnn_extractor': attention_cnn}
    )
callback = Callback()

model.learn(total_timesteps=int(1e7), callback=callback)