# LET's Test Compositionality

Michael Neely & Leila Talha

Natural Language Processing 2, Spring 2020

University of Amsterdam

## Introduction
In many Natural Language Processing (NLP) tasks, data-driven deep learning models have proven more effective than their traditional, symbolic counterparts. Neural networks can process large amounts of data and are therefore more robust to the intrinsic noise of natural language. However, does their success on isolated tasks mean they genuinely understand the languages they are processing? Or perhaps, are they just exploiting statistical patterns?
        
The principle of compositionality is a fundamental feature of natural language and serves as a rigorous test of the robustness of a model's learned representation. In its most broad definition, compositionality refers to "the meaning of a whole is a function of the meanings of the parts and of the way they are syntactically combined" ([Kamp and Partee, 1995](https://doi.org/10.1016/0010-0277(94)00659-9)). Note that this statement only refers to language itself, not the behavior of an entity using the language. To address this gap, ([Hupkes et al. 2019](https://arxiv.org/abs/1908.08351)) propose a suite of behavioral tests designed to evaluate the compositionality of a neural model across five dimensions. Of the five aspects, systematicity --- whether models systematically recombine known parts and rules --- and localism --- whether models' composition operations are local or global --- are of particular note.

In this research, we evaluate the compositionality of recurrent and attention-based neural models in terms of both localism and systematicity in the context of a sequence-to-sequence machine translation task on an artificial compositional language, which we call PFCG-LET. 

## Setup

### <a name="install"></a> Install Dependencies

**Warning: The runtime must be restarted to import OpenNMT**

In [None]:
!git clone https://github.com/OpenNMT/OpenNMT-py.git
%cd OpenNMT-py
!python setup.py install
%cd ..
!pip install torchtext==0.4.0 treelib

### Optional: mount Google Drive storage for persistence

Ensure you have a folder in the root of your drive named `pcfglet`.

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

### Import libraries



In [None]:
# standard library
import argparse
from collections import Counter, defaultdict, OrderedDict
import glob
import logging
import math
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple, Union
import uuid

# external libraries
from IPython.display import Image, display
import matplotlib.pyplot as plt
from matplotlib import tight_layout
import numpy as np
import onmt
import onmt.opts
import onmt.translate
import pandas as pd
import seaborn as sns
from tqdm.auto import tqdm
import torch
import treelib

### Optional: load existing datasets and models from Drive

In [None]:
!cp -r "drive/My Drive/pcfglet/" ./

## PCFG List-Edit-Task (PCFG-LET)

It is common practice to evaluate the compositionality of neural models in isolation by carefully designing artificial languages which exhibit desirable compositional phenomena. In theory, by controlling the training data in this manner, models should only be able to successfully generalize if they have formed an appropriate compositional representation or strategy.

Like ([Hupkes et al. 2019](https://arxiv.org/abs/1908.08351)), the input alphabet of PCFG-LET consists of four categories of words: numerical characters which define a list, unary and binary functions which manipulate those lists, a comma (,) which separates binary arguments, and a special symbol \texttt{E} that represents an empty list. These words are combined to generate input sequences that define a series of operations applied to the list argument(s). Note that the empty token can only appear in intermediate or output sequences.

The task for a particular network is then to translate inputs into interpreted outputs by recursively applying the interpretation functions. Success should only be possible if networks are **local** in their operation and **systematic** in their approach since the functions are highly sensitive to the order of lists and arguments. The functions of the list-edit-task prevent sequence explosion, the phenomenon in which applications of interpretation functions lead to excessively long sequences. ([Hupkes et al. 2019](https://arxiv.org/abs/1908.08351)) highlight this as an unfortunate side-effect of their PCFG-SET.

### Define PCFG-LET as code

Input alphabet and interpretation functions

In [None]:
`# numbers from 1 to 519 (inclusive)
VOCAB = list(range(1,520))

# some operations may produce empty lists. Introduce a token to represent this
EMPTY_TOKEN = 'E'

# global that determines whether the empty token is included in input-sequences.
# set to True if testing for systematicity.
INCLUDE_EMPTY_TOKEN = False

# Type: a token will either be an empty token (E), argument separator (,), function name, or vocabulary character
Token = str

# Type: a parsed token will either be an empty token (E), argument separator (,), unary function, binary function, or list of integers
ParsedToken = Union[str, Callable, List[int]]

####################
# helper functions #
####################

def list_to_string(lst: List[Any]) -> str:
  '''Convert a list to space-separated string'''
  return ' '.join(map(str, lst))

def split_list(given_list: List[Any], segments: int = 1) -> List[Any]:
  '''Given a list, split into the defined number of segments'''
  length = len(given_list)
  return [ given_list[i*length // segments: (i+1)*length // segments] for i in range(segments) ]

def is_valid_pcfg_function_name(name: str) -> bool:
  '''Test if a given string is a valid PCFG-LET unary or binary function'''
  return isinstance(name, str) and (hasattr(UnaryFunctions, name) or hasattr(BinaryFunctions, name))

##############################################
# PCFG-LET Unary (single-argument) functions #
##############################################
class UnaryFunctions:

  @staticmethod
  def min(fn_input: ParsedToken) -> ParsedToken:
    '''Return the lowest character in the list'''
    fn_input = list(filter(lambda elem: elem != EMPTY_TOKEN, fn_input))
    if fn_input == EMPTY_TOKEN or fn_input == []:
      return EMPTY_TOKEN
    return [min(fn_input)]

  @staticmethod
  def max(fn_input: ParsedToken) -> ParsedToken:
    '''Return the highest character in the list'''
    fn_input = list(filter(lambda elem: elem != EMPTY_TOKEN, fn_input))
    if fn_input == EMPTY_TOKEN or fn_input == []:
      return EMPTY_TOKEN
    return [max(fn_input)]

  @staticmethod
  def unique(fn_input: ParsedToken) -> ParsedToken:
    '''Remove duplicate occurences of characters'''
    if fn_input == EMPTY_TOKEN:
      return EMPTY_TOKEN
    fn_input = list(filter(lambda elem: elem != EMPTY_TOKEN, fn_input))
    return list(OrderedDict.fromkeys(fn_input))

  @staticmethod
  def remove_unique(fn_input: ParsedToken) -> ParsedToken:
    '''Remove characters that do not occur at least twice'''
    if fn_input == EMPTY_TOKEN:
      return EMPTY_TOKEN
    fn_input = list(filter(lambda elem: elem != EMPTY_TOKEN, fn_input))
    counts = Counter(fn_input)
    out = [idx for idx, count in counts.items() if count != 1]
    return out or EMPTY_TOKEN
  
  @staticmethod
  def remove_repeated(fn_input: ParsedToken) -> ParsedToken:
    '''Remove all occurences of characters that occur more than once'''
    if fn_input == EMPTY_TOKEN:
      return EMPTY_TOKEN
    fn_input = list(filter(lambda elem: elem != EMPTY_TOKEN, fn_input))
    counts = Counter(fn_input)
    out = [idx for idx, count in counts.items() if count == 1]
    return out or EMPTY_TOKEN

  @staticmethod
  def mirror(fn_input: ParsedToken) -> ParsedToken:
    '''Rotate list 180 degrees'''
    if fn_input == EMPTY_TOKEN:
      return EMPTY_TOKEN
    fn_input = list(filter(lambda elem: elem != EMPTY_TOKEN, fn_input))
    a, b = split_list(fn_input, segments=2)
    if len(fn_input) % 2 == 0:
      return b + a
    else:
      return b[1:] + [b[0]] + a

############################################
# PCFG-LET Binary (two-argument) functions #
############################################
class BinaryFunctions():

  @staticmethod
  def filter(x: ParsedToken, y: ParsedToken) -> ParsedToken:
    '''Order-sensitive: return all characters in x that do not occur in y'''
    if x == EMPTY_TOKEN:
      return EMPTY_TOKEN
    elif y == EMPTY_TOKEN:
      return x
    elif x == EMPTY_TOKEN and y == EMPTY_TOKEN:
      return EMPTY_TOKEN
    else:
      x = list(filter(lambda elem: elem != EMPTY_TOKEN, x))
      filtered = [elem for elem in x if elem not in y]
      return filtered or EMPTY_TOKEN

  @staticmethod
  def union(x: ParsedToken, y: ParsedToken) -> ParsedToken:
    '''Symmetric: return unique characters that occur in x or y'''
    if x == EMPTY_TOKEN:
      return y
    if y == EMPTY_TOKEN:
      return x
    if x == EMPTY_TOKEN and y == EMPTY_TOKEN:
      return EMPTY_TOKEN
    unique_to_y = [elem for elem in y if elem not in x]
    return UnaryFunctions.unique(x + unique_to_y)

  @staticmethod
  def intersection(x: ParsedToken, y:ParsedToken) -> ParsedToken:
    '''Symmetric: return unique characters that occur in both x and y'''
    if x == EMPTY_TOKEN:
      return y
    if y == EMPTY_TOKEN:
      return x
    if x == EMPTY_TOKEN and y == EMPTY_TOKEN:
      return EMPTY_TOKEN
    x_in_y = [elem for elem in x if elem in y]
    y_in_x = [elem for elem in y if elem in x]
    combined = x_in_y + y_in_x
    return UnaryFunctions.unique(x_in_y + y_in_x) if combined else EMPTY_TOKEN

  @staticmethod
  def difference(x: ParsedToken, y:ParsedToken) -> ParsedToken:
    '''Symmetric: return the unique disjunctive union of x and y'''
    unique_to_x = BinaryFunctions.filter(x, y)
    unique_to_y = BinaryFunctions.filter(y, x)
    if unique_to_x == EMPTY_TOKEN and unique_to_y != EMPTY_TOKEN:
      return unique_to_y
    elif unique_to_y == EMPTY_TOKEN and unique_to_x != EMPTY_TOKEN:
      return unique_to_x
    elif unique_to_x == EMPTY_TOKEN and unique_to_y == EMPTY_TOKEN:
      return EMPTY_TOKEN
    return UnaryFunctions.unique(unique_to_x + unique_to_y)

### Define functions to generate PCFG-LET commands

In [None]:
###### The CFG ######

## Non-terminal rules:
# S --> Fu S | Fb S, S
# S --> X              
# X --> X X           

## Lexical rules:
# Fu --> min | max | unique | remove unique | remove duplicates | mirror
# Fb --> filter | union | intersection | difference 
# X --> 1 | 2 | ... | 518 | 519

#####################

Fu = {'min':.2,'max':.2, 'remove_unique':.2,'remove_repeated':.2, 'mirror':.2}
Fb = {'filter':.25, 'union':.25, 'intersection':.25, 'difference':.25}


def S_rule(command):
  '''Performs one of the S --> . rules from the PCFG, with probabilities [.6, .4]'''  
  if command == []:

    start = np.random.choice([Fu, Fb])
    func_name = np.random.choice(list(start.keys()), p=list(start.values()))
    if start == Fu:
      return [func_name, S_rule(['S'])]
    else:
      return [func_name, S_rule(['S']), ',', S_rule(['S'])]
  
  else:

    next_rule = np.random.choice(['S', 'X'], p=[.6, .4])
    if next_rule == 'S':
      return S_rule([])
    else:
      return X_rule(['X'])


def X_rule(command):
  '''Performs one of the X --> . rules from the PCFG, with probabilities [.45, .55]'''
  while 'X' in command:

    idx_1st_X = command.index('X')
    next_rule = np.random.choice(['X X', 'X'], p=[.45,.55])

    if next_rule == 'X X':
      command.insert(idx_1st_X, 'X')
      return (X_rule(command))
    else:
      if INCLUDE_EMPTY_TOKEN:
        symbol = np.random.choice([EMPTY_TOKEN] + VOCAB)
        if symbol != EMPTY_TOKEN: symbol = int(symbol)
      else:
        symbol = int(np.random.choice(VOCAB))
      command[idx_1st_X] = symbol
    
  return command


def compute_output(command):
  '''Computes the output of a command, given as a nested list'''
  while isinstance(command, list) and \
  not all(isinstance(x,np.int) or x==EMPTY_TOKEN for x in command):
    
    func_name = command[0]
    S1 = compute_output(command[1])

    if func_name in list(Fu.keys()):
      return getattr(UnaryFunctions, func_name)(S1)
    else:
      S2 = compute_output(command[3])

      return getattr(BinaryFunctions, func_name)(S1, S2)

  return command

# Used for localism analysis
def recurse_unroll(command: List[Any], wildcards: List[str], steps: List[Any]) -> None:
  '''Side-effect function only used by get_unrolled_steps'''
  while isinstance(command, list) and not all(isinstance(x,np.int) for x in command):
    func_name = command[0]
    S1 = recurse_unroll(command[1], wildcards, steps)
    S1 = list_to_string(S1) if isinstance(S1, list) else S1
    if func_name in list(Fu.keys()):
      next_wildcard = wildcards.pop()
      steps.append(f'{func_name} {S1}|{next_wildcard}')
    else:
      S2 = recurse_unroll(command[3], wildcards, steps)
      S2 = list_to_string(S2) if isinstance(S2, list) else S2
      next_wildcard = wildcards.pop()
      steps.append(f'{func_name} {S1} , {S2}|{next_wildcard}')
    return next_wildcard
  return command

def get_unrolled_steps(raw_command: List[Any], str_command: str) -> List[str]:
  '''Given a raw command and its string version, return a list of the unrolled steps required to solve the command recursively
     e.g. difference remove_unique min 17 , 23 --> ['min 17|*1', 'remove_unique *1|*2', 'difference *2 , 23|*3']'''
  num_functions = sum([is_valid_pcfg_function_name(x) for x in str_command.split(' ')])
  wildcards = list(reversed([f'*{i}' for i in range(1, num_functions + 1)]))
  unrolled_steps = []
  recurse_unroll(raw_command, wildcards, unrolled_steps)
  return unrolled_steps


def depth(l):
  '''Recursively compute depth of nested list
     source: https://stackoverflow.com/questions/6039103/counting-depth-or-the-deepest-level-a-nested-list-goes-to'''
  if isinstance(l, list):
      return 1 + max(depth(item) for item in l)
  else:
      return 0

def flatten(A):
  '''Flattens a list of strings and list, source:
     https://stackoverflow.com/questions/17864466/flatten-a-list-of-strings-and-lists-of-strings-and-lists-in-python'''
  rt = []
  for i in A:
    if isinstance(i,list): rt.extend(flatten(i))
    else: rt.append(i)

  return rt


def generate_command() -> Tuple[str, int, List[str], List[str]]:
  '''Returns a new command as a string, along with its depth, output, and unrolled steps'''
  raw_command = S_rule([])
  command = ' '.join([str(x) for x in flatten(raw_command)])
  # recursive depth function over-reports by 1 e.g. ['max', 1] == depth of 2
  command_depth = depth(raw_command) - 1 
  output = compute_output(raw_command)
  unrolled_steps = get_unrolled_steps(raw_command, command)

  return command, command_depth, output, unrolled_steps


# Samples n commands from the PCFG
def get_commands(n):
  for _ in range(n):
    yield generate_command()

### Generate some sample commands

In [None]:
for command in get_commands(10):
  command_string, command_depth, output, unrolled_steps = command
  print(f'Command: {command_string}\nDepth: {command_depth}\nOutput: {output}\nUnrolled Steps: {unrolled_steps}\n\n')


## Generate Dataset

We create a dataset of 100,000 (one-hundred-thousand) input-output pairs by sampling from the grammar. To eliminate any possibility of memorization, we ensure each sample is unique. By unique, we mean that the arguments to any given function are **never repeated**. We do **not** perform any form of naturalization, and do **not** limit the length of list arguments given to the functions, but do cap sequence length to 50 (fifty) tokens for performance. Data is split into portions of 85\%, 5\%, and 10\% for training, validation, and testing.

### Define dataset generation code

In [None]:
def ensure_dir(file_path: str) -> None:
  '''Create a directory at the provided path if one does not already exist'''
  directory = os.path.dirname(file_path)
  if not os.path.exists(directory):
    os.makedirs(directory)

def load_dataset_frames() -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
  '''Load train, test, validation, and full dataset from the persistence folder'''
  train = pd.read_pickle('pcfglet/train.pkl')
  val = pd.read_pickle('pcfglet/val.pkl')
  test = pd.read_pickle('pcfglet/test.pkl')
  full_set = pd.read_pickle('pcfglet/all.pkl')
  return train, test, val, full_set

def generate_dataset(n_samples: int = 100000, ratios: List[float] = [0.85, 0.05, 0.1]) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
  '''Builds a full dataset of the given number of samples and splits it into sets per the given ratios'''
  assert sum(ratios) == 1, "split-ratios do not sum to 1."

  ensure_dir('pcfglet/')

  # generate unique samples: make sure that the function arguments are never repeated
  seen_arguments = {k: defaultdict() for k in list(Fu.keys()) + list(Fb.keys())}
  columns = {'command': [], 'length': [],  'depth': [], 'num_functions': [], 'target': [], 'unrolled_steps': []}

  with tqdm(total=n_samples) as pbar:
    pbar.set_description('Unique commands generated:')
    while len(columns['command']) < n_samples:
      command_tuple = get_commands(1)
      command, depth, target, unrolled_steps = next(command_tuple)
      command_as_list = command.split(' ')
      if len(command_as_list) > 50:
        continue
      is_duplicate = False
      for idx, token in enumerate(command_as_list):
        if is_valid_pcfg_function_name(token):
          remaining_tokens = command_as_list[idx + 1:]
          args = []
          for remaining_token in remaining_tokens:
            if is_valid_pcfg_function_name(remaining_token) or remaining_token == ',':
              break
            else:
              args.append(remaining_token)
          if len(args) > 0:
            args = list_to_string(args)
            found = seen_arguments[token].get(args)
            if found:
              is_duplicate = True
              continue
            else:
              seen_arguments[token][args] = 1
      if is_duplicate:
        continue
      pbar.update(1)
      command_length = len(command_as_list)
      num_functions = sum([is_valid_pcfg_function_name(tok) for tok in command_as_list])
      target = list_to_string(target)
      columns['command'].append(command)
      columns['length'].append(command_length)
      columns['depth'].append(depth)
      columns['num_functions'].append(num_functions)
      columns['target'].append(target)
      columns['unrolled_steps'].append(unrolled_steps)

  df = pd.DataFrame.from_records(columns)

  # Raw dataframes for analysis
  train, test, val = np.split(df.sample(frac=1), [int(ratios[0]*len(df)), int((1-ratios[1])*len(df))])
  
  train.to_pickle('pcfglet/train.pkl')
  test.to_pickle('pcfglet/test.pkl')
  val.to_pickle('pcfglet/val.pkl')
  df.to_pickle('pcfglet/all.pkl')

  # Commands and targets only for OpenNMT scripts
  train.to_csv(f'pcfglet/src-train.txt', sep=';', columns=['command'], index=False, header=False, mode='w+')
  train.to_csv(f'pcfglet/tgt-train.txt', sep=';', columns=['target'], index=False, header=False, mode='w+')
  test.to_csv('pcfglet/src-test.txt', sep=';', columns=['command'], index=False, header=False, mode='w+')
  test.to_csv('pcfglet/tgt-test.txt', sep=';', columns=['target'], index=False, header=False, mode='w+')
  val.to_csv('pcfglet/src-val.txt', sep=';', columns=['command'], index=False, header=False, mode='w+')
  val.to_csv('pcfglet/tgt-val.txt', sep=';', columns=['target'], index=False, header=False, mode='w+')

  return train, test, val, df

### Generate a new dataset if one was not imported


In [None]:
data_files = [
  'pcfglet/train.pkl',
  'pcfglet/val.pkl',
  'pcfglet/test.pkl',
  'pcfglet/all.pkl',
  'pcfglet/src-train.txt',
  'pcfglet/tgt-train.txt',
  'pcfglet/src-val.txt',
  'pcfglet/tgt-val.txt',
  'pcfglet/src-test.txt',
  'pcfglet/tgt-test.txt'
]

# generate a new dataset if all of the required files are not present
if all([Path(f).is_file() for f in data_files]):
  print('Dataset already exists. Loading from disk.')
  TRAIN_SET, TEST_SET, VAL_SET, ALL = load_dataset_frames()
else:
  # Use same split as Hupkes et al. 2019
  TRAIN_SET, TEST_SET, VAL_SET, ALL = generate_dataset()


#### Confirm the dataset matches expectations



In [None]:
assert ALL.command.is_unique
assert TRAIN_SET.command.is_unique
assert VAL_SET.command.is_unique
assert TEST_SET.command.is_unique
assert len(ALL) == 100000
assert len(TRAIN_SET) == 85000
assert len(VAL_SET) == 5000
assert len(TEST_SET) == 10000

### Visualise dataset

In [None]:
def plot_dataset_distribution(dataset: pd.DataFrame):
  dist_plot = sns.jointplot(
      x='depth',
      y='length',
      data=dataset,
      kind="kde",
      xlim=(1,3),
      ylim=(1,20),
      space=0)
  dist_plot.ax_joint.set_xticks([1,2,3])
  dist_plot.ax_joint.set_yticks([2,10,20,])
  plt.savefig('pcfglet/distribution_plot.png')

if not Path('pcfglet/distribution_plot.png').is_file():
  plot_dataset_distribution(ALL)
else:
  display(Image('pcfglet/distribution_plot.png'))

### Preprocess data with OpenNMT

In [None]:
processed_files = [
  'pcfglet/processed.train.0.pt',
  'pcfglet/processed.valid.0.pt',
  'pcfglet/processed.vocab.pt'
]

# no need to pre-process if the files were imported from Drive
if not all([Path(f).is_file() for f in processed_files]):
  !python OpenNMT-py/preprocess.py -train_src pcfglet/src-train.txt -train_tgt pcfglet/tgt-train.txt -valid_src pcfglet/src-val.txt -valid_tgt pcfglet/tgt-val.txt -save_data pcfglet/processed -src_seq_length 100 -tgt_seq_length 100 -overwrite -filter_valid

### Optional Checkpoint: persist dataset and preprocessed files to Drive

In [None]:
# !cp -r pcfglet/processed* "drive/My Drive/pcfglet/"
# !cp -r pcfglet/*.pkl "drive/My Drive/pcfglet/"
# !cp -r pcfglet/src-*.txt "drive/My Drive/pcfglet/"
# !cp -r pcfglet/tgt-*.txt "drive/My Drive/pcfglet/"
# !cp -r pcfglet/*.png "drive/My Drive/pcfglet/"

## Train Models

Models examined in the literature tend to fall into three categories: recurrent, convolutional, or purely attention-based. Initial experimentation with convolutional models led to consistently poor results. Taking this as evidence of the order-sensitive nature of the task, we chose to study the other architectures.

**Model 1: LSTMS2S**

The sequential processing nature of LSTM models is ideal for PCFG-LET. We select a fully recurrent, bidirectional model with attention similar to the LSTMS2S architecture of ([Hupkes et al. 2019](https://arxiv.org/abs/1908.08351)). With careful hyperparameter tuning, we decide on a 512-dimensional word vector size with scaled dot-product attention ([Vaswani et al. 2019](https://arxiv.org/abs/1706.03762)) and a batch size of 64 sequences. We train the model for 25 epochs or until convergence (five successive epochs with no improvement in word accuracy and perplexity on the validation set), using the Adam optimizer with a standard learning rate of 0.001.

**Model 2: Transformer**

Transformers discard recurrent cells in favor of a purely attention-based approach. Such a design is advantageous for processing longer sequences. By inferring order from position encodings, the cost of relating tokens remains uniform regardless of their distance. Unlike LSTMs, additional stacked layers in a transformer improve hierarchical modeling capability. A drawback of the Transformer model is its sensitivity to hyperparameter tuning. We experimented with different settings, but ultimately found the set chosen by ([Hupkes et al. 2019](https://arxiv.org/abs/1908.08351)), to be optimal. The only difference between our Transformer Model and Hupkes et al. 's is a reduced word vector dimensionality of 256. The transformer model is trained for 25 epochs (or until convergence), similarly to the LSTMS2S.

Define checkpoint manipulation code

In [None]:
# Model checkpoints are saved as {type}_{run id}_step_{stepnumber}.pt
# e.g. "lstms2s_1_step_6640.pt"

def get_checkpoint_files_in_dir(path: str) -> List[str]:
  '''Find all OpenNMT model checkpoint files in a given directory'''
  model_checkpoints = f'{path}/*_step_*.pt'
  hits = glob.glob(model_checkpoints)
  return hits

def get_best_model_step_from_trace(trace_path: str) -> int:
  '''Extract the best model step number from an OpenNMT training trace file'''
  backup = None # return lastest model if training didn't converge
  for line in reversed(list(open(trace_path, 'r'))):
    stripped_line = line.rstrip()
    if 'best model found' in stripped_line.lower():
      return int(stripped_line.split(' ')[-1])
    elif 'saving checkpoint' in stripped_line.lower():
      backup = int(stripped_line.split('_')[-1].split('.')[0])
  return backup

def select_best_model(model_type: str, run_id: int) -> str:
  '''Find the best performing OpenNMT model checkpoint file by analyzing the trace'''
  best_step = get_best_model_step_from_trace(f'pcfglet/{model_type}_{run_id}_trace.txt')
  best_model = None
  deleted = 0
  for hit in get_checkpoint_files_in_dir('pcfglet'):
    _, checkpoint_file = os.path.split(hit)
    file_name, _ = os.path.splitext(checkpoint_file)
    found_model_type, found_run_id, _, found_step = file_name.split('_')
    if found_model_type == model_type and int(found_run_id) == run_id:
      if int(found_step) == best_step:
        best_model = hit
      else:
        os.remove(hit)
        deleted += 1
  print(f'Removed {deleted} checkpoints')
  return best_model

def model_from_run_exists(path: str, model_type: str, run_id: int) -> bool:
  '''Check if a model has already been trained for the given type and run id'''
  for hit in get_checkpoint_files_in_dir(path):
    _, checkpoint_file = os.path.split(hit)
    name_segments = checkpoint_file.split('_')
    if name_segments[0] == model_type and int(name_segments[1]) == run_id:
      return True
  return False

### Train LSTMS2S

Train an LSTM with specified setting, for 25 epochs or until convergence (word accuracy has not improved for five consecutive validation steps). The best model number from each trial is extracted from the trace files. Other checkpoints are deleted.

#### Trial 1

In [None]:
if not model_from_run_exists('pcfglet', 'lstms2s', 1):
  !python OpenNMT-py/train.py \
    -data pcfglet/processed \
    -save_model "pcfglet/lstms2s_1" \
    -word_vec_size 512 \
    -encoder_type brnn \
    -global_attention dot \
    -batch_size 64 \
    -optim adam \
    -learning_rate 0.001 \
    -train_steps 33200 \
    -valid_steps 1328 \
    -save_checkpoint_steps 1328 \
    -early_stopping 5 \
    -world_size 1 \
    -gpu_ranks 0 2>&1 | tee pcfglet/lstms2s_1_trace.txt
    
else:
  print('A first run LSTM model has already been trained')

BEST_LSTM_RUN_1 = select_best_model('lstms2s', 1)
%env BEST_LSTM_RUN_1=$BEST_LSTM_RUN_1

#### Trial 2

In [None]:
if not model_from_run_exists('pcfglet', 'lstms2s', 2):
  !python OpenNMT-py/train.py \
    -data pcfglet/processed \
    -save_model "pcfglet/lstms2s_2" \
    -word_vec_size 512 \
    -encoder_type brnn \
    -global_attention dot \
    -batch_size 64 \
    -optim adam \
    -learning_rate 0.001 \
    -train_steps 33200 \
    -valid_steps 1328 \
    -save_checkpoint_steps 1328 \
    -early_stopping 5 \
    -world_size 1 \
    -gpu_ranks 0 2>&1 | tee pcfglet/lstms2s_2_trace.txt
    
else:
  print('A second run LSTM model has already been trained')


BEST_LSTM_RUN_2 = select_best_model('lstms2s', 2)
%env BEST_LSTM_RUN_2=$BEST_LSTM_RUN_2

#### Trial 3

In [None]:
if not model_from_run_exists('pcfglet', 'lstms2s', 3):
  !python OpenNMT-py/train.py \
    -data pcfglet/processed \
    -save_model "pcfglet/lstms2s_3" \
    -word_vec_size 512 \
    -encoder_type brnn \
    -global_attention dot \
    -batch_size 64 \
    -optim adam \
    -learning_rate 0.001 \
    -train_steps 33200 \
    -valid_steps 1328 \
    -save_checkpoint_steps 1328 \
    -early_stopping 5 \
    -world_size 1 \
    -gpu_ranks 0 2>&1 | tee pcfglet/lstms2s_3_trace.txt
    
else:
  print('A third run LSTM model has already been trained')


BEST_LSTM_RUN_3 = select_best_model('lstms2s', 3)
%env BEST_LSTM_RUN_3=$BEST_LSTM_RUN_3

#### Optional Checkpoint: persist trained LSTM models and traces to Drive




In [None]:
# !cp -r pcfglet/lstms2s_* "drive/My Drive/pcfglet/"

#### Translate test set

In [None]:
if not Path('pcfglet/pred_lstms2s_run_1.txt').is_file():
  !python OpenNMT-py/translate.py \
  -model "$BEST_LSTM_RUN_1" \
  -src pcfglet/src-test.txt \
  -tgt pcfglet/tgt-test.txt \
  -output pcfglet/pred_lstms2s_run_1.txt

if not Path('pcfglet/pred_lstms2s_run_3.txt').is_file():
  !python OpenNMT-py/translate.py \
  -model "$BEST_LSTM_RUN_2" \
  -src pcfglet/src-test.txt \
  -tgt pcfglet/tgt-test.txt \
  -output pcfglet/pred_lstms2s_run_2.txt

if not Path('pcfglet/pred_lstms2s_run_3.txt').is_file():
  !python OpenNMT-py/translate.py \
  -model "$BEST_LSTM_RUN_3" \
  -src pcfglet/src-test.txt \
  -tgt pcfglet/tgt-test.txt \
  -output pcfglet/pred_lstms2s_run_3.txt

#### Optional Checkpoint: persist Predictions to Drive

In [None]:
# !cp -r pcfglet/pred_lstms2s_* "drive/My Drive/pcfglet/"

### Train Transformer

Train a transformer with specified setting, for approximately 25 epochs or until convergence (word accuracy has not improved for five consecutive validation steps). The best model number from each trial is extracted from the trace files. Other checkpoints are deleted.

#### Trial 1

In [None]:
if not model_from_run_exists('pcfglet', 'transformer', 1):
  !python OpenNMT-py/train.py \
    -data pcfglet/processed \
    -save_model "pcfglet/transformer_1"  \
    -layers 6 \
    -rnn_size 256 \
    -word_vec_size 256 \
    -transformer_ff 2048 \
    -heads 8  \
    -encoder_type transformer \
    -decoder_type transformer \
    -position_encoding \
    -train_steps 33200 \
    -max_generator_batches 2 \
    -dropout 0.1 \
    -batch_size 64 \
    -accum_count 2 \
    -optim adam \
    -adam_beta2 0.998 \
    -decay_method noam \
    -warmup_steps 8000 \
    -learning_rate 1 \
    -max_grad_norm 0 \
    -param_init 0 \
    -param_init_glorot \
    -valid_steps 1328 \
    -early_stopping 5 \
    -save_checkpoint_steps 1328 \
    -world_size 1 \
    -gpu_ranks 0 2>&1 | tee pcfglet/transformer_1_trace.txt
else:
  print('A first run Transformer model has already been trained')

BEST_TRANSFORMER_RUN_1 = select_best_model('transformer', 1)
%env BEST_TRANSFORMER_RUN_1=$BEST_TRANSFORMER_RUN_1

#### Trial 2

In [None]:
if not model_from_run_exists('pcfglet', 'transformer', 2):
  !python OpenNMT-py/train.py \
    -data pcfglet/processed \
    -save_model "pcfglet/transformer_2"  \
    -layers 6 \
    -rnn_size 256 \
    -word_vec_size 256 \
    -transformer_ff 2048 \
    -heads 8  \
    -encoder_type transformer \
    -decoder_type transformer \
    -position_encoding \
    -train_steps 33200 \
    -max_generator_batches 2 \
    -dropout 0.1 \
    -batch_size 64 \
    -accum_count 2 \
    -optim adam \
    -adam_beta2 0.998 \
    -decay_method noam \
    -warmup_steps 8000 \
    -learning_rate 1 \
    -max_grad_norm 0 \
    -param_init 0 \
    -param_init_glorot \
    -valid_steps 1328 \
    -early_stopping 5 \
    -save_checkpoint_steps 1328 \
    -world_size 1 \
    -gpu_ranks 0 2>&1 | tee pcfglet/transformer_2_trace.txt
else:
  print('A second run Transformer model has already been trained')


BEST_TRANSFORMER_RUN_2 = select_best_model('transformer', 2)
%env BEST_TRANSFORMER_RUN_2=$BEST_TRANSFORMER_RUN_2

#### Trial 3

In [None]:
if not model_from_run_exists('pcfglet', 'transformer', 3):
  !python OpenNMT-py/train.py \
    -data pcfglet/processed \
    -save_model "pcfglet/transformer_3"  \
    -layers 6 \
    -rnn_size 256 \
    -word_vec_size 256 \
    -transformer_ff 2048 \
    -heads 8  \
    -encoder_type transformer \
    -decoder_type transformer \
    -position_encoding \
    -train_steps 33200 \
    -max_generator_batches 2 \
    -dropout 0.1 \
    -batch_size 64 \
    -accum_count 2 \
    -optim adam \
    -adam_beta2 0.998 \
    -decay_method noam \
    -warmup_steps 8000 \
    -learning_rate 1 \
    -max_grad_norm 0 \
    -param_init 0 \
    -param_init_glorot \
    -valid_steps 1328 \
    -early_stopping 5 \
    -save_checkpoint_steps 1328 \
    -world_size 1 \
    -gpu_ranks 0 2>&1 | tee pcfglet/transformer_3_trace.txt
else:
  print('A third run Transformer model has already been trained')

BEST_TRANSFORMER_RUN_3 = select_best_model('transformer', 3)
%env BEST_TRANSFORMER_RUN_3=$BEST_TRANSFORMER_RUN_3

#### Optional Checkpoint: persist trained Transformer models and traces to Drive

In [None]:
# !cp -r pcfglet/transformer_* "drive/My Drive/pcfglet/"

#### Translate test set

(OpenNMT is not very efficient - this takes about 45 minutes)

In [None]:
if not Path('pcfglet/pred_transformer_run_1.txt').is_file():
  !python OpenNMT-py/translate.py \
  -model "$BEST_TRANSFORMER_RUN_1" \
  -src pcfglet/src-test.txt \
  -tgt pcfglet/tgt-test.txt \
  -output pcfglet/pred_transformer_run_1.txt

if not Path('pcfglet/pred_transformer_run_2.txt').is_file():
  !python OpenNMT-py/translate.py \
  -model "$BEST_TRANSFORMER_RUN_2" \
  -src pcfglet/src-test.txt \
  -tgt pcfglet/tgt-test.txt \
  -output pcfglet/pred_transformer_run_2.txt

if not Path('pcfglet/pred_transformer_run_3.txt').is_file():
  !python OpenNMT-py/translate.py \
  -model "$BEST_TRANSFORMER_RUN_3" \
  -src pcfglet/src-test.txt \
  -tgt pcfglet/tgt-test.txt \
  -output pcfglet/pred_transformer_run_3.txt

#### Optional Checkpoint: persist predictions to Drive

In [None]:
# !cp -r pcfglet/pred_transformer_* "drive/My Drive/pcfglet/"

### Finalize Models

In [None]:
MODELS = {
  'lstms2s': {
    1: BEST_LSTM_RUN_1,
    2: BEST_LSTM_RUN_2,
    3: BEST_LSTM_RUN_3
  },
  'transformer': {
    1: BEST_TRANSFORMER_RUN_1,
    2: BEST_TRANSFORMER_RUN_2,
    3: BEST_TRANSFORMER_RUN_3
  }
}

## Analysis

### Task Accuracy

#### Define code to calculate and display task accuracy by sequence feature

In [None]:
def build_task_accuracy_dataframe() -> pd.DataFrame:
  frame = {
    'length': [],
    'depth': [],
    'num_functions': [],
    'accuracy': [],
    'model_type': [],
    'trial': []
  }
  accuracies = {
    'lstms2s': {
        1: np.zeros(len(TEST_SET)),
        2: np.zeros(len(TEST_SET)),
        3: np.zeros(len(TEST_SET))
    },
    'transformer': {
        1: np.zeros(len(TEST_SET)),
        2: np.zeros(len(TEST_SET)),
        3: np.zeros(len(TEST_SET))
    },
  }
  targets = list(map(str.rstrip, list(open(f'pcfglet/tgt-test.txt'))))
  for model_type in ['lstms2s', 'transformer']:
    for trial in range(1,4):
      predictions = map(str.rstrip, list(open(f'pcfglet/pred_{model_type}_run_{trial}.txt')))
      for i, (prediction, target) in enumerate(zip(predictions, targets)):
        accuracies[model_type][trial][i] = float(prediction == target)
  with tqdm(total=len(TEST_SET) * 2 * 3) as progress_bar:
    for i, sequence_data in enumerate(TEST_SET.itertuples()):
      for model_type in ['lstms2s', 'transformer']:
        for trial in range(1,4):
          frame['length'].append(sequence_data.length)
          frame['depth'].append(sequence_data.depth)
          frame['num_functions'].append(sequence_data.num_functions)
          frame['accuracy'].append(accuracies[model_type][trial][i])
          frame['model_type'].append(model_type)
          frame['trial'].append(trial)
          progress_bar.update(1)
  df = pd.DataFrame(frame)
  df.to_pickle('pcfglet/accuracy.pkl')
  return df

def plot_accuracy_by_dimension(accuracy_frame: pd.DataFrame, dimension: str, xlim: Tuple[int, int], xticks: List[int]) -> None:
  plt.clf()
  ax = sns.lineplot(x=dimension, y='accuracy', data=accuracy_frame, style='model_type')
  ax.set_ylim((0,1))
  ax.set_xlim(xlim)
  ax.set_xticks(xticks)
  ax.grid(True)
  plt.savefig(f'pcfglet/task_accuracy_by_{dimension}.png')

#### Build task accuracy data frame

In [None]:
if not Path('pcfglet/accuracy.pkl').is_file():
  TASK_ACCURACY = build_task_accuracy_dataframe()
else:
  TASK_ACCURACY = pd.read_pickle('pcfglet/accuracy.pkl')

#### Calculate average task accuracy ± standard deviation across all trials

In [None]:
mean_accuracy_by_type = TASK_ACCURACY.groupby(['model_type', 'trial'])['accuracy'].mean()
for model_type in ['lstms2s', 'transformer']:
  mean = np.mean(mean_accuracy_by_type[model_type])
  std = np.std(mean_accuracy_by_type[model_type])
  print(f'Average task accuracy for {model_type}: {mean} ± {std}')

#### Plot average task accuracy for each model type by the same features as Hupkes et al. 2019.

By sequence length:

In [None]:
if not Path('pcfglet/task_accuracy_by_length.png').is_file():
  plot_accuracy_by_dimension(TASK_ACCURACY, 'length', (5,50), list(range(5,51,5)))
else:
  display(Image('pcfglet/task_accuracy_by_length.png'))

By sequence depth:

In [None]:
if not Path('pcfglet/task_accuracy_by_depth.png').is_file():
  plot_accuracy_by_dimension(TASK_ACCURACY, 'depth', (1,13), list(range(1,14,3)))
else:
  display(Image('pcfglet/task_accuracy_by_depth.png'))

By number of functions in the sequence:

In [None]:
if not Path('pcfglet/task_accuracy_by_num_functions.png').is_file():
  plot_accuracy_by_dimension(TASK_ACCURACY, 'num_functions', (1,15), list(range(1,16,3)))
else:
  display(Image('pcfglet/task_accuracy_by_num_functions.png'))

#### Optional Checkpoint: persist task accuracy dataframe and plots to Drive

In [None]:
# !cp -r pcfglet/accuracy.pkl "drive/My Drive/pcfglet/"
# !cp -r pcfglet/task_accuracy*.png "drive/My Drive/pcfglet/"

### Localism

We test the localism of trained models similarly to ([Hupkes et al. 2019](https://arxiv.org/abs/1908.08351)) by unrolling computationsand comparing  model’s  successive local predictions to their global ones.  The final prediction is **consistent** if it matches the global prediction, and **accurate** if it matches the target.



#### Define code to predict the output of an arbitrary command

We need to be able to predict the output of any command in order to examine model performance on individual functions when unrolling computations. One possible solution is to transform a string command into a parse tree.

In [None]:
# a parsed sequence will either be an empty token or list of integers
ParsedSequence = Union[str, List[int]]

def _try_to_parse_int(string: str) -> Union[int, str]:
  '''Cast a string to an integer if possible'''
  try:
    val = int(string)
    return val
  except ValueError:
    return string

def is_valid_token(token: Any) -> bool:
  '''Check if a given string is a valid token in PCFG-LET'''
  is_string = isinstance(token, str)
  is_empty_token = token == 'E'
  is_comma = token == ','
  is_numerical_character = is_string and token.isdigit()
  is_pcfg_function_name = is_string and (hasattr(UnaryFunctions, token) or hasattr(BinaryFunctions, token))
  return any([is_empty_token, is_comma, is_numerical_character, is_pcfg_function_name])

def is_valid_parsed_token(parsed_token: Any) -> bool:
  '''Check if a given PCFG-LET token has been parsed correctly'''
  is_empty_token = parsed_token == 'E'
  is_comma = parsed_token == ','
  is_numerical_list = isinstance(parsed_token, list) and all([isinstance(val, int) for val in parsed_token])
  is_pcfg_function = is_valid_function(parsed_token)
  return any([is_empty_token, is_comma, is_numerical_list, is_pcfg_function])

def is_valid_unary_function(token: Callable) -> bool:
  '''Check if a parsed PCFG-LET token is a Unary function'''
  return callable(token) and hasattr(UnaryFunctions, token.__name__)

def is_valid_binary_function(token: Callable) -> bool:
  '''Check if a parsed PCFG-LET token is a Binary function'''
  return callable(token) and hasattr(BinaryFunctions, token.__name__)

def is_valid_function(token: Callable) -> bool:
  '''Check if a parsed PCFG-LET token is a Unary or Binary unction'''
  return is_valid_unary_function(token) or is_valid_binary_function(token)

def is_valid_pcfg_function_argument(token: ParsedSequence) -> bool:
  '''Check if a parsed PCFG-LET token is a valid argument to a Unary or Binary function'''
  is_empty_token = token == 'E'
  is_numerical_list = isinstance(token, list) and all([isinstance(val, int) for val in token])
  return is_empty_token or is_numerical_list

def _command_to_parsed_stack(command: str) -> List[ParsedToken]:
  '''Turn a string command into a parsed stack'''
  stack = []
  tokens = command.split(' ')
  assert all(list(map(is_valid_token, tokens)))
  parsed_tokens = list(map(_try_to_parse_int, tokens))
  current_numerical_sequence = []
  for parsed_token in parsed_tokens:
    if isinstance(parsed_token, int):
      current_numerical_sequence.append(parsed_token)
    else:
      if current_numerical_sequence:
        stack.append(current_numerical_sequence)
        current_numerical_sequence = []
      if hasattr(BinaryFunctions, parsed_token):
        stack.append(getattr(BinaryFunctions, parsed_token))
      elif hasattr(UnaryFunctions, parsed_token):
        stack.append(getattr(UnaryFunctions, parsed_token))
      else:
        stack.append(parsed_token)
  if current_numerical_sequence:
    stack.append(current_numerical_sequence)
  stack = list(reversed(stack))
  assert all(list(map(is_valid_parsed_token, stack)))
  return stack

def _stack_to_parse_tree(stack: List[ParsedToken]) -> treelib.Tree:
  '''Turn a parsed stack into a parse tree'''
  tree = treelib.Tree()
  current_node = None # pointer for traversing the tree during construction
  while stack:
    val = stack.pop()
    tag = val.__name__ if callable(val) else str(val)
    # case 1: create root
    if not tree:
      current_node = tree.create_node(tag=tag, identifier='root', data=val)
      continue
    # case 2: find where to insert second argument of a binary function
    if val == ',':
      while True:
        current_node = tree.parent(current_node.identifier)
        if current_node is None:
          raise ValueError('No matching binary function for second argument')
        children = tree.children(current_node.identifier)
        if len(children) == 1 and callable(current_node.data) and hasattr(BinaryFunctions, current_node.data.__name__):
          break
    # case 3: insert node
    else:
      children = tree.children(current_node.identifier)
      # add a direction stem to the node ID, since treelib does not maintain insertion order
      node_id = f'{str(uuid.uuid4())}|{"RIGHT" if children else "LEFT"}' 
      current_node = tree.create_node(tag=tag, identifier=node_id, data=val, parent=current_node.identifier)
  return tree

def _calculate_output_from_parse_tree(parse_tree: treelib.Tree) -> ParsedSequence:
  '''Condense a parse tree into the final command output'''
  current_node = parse_tree.get_node('root')
  tree = treelib.Tree(tree=parse_tree, deep=True)
  LEFT = '|LEFT'
  RIGHT = '|RIGHT'
  next_direction = LEFT
  while tree.depth() > 0:

    if not is_valid_parsed_token(current_node.data):
      raise ValueError(f'treelib.Tree could not be parsed. Current node data: {current_node.data}')

    children = tree.children(current_node.identifier)
    if len(children) == 0:
      next_direction = LEFT if next_direction == RIGHT else RIGHT
      current_node = tree.parent(current_node.identifier)
    else:
      children_data = [child.data for child in children]
      should_condense = len(children) > 0 and all([is_valid_pcfg_function_argument(child) for child in children_data])
      if should_condense:
        if len(children) == 2 and is_valid_binary_function(current_node.data):
          left_child = next(child.data for child in children if LEFT in child.identifier)
          right_child = next(child.data for child in children if RIGHT in child.identifier)
          result = current_node.data(left_child, right_child)
        elif len(children) == 1 and is_valid_unary_function(current_node.data):
          result = current_node.data(children_data[0])
        else:
          raise ValueError('treelib.Tree could not be condensed. Phrase does not meet rules of the PCFG LET task')
        tree.update_node(current_node.identifier, data=result, tag=str(result))
        current_node = tree.parent(current_node.identifier)
        for child in children:
          tree.remove_node(child.identifier)
      else:
        try:
          child_id = next(child.identifier for child in children if next_direction in child.identifier)
        except StopIteration:
          raise ValueError('treelib.Tree could not be condensed. Phrase does not meet rules of the PCFG LET task')
        current_node = tree.get_node(child_id)
        next_direction = LEFT
  return tree.get_node('root').data

def predict_command(command: str):
  '''Predict the output of any given PCFG-LET string command'''
  stack = _command_to_parsed_stack(command)
  parse_tree = _stack_to_parse_tree(stack)
  output = _calculate_output_from_parse_tree(parse_tree)
  return list_to_string(output)

#### Define code to perform unrolled computations and produce graphs

In [None]:
def translate_one(translator: onmt.translate.Translator, source_sentence: str) -> List[str]:
  '''Translate a source sentence using OpenNMT'''
  prediction = translator.translate(src=[source_sentence], batch_size=1)
  # Looks like ([[tensor(-0.0005)]], [['21 5 7']])
  return prediction[1][0][0].split(' ')

def process_unrolled(steps, translator) -> Tuple[str, List[Tuple[str, str]]]:
  '''Process a series of unrolled steps with wildcards'''
  pairs = []
  collect_outcomes = dict()
  for step in steps:
    source, target = step.split('|')
    source_tokens = source.split(' ')
    for i, token in enumerate(source_tokens):
      if '*' in token:
        source_tokens[i] = collect_outcomes[token]
      source_tokens_flattened = []
      for token in source_tokens:
        if type(token) != list:
          source_tokens_flattened.append(token)
        else:
          source_tokens_flattened.extend(token)
      source_tokens = source_tokens_flattened
    source_command = list_to_string(source_tokens)
    prediction = translate_one(translator, source_tokens)
    str_prediction = list_to_string(prediction)
    pairs.append((source_command, str_prediction))
    if '*' in target:
      collect_outcomes[target] = str_prediction
  return str_prediction, pairs

parser = argparse.ArgumentParser()
onmt.opts.translate_opts(parser)

def load_localism_frames() -> Tuple[pd.DataFrame, pd.DataFrame]:
  '''Load the localism frames from the persistence folder'''
  return pd.read_pickle('pcfglet/localism_sequence.pkl'), pd.read_pickle('pcfglet/localism_function.pkl')

def build_localism_dataframes() -> pd.DataFrame:
  '''Store the results of unrolled computations and individual function performance into two separate dataframes'''
  sequence_localism_frame = {
    'trial': [],
    'model_type': [],
    'command': [],
    'consistency': [],
    'accuracy': [],
    'input_length': [],
    'target_length': [],
    'depth': []
  }
  function_localism_frame = {
    'trial': [],
    'model_type': [],
    'function': [],
    'accuracy': [],
    'target': []
  }
  with tqdm(total=2*3*len(TEST_SET)) as progress_bar:
    for model_type in MODELS.keys():
      for trial, model in MODELS[model_type].items():
        opt = parser.parse_args(['-model', model, '-src', 'pcfglet/src-test.txt'])
        translator = onmt.translate.translator.build_translator(opt, report_score=True, logger=logging.getLogger('onmt'))
        predictions = map(str.rstrip, list(open(f'pcfglet/pred_{model_type}_run_{trial}.txt')))
        for i, (sequence_data, prediction) in enumerate(zip(TEST_SET.itertuples(), predictions)):
          unrolled_predicted, pairs = process_unrolled(sequence_data.unrolled_steps, translator)

          consistent = float(unrolled_predicted == prediction)
          correct = float(unrolled_predicted == sequence_data.target)

          sequence_localism_frame['trial'].append(trial)
          sequence_localism_frame['model_type'].append(model_type)
          sequence_localism_frame['command'].append(sequence_data.command)
          sequence_localism_frame['consistency'].append(consistent)
          sequence_localism_frame['accuracy'].append(correct)
          sequence_localism_frame['input_length'].append(sequence_data.length)
          sequence_localism_frame['target_length'].append(len(sequence_data.target.split(' ')))
          sequence_localism_frame['depth'].append(len(pairs))

          for local_command, local_prediction in pairs:
            local_tokens = local_command.split(' ')
            function = local_tokens[0]

            # if the model made no prediction, consider the output incorrect
            if '' in local_tokens:
              locally_correct = 0
            else:
              correct_output = predict_command(local_command)
              locally_correct = float(local_prediction == correct_output)

            function_localism_frame['trial'].append(trial)
            function_localism_frame['model_type'].append(model_type)
            function_localism_frame['function'].append(function)
            function_localism_frame['accuracy'].append(locally_correct)
            function_localism_frame['target'].append(correct_output)

          progress_bar.update(1)
  sequence_frame = pd.DataFrame(sequence_localism_frame)
  sequence_frame.to_pickle('pcfglet/localism_sequence.pkl')
  function_frame = pd.DataFrame(function_localism_frame)
  function_frame.to_pickle('pcfglet/localism_function.pkl')

  return sequence_frame, function_frame

def print_localism_scores_for_frame(frame: pd.DataFrame, after_depth: int = None) -> None:
  '''Print average localism consistency and performance for sequences deeper than the given depth'''
  if after_depth:
    frame = frame.loc[frame['depth'] > after_depth]
  print(f'Localism results for depth > {after_depth or 0}')
  mean_consistency_by_type = frame.groupby(['model_type', 'trial'])['consistency'].mean()
  mean_performance_by_type = frame.groupby(['model_type', 'trial'])['accuracy'].mean()
  for model_type in ['lstms2s', 'transformer']:
    mean_consistency = np.mean(mean_consistency_by_type[model_type])
    std_consistency = np.std(mean_consistency_by_type[model_type])
    mean_performance = np.mean(mean_performance_by_type[model_type])
    std_performance = np.std(mean_performance_by_type[model_type])
    print(f'Average Localism consistency for {model_type}: {mean_consistency} ± {std_consistency}')
    print(f'Average Localism performance for {model_type}: {mean_performance} ± {std_performance}')

def plot_localism_lines_by_model_type(frame: pd.DataFrame, x: str, y: str, xlim: Tuple[int, int], xticks: List[int], ylim: Tuple[int, int] = (0,1), yticks: List[int] = None) -> None:
  '''Generate save a localism score (consistency or accuracy) by dimension (input_length, depth, etc...)'''
  plt.clf()
  ax = sns.lineplot(x=x, y=y, data=frame, style='model_type')
  ax.set_ylim(ylim)
  ax.set_xlim(xlim)
  ax.set_xticks(xticks)
  ax.grid(True)
  plt.savefig(f'pcfglet/localism_{y}_by_{x}.png')

def plot_performance_by_function(frame, non_empty_targets: bool = False):
  '''Plot performance per function (optionally for functions where the target is not empty)'''
  plt.clf()
  TO_GRAPH = frame
  TO_GRAPH = TO_GRAPH.sort_values('function', ascending=False)
  if non_empty_targets:
    TO_GRAPH = TO_GRAPH.loc[TO_GRAPH['target'] != 'E']
  count_map = TO_GRAPH.loc[(TO_GRAPH["model_type"] =='lstms2s') & (TO_GRAPH["trial"] == 1)].groupby('function').count()['trial'].to_dict()
  f, ax = plt.subplots(figsize=(10, 6))
  sns.barplot(x='accuracy', y='function', data=TO_GRAPH, hue='model_type')
  ax.set_xlim((0,1))
  ax.set_xticks([0,0.2,0.4,0.6,0.8, 1.0])
  ax.set_yticklabels([f'{key} ({count_map[key]})' for key in sorted(count_map.keys(), reverse=True)])
  renderer = tight_layout.get_renderer(f)
  inset_tight_bbox = ax.get_tightbbox(renderer)
  extent = inset_tight_bbox.transformed(f.dpi_scale_trans.inverted())
  plt.savefig(f'pcfglet/localism_performance_by_function{"_non_empty" if non_empty_targets else ""}.png', bbox_inches=extent)

#### Build localism dataframes
This will take about 1.5-2 hours (OpenNMT translation is slow :( )

In [None]:
localism_files = [
  'pcfglet/localism_sequence.pkl',
  'pcfglet/localism_function.pkl'
]


if all([Path(f).is_file() for f in localism_files]):
  print('Localism frames already exist. Loading from disk.')
  SEQUENCE_LOCALISM, PERFORMANCE_BY_FUNCTION = load_localism_frames()
else:
  SEQUENCE_LOCALISM, PERFORMANCE_BY_FUNCTION = build_localism_dataframes()

#### Display results of Localism tests

In [None]:
for i in range(0,4):
  print_localism_scores_for_frame(SEQUENCE_LOCALISM, after_depth=i)

#### Plot consistency score by input length

In [None]:
plot_localism_lines_by_model_type(SEQUENCE_LOCALISM, 'input_length', 'consistency', (1,50), [1] + list(range(5,51,5)))

#### Plot consistency score by target length

In [None]:
plot_localism_lines_by_model_type(SEQUENCE_LOCALISM, 'target_length', 'consistency', (1,50), [1] + list(range(5,51,5)))

#### Plot consistency score by sequence depth

In [None]:
plot_localism_lines_by_model_type(SEQUENCE_LOCALISM, 'depth', 'consistency', (1,10), list(range(1,11)))

#### Plot localism accuracy by input length

In [None]:
plot_localism_lines_by_model_type(SEQUENCE_LOCALISM, 'input_length', 'accuracy', (1,50), [1] + list(range(5,51,5)))

#### Plot localism accuracy by target length

In [None]:
plot_localism_lines_by_model_type(SEQUENCE_LOCALISM, 'target_length', 'accuracy', (1,50), [1] + list(range(5,51,5)))

#### Plot localism accuracy by sequence depth

In [None]:
plot_localism_lines_by_model_type(SEQUENCE_LOCALISM, 'depth', 'accuracy', (1,5), [1,2,3,4,5])

#### Plot performance by function

In [None]:
plot_performance_by_function(PERFORMANCE_BY_FUNCTION)

#### Plot performance by function where the target is not an empty token

In [None]:
plot_performance_by_function(PERFORMANCE_BY_FUNCTION, non_empty_targets=True)

#### Optional Checkpoint: persist localism dataframes and plots to Drive

In [None]:
# !cp -r pcfglet/localism*.pkl "drive/My Drive/pcfglet/"
# !cp -r pcfglet/localism*.png "drive/My Drive/pcfglet/"

#### Additional examination into poor model performance on certain functions

Check how many primitive (the only function in the sequence) instances of `intersection` and `remove_unique` are in the training set.

In [None]:
TRAIN_SET.loc[(TRAIN_SET['command'].str.contains('intersection')) & (TRAIN_SET['target'] != 'E') & (TRAIN_SET['num_functions'] == 1)]

In [None]:
TRAIN_SET.loc[(TRAIN_SET['command'].str.contains('remove_unique')) & (TRAIN_SET['target'] != 'E') & (TRAIN_SET['num_functions'] == 1)]

### Systematicity

See the separate [notebook](https://colab.research.google.com/drive/1T_wjcRu9625N9mDmDn6oDBkJFRl542gp).

Note: please ensure you have the folder titled 'Systematicity' containing the train, test and validation data and predictions saved in the root of your drive.