In [None]:
#@title Imports

# Licensed under the Apache License, Version 2.0

import json
import io
import random
import time
import numpy as np
import statistics
import tensorflow.compat.v1 as tf

tf.enable_eager_execution()
assert tf.executing_eagerly()

print(tf.version.VERSION)


In [None]:
#@title General dataset utility functions

DATASET_ADDITION_ALIGNED = "addition_aligned"
DATASET_ADDITION_NEGATIVES = "addition_negatives"
DATASET_REVERSING = "reversing-new"
DATASET_DUPLICATION = "duplication"
DATASET_CARTESIAN = "cartesian"
DATASET_INTERSECTION_BOOLEAN = "intersection_boolean-new"
DATASET_SCAN_LENGTH = "scan_length"
DATASET_SCAN_ADD_JUMP = "scan_add_jump"
DATASET_PCFG_PRODUCTIVITY = "pcfg_productivity"
DATASET_PCFG_SYSTEMATICITY = "pcfg_systematicity"
DATASET_COGS_FULL = "cogs_full"
DATASET_CFQ_MCD1 = "cfq_mcd1"
DATASET_CFQ_MCD1_INTERMEDIATE = "cfq_mcd1_intermediate"


# General tokens (common to all datasets):
PAD_TOKEN = "[PAD]"
SEP_TOKEN = "[SEP]"
END_TOKEN = "[END]"
START_TOKEN = "[START]"
END_ITERATION_TOKEN = "[ENDITER]"

PAD_TOKEN_IDX = 0
SEP_TOKEN_IDX = 1
END_TOKEN_IDX = 2
START_TOKEN_IDX = 3
END_ITERATION_TOKEN_IDX = 4


def max_length(tensor):
  return max(len(t) for t in tensor)


def decode(seq, vocab):
  """Decodes a sequence of token IDs to a string."""
  out = ""
  for tok in seq:
    if tok == 0:
      return out
    out += vocab[tok] + " "
  return out


def create_dataset_tensors(examples_in_raw, examples_out_raw, max_len_inp, max_len_targ, vocab_to_int):
  """Translates a dataset to padded token ID tensors."""
  in_tensor = []
  for example in examples_in_raw:
    tensor = [vocab_to_int[x] for x in example]
    in_tensor.append(tensor)
  out_tensor = []
  for example in examples_out_raw:
    tensor = [vocab_to_int[x] for x in example]
    out_tensor.append(tensor)
  in_tensor = tf.keras.preprocessing.sequence.pad_sequences(in_tensor, padding='post', maxlen=max_len_inp)
  out_tensor = tf.keras.preprocessing.sequence.pad_sequences(out_tensor, padding='post', maxlen=max_len_targ)

  print(in_tensor.shape)
  print(out_tensor.shape)

  return in_tensor, out_tensor


def prepare_tf_dataset_tensors(vocab, vocab_to_int, input_tensor_train, target_tensor_train,
                               input_tensor_val_list, target_tensor_val_list,
                               batch_size):
  """Converts dataset tensors to TensorFlow datasets."""

  BUFFER_SIZE = len(input_tensor_train)

  # Print some info for visual inspection:
  print(f"Training set: {len(input_tensor_train)} examples, of length {len(input_tensor_train[0])} -> {len(target_tensor_train[0])}")
  for i in range(len(input_tensor_val_list)):
    input_tensor_test = input_tensor_val_list[i]
    target_tensor_test = target_tensor_val_list[i]
    print(f"Test set: {len(input_tensor_test)} examples, of length {len(input_tensor_test[0])} -> {len(target_tensor_test[0])}")
  print(f"Example Input 1: {decode(input_tensor_train[0], vocab)}")
  print(f"Example Output 1: {decode(target_tensor_train[0], vocab)}")
  print(f"Example Input 2: {decode(input_tensor_train[-1], vocab)}")
  print(f"Example Output 2: {decode(target_tensor_train[-1], vocab)}")
  print(f"Vocab size: {len(vocab)}")
  print(f"Vocab: {vocab}")

  dataset_train = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)
  dataset_train = dataset_train.batch(batch_size, drop_remainder=True)

  dataset_val_list = []

  for i in range(len(input_tensor_val_list)):
    dataset_val = tf.data.Dataset.from_tensor_slices((input_tensor_val_list[i],
                                                      target_tensor_val_list[i])).shuffle(BUFFER_SIZE)
    dataset_val = dataset_val.batch(batch_size, drop_remainder=True)
    dataset_val_list.append(dataset_val)
  return (dataset_train, dataset_val_list)

In [None]:
#@title "Addition" dataset generation

def create_addition_dataset(trainsize, testsize, vocab, vocab_to_int,
                            reversedigits=True, leftpadding=12,
                            addAlignmentTokens=False, negativeProbability=0.0):
  """Generates the addition dataset."""

  max_number_length = 12

  def add_alignment(digits):
    digits2 = []
    pos = len(digits)
    for digit in digits:
      digits2.append(digit)
      digits2.append("P" + str(pos))
      pos -= 1
    return digits2

  def create_example(minlen, maxlen, leftpadding=0, reversedigits=False,
                     addAlignmentTokens=False, negativeProbability=0.0):
    numbers = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    digits1 = []
    digits2 = []
    digitssum = []
    int1 = 0
    int2 = 0
    l1 = random.randint(minlen, maxlen)
    l2 = random.randint(minlen, maxlen)
    for i in range(l1):
      n = random.choice(numbers)
      int1 = int1*10 + n
      digits1.append(str(n))
    for i in range(l2):
      n = random.choice(numbers)
      int2 = int2*10 + n
      digits2.append(str(n))
    if random.random() < negativeProbability:
      int1 = -int1
      digits1 = ["-"] + digits1
    if random.random() < negativeProbability:
      int2 = -int2
      digits2 = ["-"] + digits2

    sum = int1 + int2
    negatedSum = False
    sumtmp = sum
    if sumtmp < 0:
      negatedSum = True
      sumtmp = -sumtmp
    while sumtmp > 0:
      digitssum = [str(sumtmp % 10)] + digitssum
      sumtmp //= 10
    if negatedSum:
      digitssum = ["-"] + digitssum

    leftpaddingtoken = "0"
    if negativeProbability > 0:
      leftpaddingtoken = "#"

    while len(digits1) < leftpadding:
      digits1 = [leftpaddingtoken] + digits1
    while len(digits2) < leftpadding:
      digits2 = [leftpaddingtoken] + digits2
    while len(digitssum) < leftpadding:
      digitssum = [leftpaddingtoken] + digitssum

    if addAlignmentTokens:
      digits1 = add_alignment(digits1)
      digits2 = add_alignment(digits2)

    if reversedigits:
      digits1.reverse()
      digits2.reverse()
      digitssum.reverse()
    example_in = digits1 + [SEP_TOKEN] + digits2 + [END_TOKEN]
    example_out = [START_TOKEN] + digitssum + [END_TOKEN]

    return example_in, example_out


  def create_examples(n, minlen, maxlen, leftpadding=0, reversedigits=False,
                      addAlignmentTokens=False,negativeProbability=0.0):
    examples_in = []
    examples_out = []
    for i in range(n):
      ein, eout = create_example(minlen, maxlen, leftpadding=leftpadding,
                                 reversedigits=reversedigits,
                                 addAlignmentTokens=addAlignmentTokens,
                                 negativeProbability=negativeProbability)
      examples_in.append(ein)
      examples_out.append(eout)
    return examples_in, examples_out

  vocab = [PAD_TOKEN, SEP_TOKEN, END_TOKEN, START_TOKEN,
           "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
  vocab_to_int = {PAD_TOKEN:0, SEP_TOKEN:SEP_TOKEN_IDX,
                  END_TOKEN:END_TOKEN_IDX,
                  START_TOKEN:START_TOKEN_IDX,
                  "0":4, "1":5, "2":6, "3":7, "4":8,
                  "5":9, "6":10, "7":11, "8":12, "9":13}
  if addAlignmentTokens:
    for i in range(max_number_length):
      vocab_to_int["P" + str(i)] = len(vocab)
      vocab.append("P" + str(i));
  if negativeProbability > 0:
    vocab_to_int["-"] = len(vocab)
    vocab.append("-")
    vocab_to_int["#"] = len(vocab)
    vocab.append("#")

  train_examples_in_raw, train_examples_out_raw = create_examples(
      trainsize, 1, 8, leftpadding=leftpadding, reversedigits=reversedigits,
      addAlignmentTokens=addAlignmentTokens,
      negativeProbability=negativeProbability)
  test_easy_examples_in_raw, test_easy_examples_out_raw = create_examples(
      testsize, 5, 6, leftpadding=leftpadding, reversedigits=reversedigits,
      addAlignmentTokens=addAlignmentTokens,
      negativeProbability=negativeProbability)
  test_examples_in_raw, test_examples_out_raw = create_examples(
      testsize, 6, 8, leftpadding=leftpadding, reversedigits=reversedigits,
      addAlignmentTokens=addAlignmentTokens,
      negativeProbability=negativeProbability)
  test_hard_examples_in_raw, test_hard_examples_out_raw = create_examples(
      testsize, 9, 10, leftpadding=leftpadding, reversedigits=reversedigits,
      addAlignmentTokens=addAlignmentTokens,
      negativeProbability=negativeProbability)

  max_len_inp = max(max_length(train_examples_in_raw),
                    max_length(test_easy_examples_in_raw),
                    max_length(test_hard_examples_in_raw))
  max_len_targ = max(max_length(train_examples_out_raw),
                     max_length(test_easy_examples_out_raw),
                     max_length(test_hard_examples_out_raw))

  input_tensor_train, target_tensor_train = create_dataset_tensors(
      train_examples_in_raw, train_examples_out_raw, max_len_inp, max_len_targ,
      vocab_to_int)
  input_tensor_val0, target_tensor_val0 = create_dataset_tensors(
      train_examples_in_raw[0:testsize], train_examples_out_raw[0:testsize],
      max_len_inp, max_len_targ, vocab_to_int)
  input_tensor_val1, target_tensor_val1 = create_dataset_tensors(
      test_easy_examples_in_raw, test_easy_examples_out_raw, max_len_inp,
      max_len_targ, vocab_to_int)
  input_tensor_val2, target_tensor_val2 = create_dataset_tensors(
      test_examples_in_raw, test_examples_out_raw, max_len_inp, max_len_targ,
      vocab_to_int)
  input_tensor_val3, target_tensor_val3 = create_dataset_tensors(
      test_hard_examples_in_raw, test_hard_examples_out_raw, max_len_inp,
      max_len_targ, vocab_to_int)

  return (vocab, vocab_to_int,
          input_tensor_train, target_tensor_train,
          [input_tensor_val0, input_tensor_val1,
           input_tensor_val2, input_tensor_val3],
          [target_tensor_val0,target_tensor_val1,
           target_tensor_val2, target_tensor_val3])


In [None]:
#@title Algorithmic dataset generation (reversing, duplicating, cartesian, intersection)

def create_reversing_dataset(trainsize, testsize, vocab, vocab_to_int,
                             trainmindigits=1,trainmaxdigits=16,
                             testmindigits=17, testmaxdigits=24):
  """Generate the reversing dataset."""

  def create_example(minlen, maxlen):
    digits = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
    l1 = random.randint(minlen, maxlen)
    example_in = []
    for i in range(l1):
      example_in.append(random.choice(digits))
    example_out = example_in[::-1]
    example_in.append(END_TOKEN)
    example_out = [START_TOKEN] + example_out + [END_TOKEN]
    return example_in, example_out


  def create_examples(n, minlen, maxlen):
    examples_in = []
    examples_out = []
    for i in range(n):
      ein, eout = create_example(minlen, maxlen)
      examples_in.append(ein)
      examples_out.append(eout)
    return examples_in, examples_out

  vocab = [PAD_TOKEN, SEP_TOKEN, END_TOKEN, START_TOKEN,
           "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
  vocab_to_int = {PAD_TOKEN:0, SEP_TOKEN:SEP_TOKEN_IDX,
                  END_TOKEN:END_TOKEN_IDX,
                  START_TOKEN:START_TOKEN_IDX,
                  "0":4, "1":5, "2":6, "3":7, "4":8,
                  "5":9, "6":10, "7":11, "8":12, "9":13}

  train_examples_in_raw, train_examples_out_raw = create_examples(
      trainsize, trainmindigits, trainmaxdigits)
  test_easy_examples_in_raw, test_easy_examples_out_raw = create_examples(
      testsize, trainmindigits, trainmaxdigits)
  test_hard_examples_in_raw, test_hard_examples_out_raw = create_examples(
      testsize, testmindigits, testmaxdigits)

  max_len_inp = max(max_length(train_examples_in_raw), max_length(
      test_easy_examples_in_raw), max_length(test_hard_examples_in_raw))
  max_len_targ = max(max_length(train_examples_out_raw), max_length(
      test_easy_examples_out_raw), max_length(test_hard_examples_out_raw))

  input_tensor_train, target_tensor_train = create_dataset_tensors(
      train_examples_in_raw, train_examples_out_raw, max_len_inp, max_len_targ,
      vocab_to_int)
  input_tensor_val0, target_tensor_val0 = create_dataset_tensors(
      train_examples_in_raw[0:testsize], train_examples_out_raw[0:testsize],
      max_len_inp, max_len_targ, vocab_to_int)
  input_tensor_val1, target_tensor_val1 = create_dataset_tensors(
      test_easy_examples_in_raw, test_easy_examples_out_raw, max_len_inp,
      max_len_targ, vocab_to_int)
  input_tensor_val2, target_tensor_val2 = create_dataset_tensors(
      test_hard_examples_in_raw, test_hard_examples_out_raw, max_len_inp,
      max_len_targ, vocab_to_int)

  return (vocab, vocab_to_int,
          input_tensor_train, target_tensor_train,
          [input_tensor_val0, input_tensor_val1,
           input_tensor_val2],
          [target_tensor_val0,target_tensor_val1,
           target_tensor_val2])


# variableDuplication = 0: just duplicate input (used in paper)
# variableDuplication = 1: randomly duplicate 1, 2, 3 or 4 times
#                          (with an integer)
# variableDuplication = 2: same as 1, but instead of an integer, it's indicated
#                          with "1", "1,1", "1,1,1" or "1,1,1,1"
# variableDuplication = 3: same as 2, but test asks for more duplications than
#                          train (this is almost the same as cartesian product).
def create_duplicating_dataset(trainsize, testsize, vocab, vocab_to_int,
                               trainmindigits=1, trainmaxdigits=16,
                               testmindigits=17, testmaxdigits=24,
                               variableduplication=0):
  """Creates the duplicating dataset."""

  def create_example(minlen, maxlen, minduplications=1, maxduplications=4):
    digits = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
    l1 = random.randint(minlen, maxlen)
    example_in_prefix = []
    example_in = []
    n_duplications = 2
    if variableduplication == 1:
      n_duplications = random.choice([1,2,3,4])
      example_in_prefix = [digits[n_duplications], SEP_TOKEN]
    elif variableduplication == 2 or variableduplication == 3:
      n_duplications = random.randint(minduplications, maxduplications)
      for i in range(n_duplications):
        example_in_prefix.append(digits[1])
      example_in_prefix.append(SEP_TOKEN)

    for i in range(l1):
      example_in.append(random.choice(digits))
    example_out = []
    for i in range(n_duplications):
      example_out += example_in
    example_in = example_in_prefix + example_in + [END_TOKEN]
    example_out = [START_TOKEN] + example_out + [END_TOKEN]
    return example_in, example_out


  def create_examples(n, minlen, maxlen, minduplications=1, maxduplications=4):
    examples_in = []
    examples_out = []
    for i in range(n):
      ein, eout = create_example(minlen, maxlen,
                                 minduplications=minduplications,
                                 maxduplications=maxduplications)
      examples_in.append(ein)
      examples_out.append(eout)
    return examples_in, examples_out

  vocab = [PAD_TOKEN, SEP_TOKEN, END_TOKEN, START_TOKEN,
           "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
  vocab_to_int = {PAD_TOKEN:0, SEP_TOKEN:SEP_TOKEN_IDX,
                  END_TOKEN:END_TOKEN_IDX,
                  START_TOKEN:START_TOKEN_IDX,
                  "0":4, "1":5, "2":6, "3":7, "4":8,
                  "5":9, "6":10, "7":11, "8":12, "9":13}

  if variableduplication == 3:
    train_examples_in_raw, train_examples_out_raw = create_examples(
        trainsize, trainmindigits, trainmaxdigits)
    test_easy_examples_in_raw, test_easy_examples_out_raw = create_examples(
        testsize, trainmindigits, trainmaxdigits)
    test_hard_examples_in_raw, test_hard_examples_out_raw = create_examples(
        testsize, testmindigits, testmaxdigits, minduplications=5,
        maxduplications=6)
  else:
    train_examples_in_raw, train_examples_out_raw = create_examples(
        trainsize, trainmindigits, trainmaxdigits)
    test_easy_examples_in_raw, test_easy_examples_out_raw = create_examples(
        testsize, trainmindigits, trainmaxdigits)
    test_hard_examples_in_raw, test_hard_examples_out_raw = create_examples(
        testsize, testmindigits, testmaxdigits)

  max_len_inp = max(max_length(train_examples_in_raw), max_length(
      test_easy_examples_in_raw), max_length(test_hard_examples_in_raw))
  max_len_targ = max(max_length(train_examples_out_raw), max_length(
      test_easy_examples_out_raw), max_length(test_hard_examples_out_raw))

  input_tensor_train, target_tensor_train = create_dataset_tensors(
      train_examples_in_raw, train_examples_out_raw, max_len_inp, max_len_targ,
      vocab_to_int)
  input_tensor_val0, target_tensor_val0 = create_dataset_tensors(
      train_examples_in_raw[0:testsize], train_examples_out_raw[0:testsize],
      max_len_inp, max_len_targ, vocab_to_int)
  input_tensor_val1, target_tensor_val1 = create_dataset_tensors(
      test_easy_examples_in_raw, test_easy_examples_out_raw, max_len_inp,
      max_len_targ, vocab_to_int)
  input_tensor_val2, target_tensor_val2 = create_dataset_tensors(
      test_hard_examples_in_raw, test_hard_examples_out_raw, max_len_inp,
      max_len_targ, vocab_to_int)

  return (vocab, vocab_to_int,
          input_tensor_train, target_tensor_train,
          [input_tensor_val0, input_tensor_val1,
           input_tensor_val2],
          [target_tensor_val0,target_tensor_val1,
           target_tensor_val2])


def create_cartesian_dataset(trainsize, testsize, vocab, vocab_to_int,
                             trainmindigits=1,trainmaxdigits=6,
                             testmindigits=7, testmaxdigits=8, reverse=False):
  """Creates the cartessian product dataset."""

  def create_example(minlen, maxlen):
    symbols1 = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
    symbols2 = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]
    l1 = random.randint(minlen, maxlen)
    l2 = random.randint(minlen, maxlen)
    set1 = []
    set2 = []
    for i in range(l1):
      set1.append(random.choice(symbols1))
    for i in range(l2):
      set2.append(random.choice(symbols2))
    example_in = set1 + [SEP_TOKEN] + set2 + [END_TOKEN]
    example_out = []
    for i in set1:
      for j in set2:
        example_out.append(i)
        example_out.append(j)
        example_out.append(SEP_TOKEN)
    example_out.append(END_TOKEN)
    if reverse:
      return example_out, [START_TOKEN] + example_in
    else:
      return example_in, [START_TOKEN] + example_out

  def create_examples(n, minlen, maxlen):
    examples_in = []
    examples_out = []
    for i in range(n):
      ein, eout = create_example(minlen, maxlen)
      examples_in.append(ein)
      examples_out.append(eout)
    return examples_in, examples_out

  vocab = [PAD_TOKEN, SEP_TOKEN, END_TOKEN, START_TOKEN,
           "0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
           "a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]
  vocab_to_int = {PAD_TOKEN:0, SEP_TOKEN:SEP_TOKEN_IDX,
                  END_TOKEN:END_TOKEN_IDX,
                  START_TOKEN:START_TOKEN_IDX,
                  "0":4, "1":5, "2":6, "3":7, "4":8,
                  "5":9, "6":10, "7":11, "8":12, "9":13,
                  "a":14, "b":15, "c":16, "d":17, "e":18,
                  "f":19, "g":20, "h":21, "i":22, "j":23}

  train_examples_in_raw, train_examples_out_raw = create_examples(
      trainsize, trainmindigits, trainmaxdigits)
  test_easy_examples_in_raw, test_easy_examples_out_raw = create_examples(
      testsize, trainmindigits, trainmaxdigits)
  test_hard_examples_in_raw, test_hard_examples_out_raw = create_examples(
      testsize, testmindigits, testmaxdigits)

  max_len_inp = max(max_length(train_examples_in_raw), max_length(
      test_easy_examples_in_raw), max_length(test_hard_examples_in_raw))
  max_len_targ = max(max_length(train_examples_out_raw), max_length(
      test_easy_examples_out_raw), max_length(test_hard_examples_out_raw))

  input_tensor_train, target_tensor_train = create_dataset_tensors(
      train_examples_in_raw, train_examples_out_raw, max_len_inp, max_len_targ,
      vocab_to_int)
  input_tensor_val0, target_tensor_val0 = create_dataset_tensors(
      train_examples_in_raw[0:testsize], train_examples_out_raw[0:testsize],
      max_len_inp, max_len_targ, vocab_to_int)
  input_tensor_val1, target_tensor_val1 = create_dataset_tensors(
      test_easy_examples_in_raw, test_easy_examples_out_raw, max_len_inp,
      max_len_targ, vocab_to_int)
  input_tensor_val2, target_tensor_val2 = create_dataset_tensors(
      test_hard_examples_in_raw, test_hard_examples_out_raw, max_len_inp,
      max_len_targ, vocab_to_int)

  return (vocab, vocab_to_int,
          input_tensor_train, target_tensor_train,
          [input_tensor_val0, input_tensor_val1,
           input_tensor_val2],
          [target_tensor_val0,target_tensor_val1,
           target_tensor_val2])


def create_intersection_dataset(trainsize, testsize, vocab, vocab_to_int,
                                trainminelements=1, trainmaxelements=16,
                                testminelements=17,testmaxelements=24):
  """Creates the intersection boolean dataset."""

  elements = []
  for a in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]:
    for b in ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]:
      elements.append(a+b)

  def create_example(minlen, maxlen, label):
    l1 = random.randint(minlen, maxlen)
    l2 = random.randint(minlen, maxlen)
    example_in = []
    intersection = []
    set1 = []
    set2 = []
    for i in range(l1):
      element = random.choice(elements)
      if element not in set1:
        set1.append(element)
    while len(set2) < l2:
      element = random.choice(elements)
      if element not in set2:
        if element in set1:
          if label == "true":
            intersection.append(element)
            set2.append(element)
        else:
          set2.append(element)
    if label == "true" and not intersection:
      element = random.choice(set1)
      set2[random.choice(range(len(set2)))] = element
      intersection.append(element)

    example_in = set1 + [SEP_TOKEN] + set2 + [END_TOKEN]
    if intersection:
      example_out = [START_TOKEN, "true", END_TOKEN]
    else:
      example_out = [START_TOKEN, "false", END_TOKEN]
    return example_in, example_out

  def create_examples(n, minlen, maxlen):
    examples_in = []
    examples_out = []
    n_positive = 0
    for i in range(n):
      ein, eout = create_example(minlen, maxlen, ["true", "false"][i%2])
      examples_in.append(ein)
      examples_out.append(eout)
      if "true" in eout:
        n_positive += 1
    print(f"positive: {n_positive}, negative: {n - n_positive}")
    return examples_in, examples_out

  vocab = [PAD_TOKEN, SEP_TOKEN, END_TOKEN, START_TOKEN,
           "true", "false"] + elements
  vocab_to_int = {PAD_TOKEN:0, SEP_TOKEN:SEP_TOKEN_IDX,
                  END_TOKEN:END_TOKEN_IDX,
                  START_TOKEN:START_TOKEN_IDX,
                  "true": 4, "false": 5}
  for element in elements:
    vocab_to_int[element] = len(vocab_to_int)

  train_examples_in_raw, train_examples_out_raw = create_examples(
      trainsize, trainminelements, trainmaxelements)
  test_easy_examples_in_raw, test_easy_examples_out_raw = create_examples(
      testsize, trainminelements, trainmaxelements)
  test_hard_examples_in_raw, test_hard_examples_out_raw = create_examples(
      testsize, testminelements, testmaxelements)

  max_len_inp = max(max_length(train_examples_in_raw), max_length(
      test_easy_examples_in_raw), max_length(test_hard_examples_in_raw))
  max_len_targ = max(max_length(train_examples_out_raw), max_length(
      test_easy_examples_out_raw), max_length(test_hard_examples_out_raw))

  input_tensor_train, target_tensor_train = create_dataset_tensors(
      train_examples_in_raw, train_examples_out_raw, max_len_inp, max_len_targ,
      vocab_to_int)
  input_tensor_val0, target_tensor_val0 = create_dataset_tensors(
      train_examples_in_raw[0:testsize], train_examples_out_raw[0:testsize],
      max_len_inp, max_len_targ, vocab_to_int)
  input_tensor_val1, target_tensor_val1 = create_dataset_tensors(
      test_easy_examples_in_raw, test_easy_examples_out_raw, max_len_inp,
      max_len_targ, vocab_to_int)
  input_tensor_val2, target_tensor_val2 = create_dataset_tensors(
      test_hard_examples_in_raw, test_hard_examples_out_raw, max_len_inp,
      max_len_targ, vocab_to_int)

  return (vocab, vocab_to_int,
          input_tensor_train, target_tensor_train,
          [input_tensor_val0, input_tensor_val1,
           input_tensor_val2],
          [target_tensor_val0,target_tensor_val1,
           target_tensor_val2])

In [None]:
#@title "SCAN" dataset generation

# Set the appropriate path here:
to_download = [
               ["tasks_train_length.txt",
                "SCAN/length_split/tasks_train_length.txt"],
               ["tasks_test_length.txt",
                "SCAN/length_split/tasks_test_length.txt"],
               ["tasks_train_addprim_jump.txt",
                "SCAN/add_prim_split/tasks_train_addprim_jump.txt"],
               ["tasks_test_addprim_jump.txt",
                "SCAN/add_prim_split/tasks_test_addprim_jump.txt"],
               ]
uploaded = {}

# Preload the files:
for [name, path] in to_download:
  with tf.io.gfile.GFile(path, "rb") as f:
    uploaded[name] = f.read()
    lines = uploaded[name].decode("utf-8").split("\n")

def create_scan_length_dataset(vocab, vocab_to_int,
                               map_output_to_input=False):
  return create_scan_dataset(vocab, vocab_to_int,
                              "tasks_train_length.txt",
                              "tasks_test_length.txt",
                              map_output_to_input=map_output_to_input)

def create_scan_add_jump_dataset(vocab, vocab_to_int,
                             map_output_to_input=False):
  return create_scan_dataset(vocab, vocab_to_int,
                              "tasks_train_addprim_jump.txt",
                              "tasks_test_addprim_jump.txt",
                              map_output_to_input=map_output_to_input)

def create_scan_dataset(vocab, vocab_to_int,
                         train_filename,
                         test_filename,
                         map_output_to_input=False,
                         keep_all_steps=False):
  """Creates a version of the SCAN dataset."""

  vocab = [PAD_TOKEN, SEP_TOKEN, END_TOKEN, START_TOKEN,
          END_ITERATION_TOKEN]
  vocab_to_int = {PAD_TOKEN:0, SEP_TOKEN:SEP_TOKEN_IDX,
                  END_TOKEN:END_TOKEN_IDX,
                  START_TOKEN:START_TOKEN_IDX,
                  END_ITERATION_TOKEN:END_ITERATION_TOKEN_IDX}

  def create_dataset_tensors_scan(instances, maxlen_inp=None, maxlen_targ=None,
                                  keep_all_steps=False):
    in_tensor = []
    out_tensor = []
    for instance in instances:
      if keep_all_steps:
        for i in range(len(instance)-1):
          in_tensor.append(instance[i])
          out_tensor.append(instance[i+1])
      else:
        # keep only the first and last steps (ignore intermediate steps):
        in_tensor.append(instance[0])
        out_tensor.append(instance[-1])

    in_tensor = tf.keras.preprocessing.sequence.pad_sequences(
        in_tensor,padding='post', maxlen=maxlen_inp)
    out_tensor = tf.keras.preprocessing.sequence.pad_sequences(
        out_tensor, padding='post', maxlen=maxlen_targ)

    return in_tensor, out_tensor

  def tokenize_scan_line(line, vocab, vocab_to_int, map_output_to_input):
    instance_in_split = line.split(" ")

    if map_output_to_input:
      for i in range(len(instance_in_split)):
        if instance_in_split[i] == "I_WALK":
          instance_in_split[i] = "walk"
        if instance_in_split[i] == "I_JUMP":
          instance_in_split[i] = "jump"
        if instance_in_split[i] == "I_LOOK":
          instance_in_split[i] = "look"
        if instance_in_split[i] == "I_RUN":
          instance_in_split[i] = "run"

    # find the tokens:
    for token in instance_in_split:
      if token not in vocab_to_int:
        vocab_to_int[token] = len(vocab)
        vocab.append(token)

    # tokenize:
    instance_in_tokenized = (
        [START_TOKEN_IDX] +
        [vocab_to_int[x] for x in instance_in_split] +
        [END_TOKEN_IDX])
    return instance_in_tokenized

  def load_and_tokenize_data(filename):
    instances_raw = []
    instances = []
    lines = uploaded[filename].decode("utf-8").split("\n")
    for line in lines:
      if line.startswith("IN:"):
        line = line[4:]
        instance_raw = line.split(" OUT: ")
        instance = [tokenize_scan_line(instance_raw[0], vocab, vocab_to_int,
                                       map_output_to_input),
                    tokenize_scan_line(instance_raw[1], vocab, vocab_to_int,
                                       map_output_to_input)]
        instances_raw.append(instance_raw)
        instances.append(instance)

    print("# instances: " + str(len(instances)))
    return instances_raw, instances

  instances_train_raw, instances_train = load_and_tokenize_data(train_filename)
  instances_test_raw, instances_test = load_and_tokenize_data(test_filename)

  input_tensor_train, target_tensor_train = create_dataset_tensors_scan(
      instances_train, keep_all_steps=keep_all_steps)
  input_tensor_val, target_tensor_val = create_dataset_tensors_scan(
      instances_test, keep_all_steps=keep_all_steps)
  max_length_train = max_length(input_tensor_train)
  max_length_val = max_length(input_tensor_val)
  max_length_targ_train = max_length(target_tensor_train)
  max_length_targ_val = max_length(target_tensor_val)

  testsize = len(instances_test)
  max_len_inp = max(max_length_train, max_length_val)
  max_len_targ = max(max_length_targ_train, max_length_targ_val)
  input_tensor_train, target_tensor_train = create_dataset_tensors_scan(
      instances_train, maxlen_inp=max_len_inp, maxlen_targ=max_len_targ,
      keep_all_steps=keep_all_steps)
  input_tensor_val0, target_tensor_val0 = create_dataset_tensors_scan(
      instances_train[0:testsize], maxlen_inp=max_len_inp,
      maxlen_targ=max_len_targ, keep_all_steps=keep_all_steps)
  input_tensor_val, target_tensor_val = create_dataset_tensors_scan(
      instances_test, maxlen_inp=max_len_inp, maxlen_targ=max_len_targ,
      keep_all_steps=keep_all_steps)

  return (vocab, vocab_to_int,
          input_tensor_train, target_tensor_train,
          [input_tensor_val0, input_tensor_val],
          [target_tensor_val0, target_tensor_val])


In [None]:
#@title "PCFG" dataset generation

# Set the appropriate path here:
to_download_pcfg = [
                    ["pcfg_productivity_train",
                     "pcfg/productivity/train.src",
                     "pcfg/productivity/train.tgt"],
                    ["pcfg_productivity_test",
                     "pcfg/productivity/test.src",
                     "pcfg/productivity/test.tgt"],

                    ["pcfg_systematicity_train",
                     "pcfg/systematicity/train.src",
                     "pcfg/systematicity/train.tgt"],
                    ["pcfg_systematicity_test",
                     "pcfg/systematicity/test.src",
                     "pcfg/systematicity/test.tgt"],
                   ]
uploaded_pcfg = {}

MAX_TRAIN_LEN = 128
MAX_TEST_LEN = 256

# Preloading the PCFG files:
for [name, pathin, pathout] in to_download_pcfg:
  with (tf.io.gfile.GFile(pathin, "rb") as fin,
        tf.io.gfile.GFile(pathout, "rb") as fout):
    uploaded_pcfg[name] = [fin.read(), fout.read()]
    lines = uploaded_pcfg[name][0].decode("utf-8").split("\n")
    print(name + ": " + str(len(lines)) + " lines")


def create_pcfg_dataset(pcfg_split, vocab, vocab_to_int):
  """Creates a version of the PCFG dataset."""

  vocab = [PAD_TOKEN, SEP_TOKEN, END_TOKEN, START_TOKEN,
          END_ITERATION_TOKEN]
  vocab_to_int = {PAD_TOKEN:0, SEP_TOKEN:SEP_TOKEN_IDX,
                  END_TOKEN:END_TOKEN_IDX,
                  START_TOKEN:START_TOKEN_IDX,
                  END_ITERATION_TOKEN:END_ITERATION_TOKEN_IDX}

  def create_dataset_tensors_pcfg(instances, maxlen_inp=None, maxlen_targ=None):
    in_tensor = []
    out_tensor = []
    for instance in instances:
      for i in range(len(instance)-1):
        in_tensor.append(instance[i])
        # out_tensor.append(instance[i+1] + [END_ITERATION_TOKEN])
        out_tensor.append(instance[i+1])

    in_tensor = tf.keras.preprocessing.sequence.pad_sequences(
        in_tensor, padding='post', maxlen=maxlen_inp)
    out_tensor = tf.keras.preprocessing.sequence.pad_sequences(
        out_tensor, padding='post', maxlen=maxlen_targ)

    return in_tensor, out_tensor

  def load_and_tokenize_data(filename, maxlen):
    max_in_len = 0
    max_out_len = 0
    instances_raw = []
    instances = []
    lines_in = uploaded_pcfg[filename][0].decode("utf-8").split("\n")
    lines_out = uploaded_pcfg[filename][1].decode("utf-8").split("\n")
    instance_raw = []
    instance = []
    for i in range(len(lines_in)):
      instance_raw = [lines_in[i], lines_out[i]]
      instances_raw.append(instance_raw)

      for instance_part in instance_raw:
        for token in instance_part.split(" "):
          if token not in vocab_to_int:
            vocab_to_int[token] = len(vocab)
            vocab.append(token)

      # tokenize:
      instance_in_tokenized = (
          [START_TOKEN_IDX] +
          [vocab_to_int[x] for x in instance_raw[0].split(" ")] +
          [END_TOKEN_IDX])
      instance_out_tokenized = (
          [START_TOKEN_IDX] +
          [vocab_to_int[x] for x in instance_raw[1].split(" ")] +
          [END_TOKEN_IDX])
      if len(instance_out_tokenized) > maxlen:
        continue
      instances.append([instance_in_tokenized, instance_out_tokenized])

      max_in_len = max(max_in_len, len(instance_in_tokenized))
      max_out_len = max(max_out_len, len(instance_out_tokenized))

    print("max_in_len: " + str(max_in_len))
    print("max_out_len: " + str(max_out_len))
    return instances_raw, instances

  instances_train_raw, instances_train = load_and_tokenize_data(
      "pcfg_"+pcfg_split+"_train", MAX_TRAIN_LEN)
  instances_test_raw, instances_test = load_and_tokenize_data(
      "pcfg_"+pcfg_split+"_test", MAX_TEST_LEN)

  input_tensor_train, target_tensor_train = create_dataset_tensors_pcfg(
      instances_train)
  input_tensor_val, target_tensor_val = create_dataset_tensors_pcfg(
      instances_test)
  max_length_train = max_length(input_tensor_train)
  max_length_val = max_length(input_tensor_val)
  max_length_targ_train = max_length(target_tensor_train)
  max_length_targ_val = max_length(target_tensor_val)

  testset_size = len(instances_test)
  max_len_inp = max(max_length_train, max_length_val)
  max_len_targ = max(max_length_targ_train, max_length_targ_val)
  input_tensor_train, target_tensor_train = create_dataset_tensors_pcfg(
      instances_train, maxlen_inp=max_len_inp, maxlen_targ=max_len_targ)
  input_tensor_val0, target_tensor_val0 = create_dataset_tensors_pcfg(
      instances_train[0:testset_size], maxlen_inp=max_len_inp,
      maxlen_targ=max_len_targ)
  input_tensor_val, target_tensor_val = create_dataset_tensors_pcfg(
      instances_test, maxlen_inp=max_len_inp, maxlen_targ=max_len_targ)

  return (vocab, vocab_to_int,
          input_tensor_train, target_tensor_train,
          [input_tensor_val0, input_tensor_val],
          [target_tensor_val0, target_tensor_val])

In [None]:
#@title "COGS" dataset generation

# Set the appropriate path here:
to_download_cogs = [
                    ["COGS_train",
                     "COGS/data/train.tsv"],
                    ["COGS_test",
                     "COGS/data/test.tsv"],
                    ["COGS_gen",
                     "COGS/data/gen.tsv"],
                   ]
uploaded_cogs = {}

# Preloading all the COGS files:
for [name, path] in to_download_cogs:
  with tf.io.gfile.GFile(path, "rb") as f:
    uploaded_cogs[name] = f.read()
    lines = uploaded_cogs[name].decode("utf-8").split("\n")
    print(name + ": " + str(len(lines)) + " lines")

def create_cogs_dataset(vocab, vocab_to_int, max_len=256):
  """Generates the COGS dataset.

  Note: set 'max_len' to 512 to ensure even the longest instances fit.
  Otherwise, longer instances will be discarded. The function that loads this
  dataset via the 'DATASET_COGS_FULL' name (used in the paper) already does
  this."""

  MAX_COGS_TRAIN_LEN = max_len
  MAX_COGS_TEST_LEN = max_len

  vocab = [PAD_TOKEN, SEP_TOKEN, END_TOKEN, START_TOKEN,
          END_ITERATION_TOKEN]
  vocab_to_int = {PAD_TOKEN:0, SEP_TOKEN:SEP_TOKEN_IDX,
                  END_TOKEN:END_TOKEN_IDX,
                  START_TOKEN:START_TOKEN_IDX,
                  END_ITERATION_TOKEN:END_ITERATION_TOKEN_IDX}

  def create_dataset_tensors_cogs(instances, maxlen_inp=None, maxlen_targ=None):
    in_tensor = []
    out_tensor = []
    for instance in instances:
      for i in range(len(instance)-1):
        in_tensor.append(instance[i])
        # out_tensor.append(instance[i+1] + [END_ITERATION_TOKEN])
        out_tensor.append(instance[i+1])

    in_tensor = tf.keras.preprocessing.sequence.pad_sequences(
        in_tensor, padding='post', maxlen=maxlen_inp)
    out_tensor = tf.keras.preprocessing.sequence.pad_sequences(
        out_tensor, padding='post', maxlen=maxlen_targ)

    return in_tensor, out_tensor

  def load_and_tokenize_data_by_distribution(filename, maxlen):
    max_in_len = 0
    max_out_len = 0
    instances_raw = []
    instances = []
    lines = uploaded_cogs[filename].decode("utf-8").split("\n")
    distribution_names = []
    lines_by_distribution = {}
    for line in lines:
      parts = line.split("\t")
      if len(parts) == 3:
        distribution = parts[2]
        if distribution in lines_by_distribution:
          lines_by_distribution[distribution].append(line)
        else:
          distribution_names.append(distribution)
          lines_by_distribution[distribution] = [line]

    for distribution in distribution_names:
      print(f"{distribution}: {len(lines_by_distribution[distribution])}")

    instances_raw_by_distribution = []
    instances_by_distribution = []
    for distribution in distribution_names:
      intances_raw, instances = load_and_tokenize_data_internal(
          lines_by_distribution[distribution], maxlen)
      instances_raw_by_distribution.append(intances_raw)
      instances_by_distribution.append(instances)
    return instances_raw_by_distribution, instances_by_distribution


  def load_and_tokenize_data(filename, maxlen):
    lines = uploaded_cogs[filename].decode("utf-8").split("\n")
    return load_and_tokenize_data_internal(lines, maxlen)


  def load_and_tokenize_data_internal(lines, maxlen):
    max_in_len = 0
    max_out_len = 0
    instances_raw = []
    instances = []
    lines_in = []
    lines_out = []
    for line in lines:
      parts = line.split("\t")
      if len(parts) == 3:
        lines_in.append(parts[0])
        lines_out.append(parts[1])
    instance_raw = []
    instance = []
    for i in range(len(lines_in)):
      instance_raw = [lines_in[i], lines_out[i]]
      instances_raw.append(instance_raw)

      for instance_part in instance_raw:
        for token in instance_part.split(" "):
          if token not in vocab_to_int:
            vocab_to_int[token] = len(vocab)
            vocab.append(token)

      # tokenize:
      instance_in_tokenized = (
          [START_TOKEN_IDX] +
          [vocab_to_int[x] for x in instance_raw[0].split(" ")] +
          [END_TOKEN_IDX])
      instance_out_tokenized = (
          [START_TOKEN_IDX] +
          [vocab_to_int[x] for x in instance_raw[1].split(" ")] +
          [END_TOKEN_IDX])
      if len(instance_out_tokenized) > maxlen:
        continue
      instances.append([instance_in_tokenized, instance_out_tokenized])

      max_in_len = max(max_in_len, len(instance_in_tokenized))
      max_out_len = max(max_out_len, len(instance_out_tokenized))

    print("max_in_len: " + str(max_in_len))
    print("max_out_len: " + str(max_out_len))
    return instances_raw, instances

  instances_train_raw, instances_train = load_and_tokenize_data(
      "COGS_train", MAX_COGS_TRAIN_LEN)
  instances_test_raw, instances_test = load_and_tokenize_data(
      "COGS_test", MAX_COGS_TEST_LEN)
  instances_gen_raw, instances_gen = load_and_tokenize_data(
      "COGS_gen", MAX_COGS_TEST_LEN)
  instances_gen_raw_by_dist, instances_gen_by_dist = (
      load_and_tokenize_data_by_distribution("COGS_gen", MAX_COGS_TEST_LEN))

  instances_list = [instances_train, instances_train[:len(instances_test)],
                    instances_test] + instances_gen_by_dist + [instances_gen]

  input_tensors = []
  target_tensors = []
  max_len_inp = 0
  max_len_targ = 0
  for instances in instances_list:
    input_tensor, target_tensor = create_dataset_tensors_cogs(instances)
    input_tensors.append(input_tensor)
    target_tensors.append(target_tensor)
    max_len_inp = max(max_len_inp, max_length(input_tensor))
    max_len_targ = max(max_len_targ, max_length(input_tensor))

  input_tensors = []
  target_tensors = []
  for instances in instances_list:
    input_tensor, target_tensor = create_dataset_tensors_cogs(
        instances, maxlen_inp=max_len_inp, maxlen_targ=max_len_targ)
    input_tensors.append(input_tensor)
    target_tensors.append(target_tensor)
    max_len_inp = max(max_len_inp, max_length(input_tensor))
    max_len_targ = max(max_len_targ, max_length(input_tensor))

  return (vocab, vocab_to_int,
          input_tensors[0], target_tensors[0],
          input_tensors[1:], target_tensors[1:])


In [None]:
#@title "CFQ" dataset generation

# Set the appropriate path here:
to_download_cfq = [
                    ["cfq_mcd1_train",
                     "cfq/mcd1/train.txt"],
                    ["cfq_mcd1_dev",
                     "cfq/mcd1/dev.txt"],
                    ["cfq_mcd1_test",
                     "cfq/mcd1/test.txt"],
                   ]
uploaded_cfq = {}

MAX_TRAIN_LEN = 256
MAX_TEST_LEN = 256

# simulate uploading them from local hard drive:
for [name, path] in to_download_cfq:
  with tf.io.gfile.GFile(path, "rb") as f:
    uploaded_cfq[name] = f.read()
    lines = uploaded_cfq[name].decode("utf-8").split("\n")
    print(name + ": " + str(len(lines)) + " lines")


def cfq_decompose_output(line):
  tokens = line.split(" ")
  prefix = ""
  postfix = ""
  triplets_text = ""
  state = 0
  for token in tokens:
    if state == 0:
      if token == "{":
        prefix += token + " "
        state = 1
      else:
        prefix += token + " "
    elif state == 1:
      if token == "}":
        postfix += token + " "
        state = 2
      else:
        triplets_text += token + " "
    else:
      postfix += token + " "
  triplets = triplets_text.strip().split(" . ")
  return prefix, triplets, postfix


def cfq_rewrite_cartesian(triplets):
  if not triplets:
    return triplets
  triplet = triplets[0]
  tokens = triplet.split(" ")
  if len(tokens) == 3 and tokens[1] != "a":
    relation = tokens[1]
    left_tokens = [tokens[0]]
    right_tokens = [tokens[2]]
    relation_pairs = [(tokens[0], tokens[2])]
    to_delete = [triplet]
    to_keep = []
    for triplet2 in triplets[1:]:
      tokens2 = triplet2.split(" ")
      if len(tokens2) == 3 and tokens2[1] == relation:
        relation_pairs.append((tokens2[0], tokens2[2]))
        if tokens2[0] not in left_tokens:
          left_tokens.append(tokens2[0])
        if tokens2[2] not in right_tokens:
          right_tokens.append(tokens2[2])
        to_delete.append(triplet2)
      else:
        to_keep.append(triplet2)
    # See if it's a cartesian product:
    any_missing = False
    for left_token in left_tokens:
      for right_token in right_tokens:
        if (left_token, right_token) not in relation_pairs:
          any_missing = True
          break
      if any_missing:
        break
    if any_missing:
      return ["( " + tokens[0] + " ) ( " + relation + " ) ( " + tokens[2] +
              " )"] + cfq_rewrite_cartesian(triplets[1:])
    else:
      # we have a cartesian product!
      new_triplet = ("( " + " ".join(left_tokens) + " ) ( " + relation + " ) ( "
      + " ".join(right_tokens) + " )")
      return [new_triplet] + cfq_rewrite_cartesian(to_keep)

  else:
    return [triplet] + cfq_rewrite_cartesian(triplets[1:])


def cfq_merge_cartesians(triplets):
  if not triplets:
    return triplets
  triplet = triplets[0]
  if triplet[0] == "(":
    tokens = triplet.split(" ) ( ")
    if len(tokens) == 3:
      to_keep = []
      relations = [tokens[1]]
      for triplet2 in triplets[1:]:
        if triplet2[0] == "(":
          tokens2 = triplet2.split(" ) ( ")
          if (len(tokens2) == 3 and tokens[0] == tokens2[0] and
              tokens[2] == tokens2[2]):
            relations.append(tokens2[1])
          else:
            to_keep.append(triplet2)
        else:
          to_keep.append(triplet2)
      new_triplet = (tokens[0] + " ) ( " + " ".join(relations) + " ) ( " +
                     tokens[2]);
      return [new_triplet] + cfq_merge_cartesians(to_keep)
    else:
      return [triplet] + cfq_merge_cartesians(triplets[1:])
  else:
    return [triplet] + cfq_merge_cartesians(triplets[1:])


def simplify_cfq_output(output):
  prefix, triplets, postfix = cfq_decompose_output(output)
  triplets = cfq_rewrite_cartesian(triplets)
  triplets = cfq_merge_cartesians(triplets)
  return prefix + " . ".join(triplets) + " " + postfix


def expand_cfq_output(output):
  prefix, triplets, postfix = cfq_decompose_output(output)
  new_triplets = []
  for triplet in triplets:
    if triplet[0] == "(":
      tokens = triplet.split(" ) ( ")
      left_elements = tokens[0][2:].split(" ")
      relations = tokens[1].split(" ")
      right_elements = tokens[2][:-2].split(" ")
      for left_element in left_elements:
        for relation in relations:
          for right_element in right_elements:
            new_triplets.append(left_element + " " + relation + " " +
                                right_element)
    else:
      new_triplets.append(triplet)
  return prefix + " . ".join(new_triplets) + " " + postfix


def cfq_equivalent_output(output1, output2):
  prefix1, triplets1, postfix1 = cfq_decompose_output(output1.strip())
  prefix2, triplets2, postfix2 = cfq_decompose_output(output2.strip())
  if prefix1 != prefix2:
    return False
  if postfix1 != postfix2:
    return False
  if len(triplets1) != len(triplets2):
    return False
  for triplet in triplets1:
    if triplet not in triplets2:
      return False
  for triplet in triplets2:
    if triplet not in triplets1:
      return False
  return True


def create_cfq_mcd1_dataset(vocab, vocab_to_int, simplify_cartesians=False):
  """Creates a version of the CFQ dataset."""

  vocab = [PAD_TOKEN, SEP_TOKEN, END_TOKEN, START_TOKEN,
          END_ITERATION_TOKEN]
  vocab_to_int = {PAD_TOKEN:0, SEP_TOKEN:SEP_TOKEN_IDX,
                  END_TOKEN:END_TOKEN_IDX,
                  START_TOKEN:START_TOKEN_IDX,
                  END_ITERATION_TOKEN:END_ITERATION_TOKEN_IDX}

  def create_dataset_tensors_cfq(instances, maxlen_inp=None, maxlen_targ=None):
    in_tensor = []
    out_tensor = []
    for instance in instances:
      for i in range(len(instance)-1):
        in_tensor.append(instance[i])
        out_tensor.append(instance[i+1])

    in_tensor = tf.keras.preprocessing.sequence.pad_sequences(
        in_tensor, padding='post', maxlen=maxlen_inp)
    out_tensor = tf.keras.preprocessing.sequence.pad_sequences(
        out_tensor, padding='post', maxlen=maxlen_targ)

    return in_tensor, out_tensor

  def load_and_tokenize_data(filename, maxlen):
    max_in_len = 0
    max_out_len = 0
    instances_raw = []
    instances = []
    lines = uploaded_cfq[filename].decode("utf-8").split("\n")
    lines_in = []
    lines_out = []
    for i in range(len(lines)//2):
      lines_in.append(lines[i*2].strip().replace("INPUT: ", ""))
      lines_out.append(lines[i*2+1].strip().replace("OUTPUT: ", ""))
    instance_raw = []
    instance = []
    for i in range(len(lines_in)):
      instance_raw = [lines_in[i], lines_out[i]]
      if simplify_cartesians:
        instance_raw[1] = simplify_cfq_output(lines_out[i])
      instances_raw.append(instance_raw)

      for instance_part in instance_raw:
        for token in instance_part.split(" "):
          if token not in vocab_to_int:
            vocab_to_int[token] = len(vocab)
            vocab.append(token)

      # tokenize:
      instance_in_tokenized = (
          [START_TOKEN_IDX] +
          [vocab_to_int[x] for x in instance_raw[0].split(" ")] +
          [END_TOKEN_IDX])
      instance_out_tokenized = (
          [START_TOKEN_IDX] +
          [vocab_to_int[x] for x in instance_raw[1].split(" ")] +
          [END_TOKEN_IDX])
      if len(instance_out_tokenized) > maxlen:
        continue
      instances.append([instance_in_tokenized, instance_out_tokenized])

      max_in_len = max(max_in_len, len(instance_in_tokenized))
      max_out_len = max(max_out_len, len(instance_out_tokenized))

    print("max_in_len: " + str(max_in_len))
    print("max_out_len: " + str(max_out_len))
    return instances_raw, instances

  instances_train_raw, instances_train = load_and_tokenize_data(
      "cfq_mcd1_train", MAX_TRAIN_LEN)
  instances_test_raw, instances_test = load_and_tokenize_data(
      "cfq_mcd1_dev", MAX_TEST_LEN)
  instances_gen_raw, instances_gen = load_and_tokenize_data(
      "cfq_mcd1_test", MAX_TEST_LEN)

  input_tensor_train, target_tensor_train = create_dataset_tensors_cfq(
      instances_train)
  input_tensor_val, target_tensor_val = create_dataset_tensors_cfq(
      instances_test)
  input_tensor_val2, target_tensor_val2 = create_dataset_tensors_cfq(
      instances_gen)
  max_length_train = max_length(input_tensor_train)
  max_length_val = max_length(input_tensor_val)
  max_length_val2 = max_length(input_tensor_val2)
  max_length_targ_train = max_length(target_tensor_train)
  max_length_targ_val = max_length(target_tensor_val)
  max_length_targ_val2 = max_length(target_tensor_val2)

  testset_size = len(instances_test)
  max_len_inp = max(max_length_train, max_length_val, max_length_val2)
  max_len_targ = max(max_length_targ_train, max_length_targ_val,
                     max_length_targ_val2)
  input_tensor_train, target_tensor_train = create_dataset_tensors_cfq(
      instances_train, maxlen_inp=max_len_inp, maxlen_targ=max_len_targ)
  input_tensor_val0, target_tensor_val0 = create_dataset_tensors_cfq(
      instances_train[0:testset_size], maxlen_inp=max_len_inp,
      maxlen_targ=max_len_targ)
  input_tensor_val, target_tensor_val = create_dataset_tensors_cfq(
      instances_test, maxlen_inp=max_len_inp, maxlen_targ=max_len_targ)
  input_tensor_val2, target_tensor_val2 = create_dataset_tensors_cfq(
      instances_gen, maxlen_inp=max_len_inp, maxlen_targ=max_len_targ)

  return (vocab, vocab_to_int,
          input_tensor_train, target_tensor_train,
          [input_tensor_val0, input_tensor_val, input_tensor_val2],
          [target_tensor_val0, target_tensor_val, target_tensor_val2])

In [None]:
#@title This code block defines Transformer Encoder/Decoder layers, following this tutorial (adding relative attention, a copy decoder, and a few other improvements): https://www.tensorflow.org/tutorials/text/transformer

# Attention types:
RELATIVE16 = "relative16"
RELATIVE16_BIASONLY = "relative16_biasonly"
RELATIVE16_BIAS = "relative16_bias"

RELATIVE16_D2E = "relative16_d2e"  # also relative positions for dec2enc
RELATIVE16_D2E_BIASONLY = "relative16_d2e_biasonly"
RELATIVE16_D2E_BIAS = "relative16_d2e_bias"

ABSOLUTE_SINUSOIDAL = "absolute_sinusoidal"

RELATIVE = "relative"
RELATIVE_BIASONLY = "relative_biasonly"
RELATIVE_BIAS = "relative_bias"

# COPY_DECODER2 factors the copy decoder feedforward layer, to
# prevent it from exploding in datasets with large vocabs:
STANDARD_DECODER = "standard_decoder"
COPY_DECODER = "copy_decoder"
COPY_DECODER2 = "copy_decoder2"

NO_REALFORMER = "no_realformer"
REALFORMER_LOGITS = "realformer_logits"
REALFORMER_LOGITSPSF = "realformer_logits_plus_softmax"

STANDARD_FF = "standard_ff"

GREEDY_DECODING = "greedy_decoding"


def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates

def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

  pos_encoding = angle_rads[np.newaxis, ...]

  return tf.cast(pos_encoding, dtype=tf.float32)

def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)


def scaled_dot_product_relative_attention(q, k, v,
                                          mask=None,
                                          relative_ids=None,
                                          relative_embeddings=None,
                                          relative_biases=None,
                                          attention_logits_residual=None):
  """Calculate the attention weights.
  - If relative_ids or relative_embeddings is None, then this is equivalent to
    regular scaled_dot_product_attention.
  - q, k, v must have matching leading dimensions.
  - k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  - The mask has different shapes depending on its type(padding or look ahead)
    but it must be broadcastable for addition.
  - relative_ids is an optional [seq_len_q, seq_len_k] with the relative ids.
  - relative_embeddings is an optional dense layer that converts OHE ids to
    embeddings.
  - relative_biases: optional dense layer that generates a bias term for the
    attention logits based on the relative_ids
  - attention_logits_residual is an optional parameter to implement
    Realformer-style residual connections between Transformer layers.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable
          to (..., seq_len_q, seq_len_k). Defaults to None.
    relative_ids == (seq_len_q, seq_len_k). Defaults to None.
    relative_embeddings == dense layer. Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  if relative_ids is not None:
    if relative_embeddings is not None:
      r = relative_embeddings(relative_ids)
      matmul_qrel = tf.einsum("bhqd,qkd->bhqk", q, r)
      matmul_qk += matmul_qrel
    if relative_biases is not None:
      matmul_qk += tf.squeeze(relative_biases(relative_ids), axis=-1)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)

  # Realformer style resitual connection:
  if attention_logits_residual is not None:
    scaled_attention_logits +=  attention_logits_residual

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  # (..., seq_len_q, seq_len_k)
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, scaled_attention_logits


class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, relative_radius,
               position_encodings=RELATIVE):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model
    self.relative_vocab_size = relative_radius*2+1

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

    if (position_encodings == RELATIVE or
        position_encodings == RELATIVE_BIAS):
      self.relative_embeddings = tf.keras.layers.Embedding(
          self.relative_vocab_size, self.depth)
    else:
      self.relative_embeddings = None
    if (position_encodings == RELATIVE_BIAS or
        position_encodings == RELATIVE_BIASONLY):
      self.relative_biases = tf.keras.layers.Embedding(
          self.relative_vocab_size, 1)
    else:
      self.relative_biases = None

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len,
    depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q,
           mask=None,
           relative_ids=None,
           attention_logits_residual=None):
    batch_size = tf.shape(q)[0]

    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)

    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_logits = scaled_dot_product_relative_attention(
        q, k, v, mask, relative_ids, self.relative_embeddings,
        self.relative_biases,
        attention_logits_residual)

    # (batch_size, seq_len_q, num_heads, depth)
    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])

    # (batch_size, seq_len_q, d_model)
    concat_attention = tf.reshape(scaled_attention,
                                  (batch_size, -1, self.d_model))

    # (batch_size, seq_len_q, d_model)
    output = self.dense(concat_attention)

    return output, attention_logits


def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff,
                            activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])


class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1, relative_radius=8,
               position_encodings=RELATIVE):
    super(EncoderLayer, self).__init__()

    self.mha = MultiHeadAttention(d_model, num_heads, relative_radius,
                                  position_encodings)
    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask, relative_ids,
           attention_logits_residual=None):

    attn_output, attention_logits = self.mha(
        x, x, x, mask, relative_ids,
        attention_logits_residual)  # (batch_size, input_seq_len, d_model)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(
        x + attn_output)  # (batch_size, input_seq_len, d_model)

    ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
    ffn_output = self.dropout2(ffn_output, training=training)
    out2 = self.layernorm2(
        out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

    return out2, attention_logits

class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1, relative_radius=8,
               position_encodings=RELATIVE):
    super(DecoderLayer, self).__init__()

    self.mha1 = MultiHeadAttention(d_model, num_heads, relative_radius,
                                   position_encodings)
    self.mha2 = MultiHeadAttention(d_model, num_heads, relative_radius,
                                   position_encodings)

    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)


  def call(self, x, enc_output, training,
           look_ahead_mask, padding_mask,
           dec_relative_ids, dec2enc_relative_ids,
           attention_logits_residual1,
           attention_logits_residual2):
    # enc_output.shape == (batch_size, input_seq_len, d_model)
    attn1, attn_logits_block1 = self.mha1(
        x, x, x, look_ahead_mask, dec_relative_ids,
        attention_logits_residual1)  # (batch_size, target_seq_len, d_model)
    attn1 = self.dropout1(attn1, training=training)
    out1 = self.layernorm1(attn1 + x)

    attn2, attn_logits_block2 = self.mha2(
        enc_output, enc_output, out1,
        padding_mask, dec2enc_relative_ids,
        attention_logits_residual2)  # (batch_size, target_seq_len, d_model)
    attn2 = self.dropout2(attn2, training=training)
    out2 = self.layernorm2(
        attn2 + out1)  # (batch_size, target_seq_len, d_model)

    ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = self.layernorm3(
        ffn_output + out2)  # (batch_size, target_seq_len, d_model)

    return out3, attn_logits_block1, attn_logits_block2


class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
               maximum_position_encoding, rate=0.1, relative_radius=8,
               position_encodings=RELATIVE,
               shared_layers=False,
               realformer_connections=False):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers
    self.position_encodings = position_encodings
    self.shared_layers = shared_layers
    self.realformer_connections = realformer_connections

    self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    if self.position_encodings == ABSOLUTE_SINUSOIDAL:
      self.pos_encoding = positional_encoding(maximum_position_encoding,
                                              self.d_model)


    if self.shared_layers:
      layer = EncoderLayer(d_model, num_heads, dff, rate,
                           relative_radius=relative_radius,
                           position_encodings=position_encodings)
      self.enc_layers = [layer for _ in range(num_layers)]
    else:
      self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate,
                                      relative_radius=relative_radius,
                                      position_encodings=position_encodings)
                         for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask, relative_ids):

    seq_len = tf.shape(x)[1]

    # adding embedding and position encoding.
    x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    if self.position_encodings == ABSOLUTE_SINUSOIDAL:
      relative_ids = None
      x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    attention_for_viz = []  # we accumulate attention logits for visualization
    attention_logits_residual = None
    for i in range(self.num_layers):
      x, attention_logits_residual = self.enc_layers[i](
          x, training, mask=mask, relative_ids=relative_ids,
          attention_logits_residual=attention_logits_residual)
      attention_for_viz.append(attention_logits_residual)
      if self.realformer_connections == NO_REALFORMER:
        attention_logits_residual = None

    return x, attention_for_viz  # x: (batch_size, input_seq_len, d_model)

  def detailed_param_count(self):
    print(f"  encoder embedding: {self.embedding.count_params()}")
    if self.shared_layers:
      print(f"  encoder layer weights (shared): {self.enc_layers[0].count_params()}")
    else:
      print(f"  encoder layer weights: {self.enc_layers[0].count_params()*len(self.enc_layers)}")


class Decoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
               maximum_position_encoding, rate=0.1, relative_radius=8,
               position_encodings=RELATIVE,
               shared_layers=False,
               realformer_connections=NO_REALFORMER):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers
    self.position_encodings = position_encodings
    self.shared_layers = shared_layers
    self.realformer_connections = realformer_connections

    self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
    if self.position_encodings == ABSOLUTE_SINUSOIDAL:
      self.pos_encoding = positional_encoding(maximum_position_encoding,
                                              d_model)

    if self.shared_layers:
      layer = DecoderLayer(d_model, num_heads, dff, rate,
                           relative_radius=relative_radius,
                           position_encodings=position_encodings)
      self.dec_layers = [layer for _ in range(num_layers)]
    else:
      self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate,
                                      relative_radius=relative_radius,
                                      position_encodings=position_encodings)
                        for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, enc_output, training,
           look_ahead_mask, padding_mask,
           dec_relative_ids, dec2enc_relative_ids):

    seq_len = tf.shape(x)[1]

    x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    if self.position_encodings == ABSOLUTE_SINUSOIDAL:
      dec_relative_ids = None
      dec2enc_relative_ids = None
      x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    dec2dec_attention_for_viz = []
    dec2enc_attention_for_viz = []
    block1 = None
    block2 = None
    for i in range(self.num_layers):
      x, block1, block2 = self.dec_layers[i](
          x, enc_output, training, look_ahead_mask, padding_mask,
          dec_relative_ids, dec2enc_relative_ids,
          attention_logits_residual1 = block1,
          attention_logits_residual2 = block2)
      dec2dec_attention_for_viz.append(block1)
      dec2enc_attention_for_viz.append(block2)
      if self.realformer_connections == NO_REALFORMER:
        block1 = None
        block2 = None

    # x.shape == (batch_size, target_seq_len, d_model)
    return x, dec2dec_attention_for_viz, dec2enc_attention_for_viz

  def detailed_param_count(self):
    print(f"  decoder embedding: {self.embedding.count_params()}")
    if self.shared_layers:
      print(f"  decoder layer weights (shared): {self.dec_layers[0].count_params()}")
    else:
      print(f"  decoder layer weights: {self.dec_layers[0].count_params()*len(self.dec_layers)}")


class Transformer(tf.keras.Model):
  def __init__(self, num_layers, d_model, num_heads, dff, vocab_size,
               pe_input, pe_target, rate=0.1,
               relative_radius=8, position_encodings=RELATIVE,
               copy_decoder=STANDARD_DECODER,
               shared_layers=False,
               realformer_connections=False):
    super(Transformer, self).__init__()
    self.d_model = d_model
    self.vocab_size = vocab_size
    self.copy_decoder = copy_decoder

    self.encoder = Encoder(num_layers, d_model, num_heads, dff,
                           vocab_size, pe_input, rate,
                           relative_radius=relative_radius,
                           position_encodings=position_encodings,
                           shared_layers=shared_layers,
                           realformer_connections=realformer_connections)

    self.decoder = Decoder(num_layers, d_model, num_heads, dff,
                           vocab_size, pe_target, rate,
                           relative_radius=relative_radius,
                           position_encodings=position_encodings,
                           shared_layers=shared_layers,
                           realformer_connections=realformer_connections)

    self.final_layer = tf.keras.layers.Dense(vocab_size)
    if self.copy_decoder == COPY_DECODER:
      # pe_input is the maximum input length
      self.final_layer_copy = tf.keras.layers.Dense(d_model)
      self.final_layer_copy_weight = tf.keras.layers.Dense(
          1, activation="sigmoid")
    elif self.copy_decoder == COPY_DECODER2:
      self.final_layer_copy = tf.keras.layers.Dense(d_model)
      self.final_layer_copy_weight = tf.keras.layers.Dense(
          1, activation="sigmoid")
      # We want vocab_size -> vocab_size, but that might be too big.
      # So, we do a low-rank approximation, bringing it down to d_model first,
      # in case d_model < vocab_size:
      if self.d_model < self.vocab_size:
        self.final_layer_copy2a = tf.keras.layers.Dense(self.d_model)
        self.final_layer_copy2b = tf.keras.layers.Dense(self.vocab_size)
      else:
        self.final_layer_copy2 = tf.keras.layers.Dense(self.vocab_size)


  def call(self, inp, tar, training, enc_padding_mask,
           look_ahead_mask, dec_padding_mask,
           enc_relative_ids, dec_relative_ids, dec2enc_relative_ids):

    enc_output, enc2enc_attention_for_viz = (
        self.encoder(inp, training, enc_padding_mask,
                     enc_relative_ids))  # (batch_size, inp_seq_len, d_model)

    # dec_output.shape == (batch_size, tar_seq_len, d_model)
    (dec_output, dec2enc_attention_for_viz, dec2dec_attention_for_viz) = (
        self.decoder(tar, enc_output, training, look_ahead_mask,
                     dec_padding_mask,
                     dec_relative_ids, dec2enc_relative_ids))

    # (batch_size, tar_seq_len, vocab_size)
    final_output = self.final_layer(dec_output)

    # Copy decoder:
    if (self.copy_decoder == COPY_DECODER or
        self.copy_decoder == COPY_DECODER2):
      # (batch_size, tar_seq_len, d_model)
      copy_output_query = self.final_layer_copy(dec_output)
      copy_output_weight = self.final_layer_copy_weight(dec_output)
      # (batch_size, inp_seq_len, vocab_size)
      copy_output, _ = scaled_dot_product_relative_attention(
          copy_output_query,  # (batch_size, tar_seq_len, d_model)
          enc_output,  # (batch_size, inp_seq_len, d_model)
          tf.one_hot(inp, self.vocab_size))
      if self.copy_decoder == COPY_DECODER2:
        if self.d_model < self.vocab_size:
          copy_output = tf.nn.softmax(
              self.final_layer_copy2b(
                  self.final_layer_copy2a(copy_output)))
        else:
          copy_output = tf.nn.softmax(
              self.final_layer_copy2(copy_output))
      final_output = (
          (1 - copy_output_weight) * tf.nn.softmax(final_output, axis=-1) +
          copy_output_weight*copy_output)
    else:
      final_output = tf.nn.softmax(final_output, axis=-1)

    return (final_output, enc2enc_attention_for_viz,
            dec2enc_attention_for_viz, dec2dec_attention_for_viz)

  def detailed_param_count(self):
    print("Transformer parameter counts:")
    self.encoder.detailed_param_count()
    self.decoder.detailed_param_count()
    print(f"  final_layer: {self.final_layer.count_params()}")
    if (self.copy_decoder == COPY_DECODER or
        self.copy_decoder == COPY_DECODER2):
      print(f"  final_layer_copy: {self.final_layer_copy.count_params()}")
      print(f"  final_layer_copy_weight: {self.final_layer_copy_weight.count_params()}")
    if self.copy_decoder == COPY_DECODER2:
      if self.d_model < self.vocab_size:
        print(f"  final_layer_copy2a: {self.final_layer_copy2a.count_params()}")
        print(f"  final_layer_copy2b: {self.final_layer_copy2b.count_params()}")
      else:
        print(f"  final_layer_copy2: {self.final_layer_copy2.count_params()}")

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super(CustomSchedule, self).__init__()

    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps

  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


def create_masks(inp, tar):
  # Encoder padding mask
  enc_padding_mask = create_padding_mask(inp)

  # Used in the 2nd attention block in the decoder.
  # This padding mask is used to mask the encoder outputs.
  dec_padding_mask = create_padding_mask(inp)

  # Used in the 1st attention block in the decoder.
  # It is used to pad and mask future tokens in the input received by
  # the decoder.
  look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
  dec_target_padding_mask = create_padding_mask(tar)
  combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

  return enc_padding_mask, combined_mask, dec_padding_mask


def create_relative_ids(inp_len, tar_len, relative_radius, de2enc_ids):
  enc_relative_ids = np.zeros([inp_len, inp_len], dtype=int)
  for i in range(inp_len):
    for j in range(inp_len):
      diff = i - j
      diff = relative_radius + min(max(diff, -relative_radius), relative_radius)
      enc_relative_ids[i][j] = diff

  dec_relative_ids1 = np.zeros([tar_len-1, tar_len-1], dtype=int)
  for i in range(tar_len-1):
    for j in range(tar_len-1):
      diff = i - j
      diff = relative_radius + min(max(diff, -relative_radius), relative_radius)
      dec_relative_ids1[i][j] = diff

  dec2end_relative_ids = np.zeros([tar_len-1, inp_len], dtype=int)
  for i in range(tar_len-1):
    for j in range(inp_len):
      if de2enc_ids:
        diff = i - j
        diff = relative_radius + min(max(diff, -relative_radius),
                                     relative_radius)
        dec2end_relative_ids[i][j] = diff
      else:
        dec2end_relative_ids[i][j] = relative_radius

  return (tf.constant(enc_relative_ids), tf.constant(dec_relative_ids1),
          tf.constant(dec2end_relative_ids))

In [None]:
#@title Experiment execution function

def setup_dataset(dataset):

  vocab = [PAD_TOKEN, SEP_TOKEN, END_TOKEN, START_TOKEN,
          END_ITERATION_TOKEN]
  vocab_to_int = {PAD_TOKEN:0, SEP_TOKEN:SEP_TOKEN_IDX,
                  END_TOKEN:END_TOKEN_IDX,
                  START_TOKEN:START_TOKEN_IDX,
                  END_ITERATION_TOKEN:END_ITERATION_TOKEN_IDX}

  if dataset == DATASET_ADDITION_ALIGNED:
    return create_addition_dataset(200000, 512, vocab, vocab_to_int,
                                reversedigits=False, leftpadding=12,
                                addAlignmentTokens=False)
  elif dataset == DATASET_ADDITION_NEGATIVES:
    return create_addition_dataset(200000, 512, vocab, vocab_to_int,
                                reversedigits=False, leftpadding=12,
                                addAlignmentTokens=False,
                                negativeProbability=0.25)
  elif dataset == DATASET_SCAN_LENGTH:
    return create_scan_length_dataset(vocab, vocab_to_int,
                                    map_output_to_input=False)
  elif dataset == DATASET_SCAN_ADD_JUMP:
    return create_scan_add_jump_dataset(vocab, vocab_to_int,
                                     map_output_to_input=False)
  elif dataset == DATASET_PCFG_PRODUCTIVITY:
    return create_pcfg_dataset("productivity", vocab, vocab_to_int)
  elif dataset == DATASET_PCFG_SYSTEMATICITY:
    return create_pcfg_dataset("systematicity", vocab, vocab_to_int)
  elif dataset == DATASET_COGS_FULL:
    return create_cogs_dataset(vocab, vocab_to_int, max_len=512)
  elif dataset == DATASET_CFQ_MCD1:
    return create_cfq_mcd1_dataset(vocab, vocab_to_int)
  elif dataset == DATASET_CFQ_MCD1_INTERMEDIATE:
    return create_cfq_mcd1_dataset(vocab, vocab_to_int,
                                   simplify_cartesians=True)
  elif dataset == DATASET_DUPLICATION:
    return create_duplicating_dataset(200000, 1024, vocab, vocab_to_int)
  elif dataset == DATASET_REVERSING:
    return create_reversing_dataset(200000, 1024, vocab, vocab_to_int)
  elif dataset == DATASET_CARTESIAN:
    return create_cartesian_dataset(200000, 1024, vocab, vocab_to_int)
  elif dataset == DATASET_INTERSECTION_BOOLEAN:
    return create_intersection_dataset(200000, 1024, vocab, vocab_to_int)
  else:
    raise ValueError(f"Undefined dataset type: {dataset}")


def run_evaluation(dataset,
                   num_layers, d_model, dff, num_heads, position_encoding,
                   decoder_type, share_layer_weights,
                   num_epochs,
                   realformer=NO_REALFORMER,
                   num_repetitions=1,
                   save_results_every_n_epochs=None,
                   batch_size=64):

  if save_results_every_n_epochs is None:
    save_results_every_n_epochs = num_epochs

  # Load and setup the corresponding dataset:
  (vocab, vocab_to_int, input_tensor_train, target_tensor_train,
    input_tensor_val_list, target_tensor_val_list) = setup_dataset(dataset)
  (dataset_train, dataset_val_list) = prepare_tf_dataset_tensors(
      vocab, vocab_to_int, input_tensor_train, target_tensor_train,
      input_tensor_val_list, target_tensor_val_list, batch_size)
  max_len_inp = len(input_tensor_train[0])
  max_len_targ = len(target_tensor_train[0])

  if position_encoding == RELATIVE16:
    position_encoding = RELATIVE
    relative_radius = 16
    de2enc_ids = False
  elif position_encoding == RELATIVE16_BIASONLY:
    position_encoding = RELATIVE_BIASONLY
    relative_radius = 16
    de2enc_ids = False
  elif position_encoding == RELATIVE16_BIAS:
    position_encoding = RELATIVE_BIAS
    relative_radius = 16
    de2enc_ids = False
  elif position_encoding == RELATIVE16_D2E:
    position_encoding = RELATIVE
    relative_radius = 16
    de2enc_ids = True
  elif position_encoding == RELATIVE16_D2E_BIASONLY:
    position_encoding = RELATIVE_BIASONLY
    relative_radius = 16
    de2enc_ids = True
  elif position_encoding == RELATIVE16_D2E_BIAS:
    position_encoding = RELATIVE_BIAS
    relative_radius = 16
    de2enc_ids = True
  elif position_encoding == ABSOLUTE_SINUSOIDAL:
    position_encoding = ABSOLUTE_SINUSOIDAL
    relative_radius = 0
    de2enc_ids = False
  else:
    raise ValueError(f"Undefined position embeddings type: {position_embeddings}")

  dropout_rate = 0.1

  repetition_metrics_l = []
  for repetition in range(num_repetitions):
    print(f"Starting repetition {repetition} ...\n")

    learning_rate = CustomSchedule(d_model)
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                        epsilon=1e-9)
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=False, reduction='none')
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')
    eval_loss = tf.keras.metrics.Mean(name='train_loss')
    eval_accuracy = tf.keras.metrics.Mean(name='train_accuracy')

    transformer = Transformer(num_layers, d_model, num_heads, dff,
                              len(vocab),
                              pe_input=max_len_inp,
                              pe_target=max_len_targ,
                              rate=dropout_rate,
                              relative_radius=relative_radius,
                              position_encodings=position_encoding,
                              copy_decoder=decoder_type,
                              shared_layers=share_layer_weights,
                              realformer_connections=realformer)


    def loss_function(real, pred):
      mask = tf.math.logical_not(tf.math.equal(real, 0))
      loss_ = loss_object(real, pred)

      mask = tf.cast(mask, dtype=loss_.dtype)
      loss_ *= mask

      return tf.reduce_sum(loss_)/tf.reduce_sum(mask)


    def accuracy_function(real, pred):
      mask = tf.math.logical_not(tf.math.equal(real, 0))
      loss_ = tf.keras.metrics.sparse_categorical_accuracy(real, pred)

      mask = tf.cast(mask, dtype=loss_.dtype)
      loss_ *= mask

      return tf.reduce_sum(loss_)/tf.reduce_sum(mask)


    train_step_signature = [
        tf.TensorSpec(shape=(None, None), dtype=tf.int32),
        tf.TensorSpec(shape=(None, None), dtype=tf.int32)
    ]

    @tf.function(input_signature=train_step_signature)
    def train_step(inp, targ):
      targ_inp = targ[:, :-1]
      targ_real = targ[:, 1:]
      enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp,
                                                                       targ_inp)
      enc_relative_ids, dec_relative_ids, dec2end_relative_ids = (
          create_relative_ids(max_len_inp, max_len_targ, relative_radius,
                              de2enc_ids))
      with tf.GradientTape() as tape:
        predictions, _, _, _ = transformer(inp, targ_inp,
                                           True,
                                           enc_padding_mask,
                                           combined_mask,
                                           dec_padding_mask,
                                           enc_relative_ids,
                                           dec_relative_ids,
                                           dec2end_relative_ids)
        loss = loss_function(targ_real, predictions)
        accuracy = accuracy_function(targ_real, predictions)

      gradients = tape.gradient(loss, transformer.trainable_variables)
      optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
      train_loss(loss)
      train_accuracy(accuracy)


    @tf.function(input_signature=train_step_signature)
    def eval_step(inp, targ):
      targ_inp = targ[:, :-1]
      targ_real = targ[:, 1:]
      enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp,
                                                                       targ_inp)
      enc_relative_ids, dec_relative_ids, dec2end_relative_ids = (
          create_relative_ids(max_len_inp, max_len_targ, relative_radius,
                              de2enc_ids))
      with tf.GradientTape() as tape:
        predictions, _, _, _ = transformer(inp, targ_inp,
                                           False,
                                           enc_padding_mask,
                                           combined_mask,
                                           dec_padding_mask,
                                           enc_relative_ids,
                                           dec_relative_ids,
                                           dec2end_relative_ids)
        loss = loss_function(targ_real, predictions)
        accuracy = accuracy_function(targ_real, predictions)
      eval_loss(loss)
      eval_accuracy(accuracy)


    # This method is slow, so, we just want to call it once per epoch, to
    # visualize how the model is doing.
    def eval_step_detailed(inp, targ, max_to_show, vocab):
      targ_inp = targ[:, :-1]
      targ_real = targ[:, 1:]
      enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp,
                                                                       targ_inp)
      enc_relative_ids, dec_relative_ids, dec2end_relative_ids = (
          create_relative_ids(max_len_inp, max_len_targ, relative_radius,
                              de2enc_ids))
      predictions, _, _, _ = transformer(inp, targ_inp,
                                         False,
                                         enc_padding_mask,
                                         combined_mask,
                                         dec_padding_mask,
                                         enc_relative_ids,
                                         dec_relative_ids,
                                         dec2end_relative_ids)

      predicted_targ = predicted_ids = tf.argmax(predictions, 2).numpy()

      accuracy_seq = 0
      accuracy_tok = 0
      total_tok = 0
      shown = 0
      inp_numpy = inp.numpy()
      ground_truth = targ_real.numpy()
      for i in range(batch_size):
        for j in range(int(targ_real.shape[1])-1):
          if predicted_targ[i][j] == 2 or predicted_targ[i][j] == 0:
            predicted_targ[i][j+1] = 0
        for k in range(len(predicted_targ[i])):
          if predicted_targ[i][k] == ground_truth[i][k]:
            accuracy_tok += 1
          total_tok += 1
        if (predicted_targ[i] == ground_truth[i]).all():
          accuracy_seq += 1
        else:
          if shown < max_to_show:
            print("input:     " + decode(inp_numpy[i], vocab))
            print("target:    " + decode(ground_truth[i], vocab))
            print("predicted: " + decode(predicted_targ[i], vocab))
            shown += 1
      return accuracy_tok / float(total_tok), accuracy_seq / float(batch_size)

    def evaluate_in_set(dataset_val, vocab):
      steps_per_epoch_val = dataset_val.cardinality()
      n_test_batches = 0
      test_accuracy_token = 0
      test_accuracy_seq = 0
      for (batch, (inp, targ)) in enumerate(dataset_val.take(
          steps_per_epoch_val)):
        if batch == 0:
          batch_accuracy_token, batch_accuracy_seq = eval_step_detailed(
              inp, targ, 16, vocab)
          test_accuracy_token += batch_accuracy_token
          test_accuracy_seq += batch_accuracy_seq
        else:
          batch_accuracy_token, batch_accuracy_seq = eval_step_detailed(
              inp, targ, 0, vocab)
          test_accuracy_token += batch_accuracy_token
          test_accuracy_seq += batch_accuracy_seq
        n_test_batches += 1

      print (f'Eval accuracy (token level): {test_accuracy_token/n_test_batches}')
      print (f'Eval accuracy (sequence level): {test_accuracy_seq/n_test_batches}\n')
      return (test_accuracy_token/n_test_batches, test_accuracy_seq/n_test_batches)

    steps_per_epoch = len(input_tensor_train)//batch_size
    repetition_metrics = []
    for epoch in range(num_epochs):
      steps_per_epoch_val = dataset_val_list[0].cardinality()
      start = time.time()

      train_loss.reset_states()
      train_accuracy.reset_states()
      eval_loss.reset_states()
      eval_accuracy.reset_states()

      for (batch, (inp, targ)) in enumerate(dataset_train.take(steps_per_epoch)):
        # print("inp:" + str(inp.shape))
        # print("targ:" + str(targ.shape))
        train_step(inp, targ)
        if batch % 100 == 0:
          print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
              epoch + 1, batch, train_loss.result(), train_accuracy.result()))
      print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1,
                                                    train_loss.result(),
                                                    train_accuracy.result()))
      for (batch, (inp, targ)) in enumerate(dataset_val_list[0].take(
          steps_per_epoch_val)):
        eval_step(inp, targ)
      print ('Epoch {} Eval Loss {:.4f} Eval Accuracy {:.4f}'.format(epoch + 1,
                                                    eval_loss.result(),
                                                    eval_accuracy.result()))

      print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

      if ((epoch+1) % save_results_every_n_epochs) == 0:
        epoch_metrics = [epoch+1]
        for i in range(len(dataset_val_list)):
          print(f"------- Evaluation in dataset_val {i} -------")
          (acc_token, acc_seq) = evaluate_in_set(dataset_val_list[i], vocab)
          epoch_metrics.append(acc_token)
          epoch_metrics.append(acc_seq)
        repetition_metrics.append(epoch_metrics)

    # Make sure we save the last epoch results if it is not a
    # multiple of save_results_every_n_epochs:
    if (num_epochs % save_results_every_n_epochs) != 0:
      epoch_metrics = [num_epochs]
      for i in range(len(dataset_val_list)):
        print(f"------- Evaluation in dataset_val {i} -------")
        (acc_token, acc_seq) = evaluate_in_set(dataset_val_list[i], vocab)
        epoch_metrics.append(acc_token)
        epoch_metrics.append(acc_seq)
      repetition_metrics.append(epoch_metrics)

    repetition_metrics_l.append(repetition_metrics)

  print("Raw repetition metrics:")
  for repetition_metrics in repetition_metrics_l:
    for epoch_metrics in repetition_metrics:
      print(epoch_metrics)

  averages = repetition_metrics_l[0]
  for i in range(1, len(repetition_metrics_l)):
    for j in range(len(repetition_metrics_l[i])):
      epoch_metrics = repetition_metrics_l[i][j]
      # skip the epoch number:
      for k in range(1, len(epoch_metrics)):
        averages[j][k] += epoch_metrics[k]
  for j in range(len(averages)):
    for k in range(1, len(averages[j])):
      averages[j][k] /= len(repetition_metrics_l)

  print("Average repetition metrics:")
  for average_metrics in averages:
    print('\t'.join(map(str,average_metrics)))

  print(f"Transformer params: {transformer.count_params()}")
  transformer.detailed_param_count()

  return averages, transformer, dataset_train, dataset_val_list


In [None]:
#@title Individual Experiments

# Run the different experiments:
# parameters are:
# - dataset
# - num_layers, d_model, dff, num_heads
# - position_encoding, decoder_type, share_layer_weights
# - num_epochs
# - optional keyword parameters:
#     num_repetitions=1,
#     save_results_every_n_epochs=None, batch_size=64

# Example:
run_evaluation(DATASET_DUPLICATION,
               4, 128, 512, 8,
               RELATIVE16_D2E_BIAS, COPY_DECODER, True,
               4,
               num_repetitions=1, save_results_every_n_epochs=1)


In [None]:
#@title Systematic Experiments

# Set the path where you want the results to be stored:
RESULTS_FILE = "/path-to/results.tsv"

datasets = [
            (DATASET_ADDITION_ALIGNED, 2),
            (DATASET_ADDITION_NEGATIVES, 10),
            (DATASET_REVERSING, 2),
            (DATASET_DUPLICATION, 4),
            (DATASET_CARTESIAN, 4),
            (DATASET_INTERSECTION_BOOLEAN, 8),
            (DATASET_SCAN_LENGTH, 24),
            (DATASET_SCAN_ADD_JUMP, 24),
            (DATASET_PCFG_PRODUCTIVITY, 20),
            (DATASET_PCFG_SYSTEMATICITY, 20),
            (DATASET_COGS_FULL, 16),
            (DATASET_CFQ_MCD1, 16),
            (DATASET_CFQ_MCD1_INTERMEDIATE, 16)
            ]

num_repetitions = 5

models = [
          ("abs", 2, 64, 256, 4, ABSOLUTE_SINUSOIDAL, STANDARD_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("rel-e", 2, 64, 256, 4, RELATIVE16, STANDARD_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("rel-b", 2, 64, 256, 4, RELATIVE16_BIASONLY, STANDARD_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("rel-eb", 2, 64, 256, 4, RELATIVE16_BIAS, STANDARD_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("rel2-e", 2, 64, 256, 4, RELATIVE16_D2E, STANDARD_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("rel2-b", 2, 64, 256, 4, RELATIVE16_D2E_BIASONLY, STANDARD_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("rel2-eb", 2, 64, 256, 4, RELATIVE16_D2E_BIAS, STANDARD_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("abs-c", 2, 64, 256, 4, ABSOLUTE_SINUSOIDAL, COPY_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("rel-eb-c", 2, 64, 256, 4, RELATIVE16_BIAS, COPY_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("rel2-eb-c", 2, 64, 256, 4, RELATIVE16_D2E_BIAS, COPY_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("small-2", 2, 64, 256, 4, RELATIVE16_D2E_BIAS, COPY_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("small-4", 4, 64, 256, 4, RELATIVE16_D2E_BIAS, COPY_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("small-6", 6, 64, 256, 4, RELATIVE16_D2E_BIAS, COPY_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("large-2", 4, 128, 512, 8, RELATIVE16_D2E_BIAS, COPY_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("large-4", 4, 128, 512, 8, RELATIVE16_D2E_BIAS, COPY_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("large-6", 6, 128, 512, 8, RELATIVE16_D2E_BIAS, COPY_DECODER, False, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("small-2s", 2, 64, 256, 4, RELATIVE16_D2E_BIAS, COPY_DECODER, True, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("small-4s", 4, 64, 256, 4, RELATIVE16_D2E_BIAS, COPY_DECODER, True, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("small-6s", 6, 64, 256, 4, RELATIVE16_D2E_BIAS, COPY_DECODER, True, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("large-2s", 2, 128, 512, 8, RELATIVE16_D2E_BIAS, COPY_DECODER, True, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("large-4s", 4, 128, 512, 8, RELATIVE16_D2E_BIAS, COPY_DECODER, True, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ("large-6s", 6, 128, 512, 8, RELATIVE16_D2E_BIAS, COPY_DECODER, True, NO_REALFORMER, STANDARD_FF, GREEDY_DECODING),
          ]

for (modelname, num_layers, d_model, dff, num_heads, position_encoding, decoder_type, share_layer_weights, realformer, ff_type, decoding_strategy) in models:
  for (dataset, epochs) in datasets:
    for i in range(num_repetitions):
      results, transformer, _, _ = run_evaluation(
          dataset,
          num_layers, d_model, dff, num_heads, position_encoding, decoder_type, share_layer_weights, epochs, realformer=realformer,
          save_results_every_n_epochs=2)
      params = transformer.count_params()
      with gfile.Open(RESULTS_FILE, "a") as wf:
        for result in results:
          report = [modelname, dataset, params] + result
          report_str = "\t".join([str(x) for x in report])
          print(report_str)
          wf.write(report_str)
          wf.write("\n")

