# Training notebook for simplicial and relational agent

This is the training notebook for the simplicial agent of the paper "Logic and the 2-Simplicial Transformer", by James Clift, Dmitry Doryn, Daniel Murfet and James Wallbridge. The official repository is here: https://github.com/dmurfet/2simplicialtransformer.

Libraries used:
* Keras, TensorFlow (1.13.1)
* Keras-Transformer (https://github.com/kpot/keras-transformer)
* RLlib / Ray
* OpenAI Gym

Throughout the file are various quotes, always from the paper 

- V. Zambaldi, D. Raposo, A. Santoro, V. Bapst, Y. Li, I. Babuschkin, K. Tuyls, D. Reichert, T. Lillicrap, E. Lockhart, M. Shanahan, V. Langston, R. Pascanu, M. Botvinick, O. Vinyals and P. Battaglia, [Deep reinforcement learning with relational inductive biases](https://openreview.net/forum?id=HkxaFoC9KQ), in Proceedings of the International Conference on Learning Representations (ICLR), 2019.

## Installation

There is a more detailed installation guide on the GitHub repoistory. Note that to be safe, you should have the 0.7.0.dev2 version of ray

```
pip3 install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl
```

and have patched the Ray RLlib files using the versions distributed with this notebook, for instance:

```
cp ~/2simplicialtransformer/python/policy_evaluator-0.7.0.dev2-edited.py ~/.local/lib/python3.6/site-packages/ray/rllib/evaluation/policy_evaluator.py
cp ~/2simplicialtransformer/python/impala-0.7.0.dev2-edited.py ~/.local/lib/python3.6/site-packages/ray/rllib/agents/impala/impala.py
```

This of course will depend on where you have Python installed. You will need the libraries for the environment and agents to be in your Python path:

```
ln -s ~/2simplicialtransformer/env/bridge_boxworld.py ~/.local/lib/python3.6/site-packages/bridge_boxworld.py
ln -s ~/2simplicialtransformer/agent/agent_relational.py ~/.local/lib/python3.6/site-packages/agent_relational.py
ln -s ~/2simplicialtransformer/agent/agent_simplicial.py ~/.local/lib/python3.6/site-packages/agent_simplicial.py
```

## Before running this notebook

You will need to have Ray running, e.g. with `ray start --head --redis-port=6379 --num-cpus=64 --num-gpus=1` and then you need to copy the output of ray start into the `REDIS_ADDRESS` flag below. Also set the number of virtual CPUs and GPUs. Ray will automatically write Tensorboard files, `tensorboard --logdir=~/ray_results > /dev/null 2>&1 &`.

## Config flags

In [None]:
REDIS_ADDRESS="192.168.2.2:6379" # output of ray start
AGENT = "base_model" # simplicial_model (simplicial agent), base_model (relational agent)
LOG_FILENAME = "/home/murfetd/log-base-test.txt" # should vary between runs
NUM_WORKERS = 63 # one less than your number of vCPUs
NUM_GPUS = 1
TIMESTEPS_TOTAL = 100e8 # training will never finish, interrupt to stop

## Imports

In [None]:
from __future__ import print_function
import numpy as np
import math
import io
import os
import sys
import base64
from random import randint, sample, shuffle
import random

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # turn off LOG and WARNINGs 

import tensorflow as tf
from tensorflow import logging
logging.set_verbosity(tf.logging.ERROR)

import ray
from ray import tune
from ray.tune import grid_search

import gym
gym.logger.set_level(40) #error only
#gym.logger.set_level(10) #debug
from gym import spaces
from gym.utils import seeding
from gym.spaces import Discrete, Box
import copy
import colorsys

## Environment

The environment is an object conforming to the OpenAI Gym Environment API. The environment is split between a part in the notebook (below) and a part in the bridge BoxWorld environment library. There is no good reason for this, but nobody promised you a notebook conforming to the highest standards of software engineering. It works.

In [None]:
import bridge_boxworld as boxworld

def createColoursArrayHSV():
  colours = []
  
  for h in range(0, 360, 18):
    rgb = colorsys.hsv_to_rgb(h/360, 0.7, 0.8)
    colours.append(tuple([int(round(c*255.0)) for c in rgb]))
    
  # randomise the order of the colours
  random.shuffle(colours)
  
  colours.insert(0,(255,255,255))

  return colours

def board_to_rgb(b,inv,colours):
  """Converts a Clift board into an RGB board"""
  
  # Keep in mind that in boxworld.py, b[x][y] is the entry in 
  # x-coordinate x and y-coordinate y with the origin in the
  # upper left corner of the screen, but in rgb_board (according
  # to matplotlib convention) [y,x,:] is the RGB vector for the
  # tile in position (x,y).
  
  numCols = len(b)
  numRows = len(b[0])

  # We have an extra column for the inventory
  # "Keys that an agent has in its possession are depicted in the input
  # observation as a pixel in the top-left corner.""
  rgb_board = np.zeros([numRows,numCols+1,3],dtype=np.uint8)
  
  for x in range(numCols+1):
    for y in range(numRows):
        if( x < numCols ):
            c = b[x][y]
        else: # inventory
            if( y < len(inv) ):
                c = [inv[y]]
            else:
                c = []
                    
        if( len(c) == 0 and x == numCols ):
            rgb_board[y,x,:] = [10,10,10]
        elif( len(c) == 0 ):
            rgb_board[y,x,:] = [199,199,199]
        elif( len(c) == 1 ):
            rgb_board[y,x,:] = colours[c[0]]

  return rgb_board

# NOTE: we have borrowed from this example
# https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_env.py

class BoxWorld(gym.Env):

    # The state of the game consists of
    #
    # 1. The board, an array of integers in [0,sol_length] of shape (num_rows,num_cols)
    # 2. The player position, a pair (y,x) of integers
    # 3. The inventory, a list of nonzero integers
    #
    # When asked for an observation of the environment, we return rgb_board
    # applied to the current board, with the player position painted PLAYER_COLOUR
    # Note that observations are generated with RGB range 0-255. There is an extra
    # column added on the right hand side, which contains (from the top) a rendering
    # of the content of the player inventory, one square per key (if this exceeds
    # the size of the screen, too bad! TODO).
    
    metadata = {
        'render.modes': ['rgb_array'],
        'video.frames_per_second': 30
    }
 
    reward_range = (-float('inf'), float('inf'))
    
    def __init__(self, config):
        self.num_rows         = config["num_rows"]
        self.num_cols         = config["num_cols"]
        self.max_decoy_paths  = config["max_decoy_paths"]
        self.max_decoy_length = config["max_decoy_length"]
        self.min_sol_length   = config["min_sol_length"]
        self.max_sol_length   = config["max_sol_length"]
        self.episode_horizon  = config["episode_horizon"]
        self.monitor_interval = config["monitor_interval"]
        self.multi_lock       = config["multi_lock"]
        self.has_bridge       = config["has_bridge"]
        self.num_steps        = 0
        
        # The valid entries in a square are 0,....,sol_length
        
        # For spaces see https://github.com/openai/gym/blob/master/gym/spaces/box.py
        self.action_space = spaces.Discrete(4)
        self.observation_space = spaces.Box(low=0,high=255, shape=(self.num_rows,self.num_cols+1,3), dtype=np.uint8)

        self.seed()

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def step(self, action):
        assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action))
        
        self.num_steps += 1
        
        # actions
        # 0 = left
        # 1 = up
        # 2 = right
        # 3 = down
        
        if( action == 0 ): frame_input = [-1,0]
        if( action == 1 ): frame_input = [0,-1]
        if( action == 2 ): frame_input = [1,0]
        if( action == 3 ): frame_input = [0,1]
          
        old_inventory = copy.deepcopy(self.inventory)
        
        self.board, self.pos, self.inventory = boxworld.updateState(self.board,self.pos,self.inventory,frame_input,self.multi_lock)
        
        # p.14 of Zambaldi et al states:
        # "An agent receives a reward of +10 for collecting the gem, +1 for opening a box
        # in the solution path and −1 for opening a distractor box. A level terminates
        # immediately after the gem is collected or a distractor box is opened."
        #
        # Currently we give +10 for the gem, and +1 for a change in inventory, which
        # is a proxy for opening a box
        
        reward = 0.0
        done = False
        
        # If you get the Gem, game over (you win)
        if( boxworld.GOAL_TILE in self.inventory ):
          done = True
          reward += 10.0

        # If you get a decoy, game over (you lose)
        if( self.pos in self.decoys ):
          done = True
          reward += -1.0
        
        if( self.inventory != old_inventory and not done ):
          reward += 1.0
            
        if( self.episode_horizon != -1 and self.num_steps >= self.episode_horizon ):
          done = True

        obs = self._generate_obs()
       
        return obs, reward, done, {}

    def reset(self):
        random.seed() # in particular, we use this seed for the box colours
        self.decoys = []
        self.num_steps = 0
        
        self.colours = createColoursArrayHSV()
        self.sol_length = random.randint(self.min_sol_length,self.max_sol_length)
        self.num_decoy_paths = random.randint(0,self.max_decoy_paths)

        self.board, self.decoys = boxworld.generatePuzzle(numCols=self.num_cols,
                                             numRows=self.num_rows,
                                             solutionLength=self.sol_length,
                                             numDecoyPaths=self.num_decoy_paths,
                                             maxDecoyLength=self.max_decoy_length,
                                             multiLock=self.multi_lock,
                                             hasBridge=self.has_bridge)
        self.pos = [0,0] # (x,y)
        self.inventory = []
        
        obs = self._generate_obs()
        
        return obs

    def render(self, mode='rgb_array'):
        obs = self._generate_obs()
        
        # Scale up for the video
        SCALE_FACTOR = 10
        obs_scaled = np.zeros([self.num_rows*SCALE_FACTOR, (self.num_cols+1)*SCALE_FACTOR,3],dtype=np.uint8)
        
        for y in range(self.num_rows):
          for x in range(self.num_cols+1):
            c = obs[y,x]
            
            for i in range(SCALE_FACTOR):
              for j in range(SCALE_FACTOR):
                obs_scaled[y*SCALE_FACTOR+i,x*SCALE_FACTOR+j] = c
            
        return obs_scaled

    def _generate_obs(self):
        """Generate an RGB observation of the board"""
        rgb_board = board_to_rgb(self.board,self.inventory,self.colours)
        rgb_board[self.pos[1],self.pos[0]] = [127,127,127] # PLAYER_COLOUR
        
        return rgb_board

## Agent

In [None]:
from ray.rllib.models import ModelCatalog
from agent_relational import BaseModel
from agent_simplicial import SimplicialModel

ModelCatalog.register_custom_model("base_model", BaseModel)
ModelCatalog.register_custom_model("simplicial_model", SimplicialModel)

## Training

Relevant quotes from Zambaldi et al:

- "We used distributed A2C agents with off-policy corrections (Espeholt et al., 2018). Each agents consisted of 100 actors generating trajectories of experience, and a single learner, which learns pi and B using the actors’ experiences. The model updates were performed on GPU using mini-batches of 32 trajectories provided by the actors via a queue. The agents used an entropy cost of 0.005, discount of 0.99 and unroll length of 40 steps. Training was done using RMSprop optimiser with momentum of 0, epsilon of 0.1 and a decay term of 0.99. The learning rate was tuned, taking values between 1e−5 and 2e−4."
- p.5 "The training set-up consisted of Box-World levels with solution lengths of at least 1 and up to 4. This ensured that an untrained agent would have a small probability of reaching the goal by chance, at least on some levels.2 The number of distractor branches was randomly sampled from 0 to 4. Training was split into two variants of the task: one with distractor branches of length 1; another one with distractor branches of length 3 (see Figure 3)." 
- Note their solution length 1 is our sol_length=1 (i.e. one locked box)

For default IMPALA hyperparameters see
- https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/impala/impala.py

For an explanation of the various flags specific to Ray Tune, see
- https://ray.readthedocs.io/en/latest/tune-usage.html
- https://ray.readthedocs.io/en/latest/rllib-training.html

Custom metrics:
- https://ray.readthedocs.io/en/latest/rllib-training.html#callbacks-and-custom-metrics
- https://github.com/ray-project/ray/blob/master/python/ray/rllib/evaluation/sampler.py

In [None]:
def on_episode_end(info):
    env = (info["env"].get_unwrapped())[0]
    episode = info["episode"]
    
    if( boxworld.GOAL_TILE in env.inventory ):
      episode.custom_metrics["win"] = 1
    else:
      episode.custom_metrics["win"] = 0    

    for i in range(0,5):
        if( len(env.decoys) == i ):
            if( boxworld.GOAL_TILE in env.inventory ):
                episode.custom_metrics["win_on_decoys_" + str(i)] = 1
            else:
                episode.custom_metrics["win_on_decoys_" + str(i)] = 0

# Capture logging output.
LOG_STDOUT = True

class flushfile():
    def __init__(self, f):
        self.f = f
    def __getattr__(self,name): 
        return object.__getattribute__(self.f, name)
    def write(self, x):
        self.f.write(x)
        self.f.flush()
    def flush(self):
        self.f.flush()

if( LOG_STDOUT ):
  oldstdout = sys.stdout
  sys.stdout = open(LOG_FILENAME, 'w')
  sys.stdout = flushfile(sys.stdout)

# You can try an object_store_memory cap to avoid RayOutOfMemoryError
# on large PBT sessions (e.g. 10 workers) object_store_memory=5000000000
ray.init(redis_address=REDIS_ADDRESS, log_to_driver=False)

tune.run(
        "IMPALA",
        stop={
            "timesteps_total": TIMESTEPS_TOTAL,
        },
        config={
            "env": BoxWorld,
            "callbacks": {
                "on_episode_end": tune.function(on_episode_end)
            },
            "model": {
                "custom_model": AGENT,
                "custom_options": {"transformer_model_dim": 64,
                                   "transformer_simplicial_model_dim": 48,
                                   "transformer_num_heads": 2, # default 2 (for 1-simplicial attention)
                                   "transformer_depth": 2, # default 2
                                   "conv_padding": "valid", # default "valid"
                                   "num_virtual_entities": 2
                                  }
            },
            "monitor": False,
            "ignore_worker_failures": True,
            "sample_batch_size": 40,
            "train_batch_size": 1280,
            "num_workers": NUM_WORKERS,
            "num_gpus": NUM_GPUS,
            "log_level": "WARN", # DEBUG INFO WARN ERROR, default is INFO
            "opt_type": "rmsprop",
            "decay": 0.99, # RMSprop
            "momentum": 0.0, # RMSprop
            "epsilon": 0.1, # RMSprop
            "entropy_coeff": 0.005,
            "lr": 2e-4,
            "env_config": {"num_rows":7,
                           "num_cols":9,
                           "min_sol_length":1,
                           "max_sol_length":3,
                           "max_decoy_paths":0,
                           "max_decoy_length":1,
                           "multi_lock":True,
                           "has_bridge":True,
                           "episode_horizon":-1,# -1 for no horizon
                          "monitor_interval":800}, 
        },
        reuse_actors=True,
        checkpoint_freq=50,
        checkpoint_at_end=True,
        max_failures=100,
        resume=False,
)

if( LOG_STDOUT ):
  sys.stdout = oldstdout
    
ray.shutdown()

## Manual environment driving

Execute the following cells if you want to play the environment by hand

In [None]:
# Manually test the environment

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

env = BoxWorld(config={"num_rows":7,
                           "num_cols":9,
                           "min_sol_length":1,
                           "max_sol_length":3,
                           "max_decoy_paths":0,
                           "max_decoy_length":1,
                           "multi_lock":True,
                           "has_bridge":True,
                           "episode_horizon":-1,# -1 for no horizon
                          "monitor_interval":200})
for i in range(1000):
  env.reset()

obs = env.step(2)

plt.imshow(obs[0]/255)
plt.show()

print(obs[1])
print(obs[2])
print(env.inventory)