In [None]:
import json
import os
import sys
from datetime import datetime
from pathlib import Path
from pprint import pprint

import requests


# Setup environment
if 'google.colab' in sys.modules:
    if 'rl-attention' not in sys.path:
        if not Path('rl-attention').exists():
            !git clone --quiet https://github.com/Pastafarianist/rl-attention.git
        sys.path.append('rl-attention')
    config_path = Path('rl-attention/config.json')
else:
    config_path = Path('config.json')

# Read config
with config_path.open() as fp:
    cfg = json.load(fp)
    print('Original config:')
    pprint(cfg)

# Change config here
custom_cfg = {
    
}

# Update config
if custom_cfg:
    cfg.update(custom_cfg)
    print()
    print('Modified config:')
    pprint(cfg)

# Generate run ID
if 'run_ts' not in globals():
    run_ts = datetime.now().isoformat(sep='_', timespec='milliseconds').replace(':', '-')
run_name = '{env_name},{algo},{network},{train_seed},{run_ts}'.format(run_ts=run_ts, **cfg)
print('Run ID is {}'.format(run_name))

from bootstrap import setup_environment
output_dir = setup_environment(
    run_name=run_name,
    ssh_keys=(
        key
        for username in ['Pastafarianist', 'HoagyC', 'fabiansteuer', 'skosch', 'RerRayne']
        for key in requests.get('https://github.com/{}.keys'.format(username)).text.split('\n')
        if key
    )
)

# Setup legacy logging from baselines (useful for comparisons)
from stable_baselines.logger import configure

log_dir = output_dir / cfg['log_dir']
tensorboard_dir = output_dir / cfg['tb_dir']
configure(
    log_dir=str(log_dir),
    format_strs=['log', 'csv', 'tensorboard'],
    tensorboard_dir=str(tensorboard_dir)
)

In [None]:
import gym
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
%matplotlib inline

from stable_baselines import PPO2
from stable_baselines.common import set_global_seeds
from stable_baselines.common.cmd_util import make_atari_env
from stable_baselines.common.vec_env import VecFrameStack, VecNormalize

from tqdm import tqdm_notebook

from models import get_network_builder
from losses import get_loss

In [None]:
class Callback(object):
    def __init__(self, display_frames=False):
        self.display_frames = display_frames
        self.pbar = None
    
    def __call__(self, _locals, _globals):
        if self.pbar is None:
            self.pbar = tqdm_notebook(total=_locals['nupdates'] * _locals['self'].n_batch)
        
        self.pbar.update(_locals['self'].n_batch)
        
        if _locals['update'] == _locals['nupdates']:
            self.pbar.close()
            self.pbar = None
        
        if _locals['update'] % 100 == 1 or _locals['update'] == _locals['nupdates']:
            if self.display_frames:
                plt.grid(None)
                plt.imshow(_locals["self"].env.render(mode='rgb_array'))
                plt.show()
            #print("nsteps", _locals['update'])
            #pprint(_locals['self'].graph.get_collection("trainable_variables"))
            
            # Save current model
            _locals['self'].save(str(output_dir / 'model.pkl'))

        return True

In [None]:
# Dump entire configuration
with (output_dir / 'config.json').open('w') as fp:
    json.dump(cfg, fp, indent=2)

set_global_seeds(cfg['train_seed'])

env = make_atari_env(cfg['env_name'], num_env=8, seed=cfg['train_seed'])
env = VecFrameStack(env, n_stack=4) # stack 4 frames

# TODO: perhaps the below can be packaged up into a nice .py file
# with different types of training algorithms etc.

model_params_source = 'FRESH' # or: 'UPLOAD' or 'GDRIVE'

kwargs = {
    'verbose': 1,
    'learning_rate': lambda frac: 0.00025 * frac,
    'attn_loss': get_loss(cfg['attn_loss'])(),
    'attn_coef': cfg['attn_coef'],
    'policy_kwargs': {
        'cnn_extractor': get_network_builder(cfg['network'])
    },
    'tensorboard_log': str(tensorboard_dir),
}

if model_params_source == 'FRESH':
    model = PPO2(cfg['policy_type'], env, **kwargs)
else:
    if model_params_source == 'UPLOAD':
        uploaded_files = files.upload()  # will save the file to the /content folder
        model_path = list(uploaded_files.keys())[0]
    elif model_params_source == 'GDRIVE':
        model_path = output_dir / 'model.pkl'
    else:
        raise ValueError("Source should be one of: FRESH, UPLOAD, GDRIVE")
        
    model = PPO2.load(model_path, env, **kwargs)

callback = Callback()
model.learn(
    total_timesteps=cfg['time_steps'],
    callback=callback,
    log_interval=cfg['log_interval'],
    tb_log_name=None,
)

In [None]:
!du -h $output_dir