In [None]:
#@title Imports

# Licensed under the Apache License, Version 2.0

import json
import io
import random
import time
import numpy as np
import statistics
import tensorflow.compat.v1 as tf

tf.enable_eager_execution()
assert tf.executing_eagerly()

print(tf.version.VERSION)


In [None]:
#@title "COGS" tagging dataset generation

DATA_DIR = 'DATA_DIR'  # Directory containing sequence tagged json files.
DATASET_FILES = [DATA_DIR + ifile for ifile in [
        "train_seqtag.jsonl", "test_seqtag.jsonl", "dev_seqtag.jsonl","gen_seqtag.jsonl"]
]

PAD_TOKEN = "[PAD]"
PAD_PARENT = 99999

# Possible parent encodings:
# The type of parent encoding is used already in dataset generation.
PARENT_ABSOLUTE = "parent_absolute_encoding"
PARENT_RELATIVE = "parent_relative_encoding"
PARENT_ATTENTION = "parent_attention_encoding"


def decode(seq, vocab):
  out = ""
  for tok in seq:
    if tok == 0:
      return out
    out += str(vocab[tok]) + ", "
  return out


def read_cogs_datafile(filename):
  data = []
  print(filename)
  with tf.io.gfile.GFile(filename, "r") as f:
    for line in f:
      data.append(json.loads(line))
  print(f"loaded {filename}: {len(data)} instances.")
  return data


def split_set_by_distribution(dataset):
  """Create multiple splits based on the generalization type.

  Only COGS geneneralization dataset is annotated by the generalization type.
  """
  distributions = []
  split = {}
  for example in dataset:
    distribution = example["distribution"]
    if distribution in split:
      split[distribution].append(example)
    else:
      distributions.append(distribution)
      split[distribution] = [example]
  for distribution in split:
    print(f"{distribution}: {len(split[distribution])}")
  return split, distributions


def create_dataset_feature_tensor(dataset, feature, vocab, max_len, parent_encoding=None):
  """Read the selected feature from the examples.

  Be carfeful about the parent encoding since we comnpare the indices to the
  attention matrix.
  """
  feature_tensor = []
  for example in dataset:
    if feature == "parent":
      if parent_encoding == PARENT_ABSOLUTE:
        assert vocab[0] == PAD_PARENT  # padding
        assert vocab[1] == -1  # -1 means no parent
        assert vocab[2] == 0   # parent is the 1st token
        tensor = [vocab.index(x) for x in example[feature]]
      elif parent_encoding == PARENT_RELATIVE:
        assert vocab[0] == PAD_PARENT
        # Use self instead of -1 to denote no parent.
        parents = example[feature]
        tensor = [vocab.index(parents[i]-i) if parents[i] != -1 else vocab.index(0) for i in range(len(parents))]
      elif parent_encoding == PARENT_ATTENTION:
        # Use self instead of -1 to denote no parent.
        # The vocab for parent is hardcoded: [-2, 0, 1, 2, ...]
        assert vocab[0] == PAD_PARENT
        assert vocab[1] == 0
        assert vocab[2] == 1
        parents = example[feature]
        tensor = [vocab.index(parents[i]) if parents[i] != -1 else vocab.index(i) for i in range(len(parents))]
      else:
        raise ValueError(f"Undefined parent_encoding: {parent_encoding}")
    else:
      tensor = [vocab.index(x) for x in example[feature]]
    feature_tensor.append(tensor)
  feature_tensor = tf.keras.preprocessing.sequence.pad_sequences(
      feature_tensor, padding="post", maxlen=max_len)
  return feature_tensor


def create_parent_ids_tensor(dataset, max_len):
  """This is really an input mask.
  0 when there is an input token
  1 when the input token is padding.
  """
  feature_tensor = []
  for example in dataset:
    tensor = [0]*max_len
    for i in range(len(example["tokens"]), max_len):
      tensor[i] = 1
    feature_tensor.append(tensor)
  feature_tensor = tf.keras.preprocessing.sequence.pad_sequences(
      feature_tensor, padding="post", maxlen=max_len)
  return feature_tensor


def create_dataset_tensors(dataset,
                           vocabs,
                           max_len,
                           batch_size,
                           show_example=False,
                           parent_encoding=None):
  tokens_tensor = create_dataset_feature_tensor(dataset, "tokens", vocabs[0],
                                                max_len)
  parent_ids_tensor = create_parent_ids_tensor(dataset, max_len)
  parent_tensor = create_dataset_feature_tensor(dataset, "parent", vocabs[1],
                                                max_len, parent_encoding)
  role_tensor = create_dataset_feature_tensor(dataset, "role", vocabs[2],
                                              max_len)
  category_tensor = create_dataset_feature_tensor(dataset, "category",
                                                  vocabs[3], max_len)
  noun_type_tensor = create_dataset_feature_tensor(dataset, "noun_type",
                                                   vocabs[4], max_len)
  verb_tensor = create_dataset_feature_tensor(dataset, "verb_name", vocabs[5],
                                              max_len)
  buffer_size = len(dataset)
  dataset = tf.data.Dataset.from_tensor_slices(
      (tokens_tensor, parent_ids_tensor, parent_tensor, role_tensor,
       category_tensor, noun_type_tensor, verb_tensor)).shuffle(buffer_size)
  dataset = dataset.batch(batch_size, drop_remainder=True)

  if show_example:
    print("- Sample Example ----------------")
    print(f"tokens: {decode(tokens_tensor[0], vocabs[0])}")
    print(f"parent_ids: {parent_ids_tensor[0]}")
    print(f"parent: {decode(parent_tensor[0], vocabs[1])}")
    print(f"role: {decode(role_tensor[0], vocabs[2])}")
    print(f"category: {decode(category_tensor[0], vocabs[3])}")
    print(f"noun_type: {decode(noun_type_tensor[0], vocabs[4])}")
    print(f"verb_name: {decode(verb_tensor[0], vocabs[5])}")
    print("---------------------------------")

  return dataset


def read_cogs_datasets(dataset_files, parent_encoding, batch_size):
  assert len(dataset_files) == 4, (
      "expected list of dataset paths in this order: train, test, dev, gen; "
      "got %s"
  ) % dataset_files
  cogs_train = dataset_files[0]
  cogs_test = dataset_files[1]
  cogs_dev = dataset_files[2]
  cogs_gen = dataset_files[3]

  train_set = read_cogs_datafile(cogs_train)
  test_set = read_cogs_datafile(cogs_test)
  dev_set = read_cogs_datafile(cogs_dev)
  gen_set = read_cogs_datafile(cogs_gen)

  # Create vocabs, and calculate dataset stats:
  tokens_vocab = [PAD_TOKEN]
  # The token with index 0 has to be padding, because loss relies on it.
  # -1 is already used in the tagging datased to denote no parent,
  # so let's use -2 as the padding token.
  parent_vocab_raw = [PAD_PARENT]
  role_vocab = [PAD_TOKEN]
  category_vocab = [PAD_TOKEN]
  noun_type_vocab = [PAD_TOKEN]
  verb_name_vocab = [PAD_TOKEN]

  max_len = 0
  for example in train_set + test_set + dev_set + gen_set:
    for token in example["tokens"]:
      if token not in tokens_vocab:
        tokens_vocab.append(token)
    for token in example["parent"]:
      if token not in parent_vocab_raw:
        parent_vocab_raw.append(token)
    for token in example["role"]:
      if token not in role_vocab:
        role_vocab.append(token)
    for token in example["category"]:
      if token not in category_vocab:
        category_vocab.append(token)
    for token in example["noun_type"]:
      if token not in noun_type_vocab:
        noun_type_vocab.append(token)
    for token in example["verb_name"]:
      if token not in verb_name_vocab:
        verb_name_vocab.append(token)
    l = len(example["tokens"])
    max_len = max(max_len, l)

  if parent_encoding == PARENT_ABSOLUTE:
    parent_vocab = [PAD_PARENT, -1] + list(range(max_len))
  elif parent_encoding == PARENT_RELATIVE:
    parent_vocab = [PAD_PARENT] + list(range(-max_len+1, max_len))
  elif parent_encoding == PARENT_ATTENTION:
    parent_vocab = [PAD_PARENT] + list(range(max_len))
  else:
    raise ValueError(f"Undefined parent_encoding: {parent_encoding}")

  max_len += 1  # guarantee at least one padding token at the end
                # for "no parent"
                # Padding token is also used to stop decoding in decode().

  gen_distribution_split, gen_distributions = split_set_by_distribution(gen_set)

  print(f"n_distributions: {len(gen_distribution_split)}")
  print(f"max_len: {max_len}")
  print(f"tokens_vocab: {len(tokens_vocab)}  -->> {tokens_vocab}")
  print(f"parent_vocab: {len(parent_vocab)}  -->> {parent_vocab}")
  parent_vocab_missing = sorted(set(parent_vocab) - set(parent_vocab_raw))
  print(f"parent indices missing from the data: {len(parent_vocab_missing)}  -->> {parent_vocab_missing}")
  print(f"role_vocab: {len(role_vocab)}  -->> {role_vocab}")
  print(f"category_vocab: {len(category_vocab)}  -->> {category_vocab}")
  print(f"noun_type_vocab: {len(noun_type_vocab)}  -->> {noun_type_vocab}")
  print(f"verb_name_vocab: {len(verb_name_vocab)}  -->> {verb_name_vocab}")

  vocabs = [tokens_vocab, parent_vocab,
            role_vocab, category_vocab,
            noun_type_vocab, verb_name_vocab]

  train_tf_dataset = create_dataset_tensors(
      train_set, vocabs, max_len, batch_size, show_example=True, parent_encoding=parent_encoding)
  test_tf_dataset = create_dataset_tensors(test_set, vocabs, max_len, batch_size, parent_encoding=parent_encoding)
  dev_tf_dataset = create_dataset_tensors(dev_set, vocabs, max_len, batch_size, parent_encoding=parent_encoding)
  gen_tf_dataset = create_dataset_tensors(gen_set, vocabs, max_len, batch_size, parent_encoding=parent_encoding)

  gen_split_tf_datasets = []
  for distribution in gen_distributions:
    gen_split_tf_datasets.append(
        create_dataset_tensors(gen_distribution_split[distribution],
                               [tokens_vocab, parent_vocab,
                                role_vocab, category_vocab,
                                noun_type_vocab, verb_name_vocab],
                               max_len,
                               batch_size,
                               parent_encoding=parent_encoding))

  test_sets = [test_tf_dataset, dev_tf_dataset
              ] + gen_split_tf_datasets + [gen_tf_dataset]
  test_sets_names = ["test", "dev"] + gen_distributions + ["gen"]

  return train_tf_dataset, test_sets, test_sets_names, vocabs, max_len


def set_up_cogs(parent_encoding, batch_size):
  return read_cogs_datasets(DATASET_FILES, parent_encoding, batch_size)


In [None]:
#@title This code block defines Encoder layer, following this Transformer tutorial (adding relative attention, a copy decoder, and a few other improvements): https://www.tensorflow.org/tutorials/text/transformer

# Attention types:
RELATIVE = "relative"
RELATIVE_BIAS = "relative_bias"
RELATIVE8 = "relative8"
RELATIVE16 = "relative16"
RELATIVE16_BIAS = "relative16_bias"
ABSOLUTE_SINUSOIDAL = "absolute_sinusoidal"


def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates


def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

  pos_encoding = angle_rads[np.newaxis, ...]

  return tf.cast(pos_encoding, dtype=tf.float32)


def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)


def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)


def create_masks(inp):
  # Encoder padding mask
  enc_padding_mask = create_padding_mask(inp)

  return enc_padding_mask


def create_relative_ids(inp_len, relative_radius):
  enc_relative_ids = np.zeros([inp_len, inp_len], dtype=int)
  for i in range(inp_len):
    for j in range(inp_len):
      diff = i - j
      diff = relative_radius + min(max(diff, -relative_radius), relative_radius)
      enc_relative_ids[i][j] = diff

  return tf.constant(enc_relative_ids)


def scaled_dot_product_relative_attention(q, k, v,
                                          mask=None,
                                          relative_ids=None,
                                          relative_embeddings=None,
                                          relative_biases=None):
  """Calculate the attention weights.
  - If relative_ids or relative_embeddings is None, then this is equivalent to
    regular scaled_dot_product_attention.
  - q, k, v must have matching leading dimensions.
  - k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  - The mask has different shapes depending on its type(padding or look ahead)
    but it must be broadcastable for addition.
  - relative_ids is an optional [seq_len_q, seq_len_k] with the relative ids.
  - relative_embeddings is an optional dense layer that converts OHE ids to
    embeddings.
  - relative_biases: optional dense layer that generates a bias term for the
    attention logits based on the relative_ids

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable
          to (..., seq_len_q, seq_len_k). Defaults to None.
    relative_ids == (seq_len_q, seq_len_k). Defaults to None.
    relative_embeddings == dense layer. Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  if relative_ids is not None:
    if relative_embeddings is not None:
      r = relative_embeddings(relative_ids)
      if len(q.shape) == 4:
        matmul_qrel = tf.einsum("bhqd,qkd->bhqk", q, r)
        matmul_qk += matmul_qrel
      elif len(q.shape) == 3:
        matmul_qrel = tf.einsum("bqd,qkd->bqk", q, r)
        matmul_qk += matmul_qrel
      else:
        raise ValueError(f"Query must have dimension 4 or 3 (only 1 head), but has {len(q.shape)}")
    if relative_biases is not None:
      matmul_qk += tf.squeeze(relative_biases(relative_ids), axis=-1)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  if v is not None:
    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)
    return output, attention_weights
  else:
    return None, attention_weights


class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, relative_radius, position_encodings):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model
    self.relative_vocab_size = relative_radius*2+1

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

    if (position_encodings == RELATIVE or
        position_encodings == RELATIVE_BIAS):
      self.relative_embeddings = tf.keras.layers.Embedding(
          self.relative_vocab_size, self.depth)
    else:
      self.relative_embeddings = None
    if position_encodings == RELATIVE_BIAS:
      self.relative_biases = tf.keras.layers.Embedding(
          self.relative_vocab_size, 1)
    else:
      self.relative_biases = None

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask=None, relative_ids=None):
    batch_size = tf.shape(q)[0]

    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)

    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_logits = scaled_dot_product_relative_attention(
        q, k, v, mask, relative_ids, self.relative_embeddings, self.relative_biases)

    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention,
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

    return output, attention_logits


def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation="relu"),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])


class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, position_encodings, rate=0.1, relative_radius=8):
    super(EncoderLayer, self).__init__()

    self.mha = MultiHeadAttention(d_model, num_heads, relative_radius, position_encodings)
    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask, relative_ids):

    attn_output, _ = self.mha(x, x, x, mask, relative_ids)  # (batch_size, input_seq_len, d_model)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

    ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
    ffn_output = self.dropout2(ffn_output, training=training)
    out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

    return out2


class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
               maximum_position_encoding, rate=0.1, relative_radius=8,
               position_encodings=RELATIVE,
               shared_layers=False):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers
    self.position_encodings = position_encodings
    self.shared_layers = shared_layers

    self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    if self.position_encodings == ABSOLUTE_SINUSOIDAL:
      self.pos_encoding = positional_encoding(maximum_position_encoding,
                                              self.d_model)

    if self.shared_layers:
      layer = EncoderLayer(d_model, num_heads, dff, position_encodings,
                           rate, relative_radius=relative_radius)
      self.enc_layers = [layer for _ in range(num_layers)]
    else:
      self.enc_layers = [EncoderLayer(d_model, num_heads, dff, position_encodings,
                                      rate, relative_radius=relative_radius)
                         for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask, relative_ids):

    seq_len = tf.shape(x)[1]

    # adding embedding and position encoding.
    x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    if self.position_encodings == ABSOLUTE_SINUSOIDAL:
      relative_ids = None
      x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask, relative_ids)

    return x  # (batch_size, input_seq_len, d_model)

  def detailed_param_count(self):
    print(f"  encoder embedding: {self.embedding.count_params()}")
    if self.shared_layers:
      print(f"  encoder layer weights (shared): {self.enc_layers[0].count_params()}")
    else:
      print(f"  encoder layer weights: {self.enc_layers[0].count_params()*len(self.enc_layers)}")


class COGSModel(tf.keras.Model):
  def __init__(self, num_layers, d_model, num_heads, dff, vocabs,
               pe_input, rate=0.1,
               relative_radius=8, position_encodings=RELATIVE,
               shared_layers=False,
               parent_encoding=None):
    super(COGSModel, self).__init__()
    self.pe_input = pe_input # input length
    self.tokens_vocab_size = len(vocabs[0])
    self.parent_vocab_size = len(vocabs[1])
    self.role_vocab_size = len(vocabs[2])
    self.category_vocab_size = len(vocabs[3])
    self.noun_type_vocab_size = len(vocabs[4])
    self.verb_name_vocab_size = len(vocabs[5])

    self.position_encodings = position_encodings
    self.encoder = Encoder(num_layers, d_model, num_heads, dff,
                           self.tokens_vocab_size, pe_input, rate,
                           relative_radius=relative_radius,
                           position_encodings=position_encodings,
                           shared_layers=shared_layers)

    self.parent_encoding = parent_encoding
    self.relative_vocab_size = relative_radius * 2 + 1
    if parent_encoding == PARENT_ABSOLUTE or parent_encoding == PARENT_RELATIVE:
      self.parent_layer = tf.keras.layers.Dense(self.parent_vocab_size)
    elif parent_encoding == PARENT_ATTENTION:
      # Note: ABSOLUTE_SINUSOIDAL should also work with parent attention, since
      # the positional embeddings are added in the encoder.
      self.parent_layer_query = tf.keras.layers.Dense(d_model)
      self.parent_layer_key = tf.keras.layers.Dense(d_model)
      if position_encodings == RELATIVE or position_encodings == RELATIVE_BIAS:
        self.parent_relative_embeddings = tf.keras.layers.Embedding(
            self.relative_vocab_size, d_model)
      else:
        self.parent_relative_embeddings = None
      if position_encodings == RELATIVE_BIAS:
        self.parent_relative_biases = tf.keras.layers.Embedding(
            self.relative_vocab_size, 1)
      else:
        self.parent_relative_biases = None
    else:
      raise ValueError(f"Undefined parent_encoding: {parent_encoding}")
    self.role_layer = tf.keras.layers.Dense(self.role_vocab_size)
    self.category_layer = tf.keras.layers.Dense(self.category_vocab_size)
    self.noun_type_layer = tf.keras.layers.Dense(self.noun_type_vocab_size)
    self.verb_name_layer = tf.keras.layers.Dense(self.verb_name_vocab_size)

  def call(self, inp, parent_ids, training, enc_padding_mask, enc_relative_ids):

    enc_output = self.encoder(
        inp, training, enc_padding_mask,
        enc_relative_ids)  # (batch_size, inp_seq_len, d_model)

    if self.parent_encoding == PARENT_ABSOLUTE or self.parent_encoding == PARENT_RELATIVE:
      parent_output = tf.nn.softmax(self.parent_layer(enc_output), axis=-1)
    elif self.parent_encoding == PARENT_ATTENTION:
      parent_query = self.parent_layer_query(enc_output)
      parent_key = self.parent_layer_key(enc_output)
      enc_padding_mask_noheads = tf.squeeze(enc_padding_mask, axis=1)
      _, attention = scaled_dot_product_relative_attention(
          parent_query,
          parent_key,
          None,  # No values needed here.
          mask=enc_padding_mask_noheads,
          relative_ids=enc_relative_ids,
          relative_embeddings=self.parent_relative_embeddings,
          relative_biases=self.parent_relative_biases)
      # attention.shape == batch, inp_seq_len, inp_seq_len
      # parent_vocab has pad (-2) token at 0 index, and then sorted indices to the token input.
      parent_output = tf.roll(attention, shift=1, axis=-1)
    else:
      raise ValueError(f"Undefined parent_encoding: {self.parent_encoding}")

    role_output = tf.nn.softmax(self.role_layer(enc_output), axis=-1)
    category_output = tf.nn.softmax(self.category_layer(enc_output), axis=-1)
    noun_type_output = tf.nn.softmax(self.noun_type_layer(enc_output), axis=-1)
    verb_name_output = tf.nn.softmax(self.verb_name_layer(enc_output), axis=-1)

    return parent_output, role_output, category_output, noun_type_output, verb_name_output

  def detailed_param_count(self):
    print("Transformer parameter counts:")
    self.encoder.detailed_param_count()
    if self.parent_encoding == PARENT_ABSOLUTE or self.parent_encoding == PARENT_RELATIVE:
      print(f"  role_layer: {self.role_layer.count_params()}")
    elif self.parent_encoding == PARENT_ATTENTION:
      print(f"  parent_layer_query: {self.parent_layer_query.count_params()}")
      print(f"  parent_layer_key: {self.parent_layer_key.count_params()}")
      print("  ... possibly also relative embedding and biases")
    else:
      ValueError(f"Undefined parent_encoding: {self.parent_encoding}")
    print(f"  role_layer: {self.role_layer.count_params()}")
    print(f"  category_layer: {self.category_layer.count_params()}")
    print(f"  noun_type_layer: {self.noun_type_layer.count_params()}")
    print(f"  verb_name_layer: {self.verb_name_layer.count_params()}")


class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super(CustomSchedule, self).__init__()

    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps

  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)

    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [None]:
#@title Experiment execution function

def run_evaluation(dataset,
                   num_layers, d_model, dff, num_heads, position_encoding,
                   parent_encoding,
                   share_layer_weights,
                   num_epochs,
                   num_repetitions=1,
                   save_results_every_n_epochs=None,
                   batch_size=64):
  assert dataset == "cogs_tag"
  dataset_train, dataset_val_list, test_sets_names, vocabs, max_len = set_up_cogs(parent_encoding, batch_size)

  if save_results_every_n_epochs is None:
    save_results_every_n_epochs = num_epochs

  if position_encoding == RELATIVE8:
    position_encoding = RELATIVE
    relative_radius = 8
  elif position_encoding == RELATIVE16:
    position_encoding = RELATIVE
    relative_radius = 16
  elif position_encoding == RELATIVE16_BIAS:
    position_encoding = RELATIVE_BIAS
    relative_radius = 16
  elif position_encoding == ABSOLUTE_SINUSOIDAL:
    position_encoding = ABSOLUTE_SINUSOIDAL
    relative_radius = 0
  else:
    raise ValueError(f"Undefined position embeddings type: {position_encoding}")

  dropout_rate = 0.1

  repetition_metrics_l = []
  for repetition in range(num_repetitions):
    print(f"Starting repetition {repetition} ...\n")

    learning_rate = CustomSchedule(d_model)
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                         epsilon=1e-9)
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=False, reduction="none")
    train_loss = tf.keras.metrics.Mean(name="train_loss")
    train_accuracy = tf.keras.metrics.Mean(name="train_accuracy")
    eval_loss = tf.keras.metrics.Mean(name="train_loss")
    eval_accuracy = tf.keras.metrics.Mean(name="train_accuracy")

    transformer = COGSModel(
        num_layers,
        d_model,
        num_heads,
        dff,
        vocabs,
        pe_input=max_len,
        rate=dropout_rate,
        relative_radius=relative_radius,
        position_encodings=position_encoding,
        shared_layers=share_layer_weights,
        parent_encoding=parent_encoding)

    def loss_function(real, pred):
      mask = tf.math.logical_not(tf.math.equal(real, 0))
      loss_ = loss_object(real, pred)

      mask = tf.cast(mask, dtype=loss_.dtype)
      loss_ *= mask

      return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

    def accuracy_function(real, pred):
      mask = tf.math.logical_not(tf.math.equal(real, 0))
      loss_ = tf.keras.metrics.sparse_categorical_accuracy(real, pred)

      mask = tf.cast(mask, dtype=loss_.dtype)
      loss_ *= mask

      return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

    train_step_signature = [
        tf.TensorSpec(shape=(None, None), dtype=tf.int32),
        tf.TensorSpec(shape=(None, None), dtype=tf.int32),
        tf.TensorSpec(shape=(None, None), dtype=tf.int32),
        tf.TensorSpec(shape=(None, None), dtype=tf.int32),
        tf.TensorSpec(shape=(None, None), dtype=tf.int32),
        tf.TensorSpec(shape=(None, None), dtype=tf.int32),
        tf.TensorSpec(shape=(None, None), dtype=tf.int32)
    ]

    @tf.function(input_signature=train_step_signature)
    def train_step(inp, parent_ids, targ_parent, targ_role, targ_category, targ_noun_type, targ_verb_name):
      enc_padding_mask = create_masks(inp)
      enc_relative_ids = create_relative_ids(max_len, relative_radius)
      with tf.GradientTape() as tape:
        (parent_predictions,
         role_predictions,
         category_predictions,
         noun_type_predictions,
         verb_name_predictions) = transformer(inp, parent_ids, True,
                                              enc_padding_mask,
                                              enc_relative_ids)
        loss = loss_function(targ_parent, parent_predictions)
        accuracy = accuracy_function(targ_parent, parent_predictions)
        loss += loss_function(targ_role, role_predictions)
        accuracy += accuracy_function(targ_role, role_predictions)
        loss += loss_function(targ_category, category_predictions)
        accuracy += accuracy_function(targ_category, category_predictions)
        loss += loss_function(targ_noun_type, noun_type_predictions)
        accuracy += accuracy_function(targ_noun_type, noun_type_predictions)
        loss += loss_function(targ_verb_name, verb_name_predictions)
        accuracy += accuracy_function(targ_verb_name, verb_name_predictions)
        loss /= 5
        accuracy /= 5

      gradients = tape.gradient(loss, transformer.trainable_variables)
      optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
      train_loss(loss)
      train_accuracy(accuracy)

    @tf.function(input_signature=train_step_signature)
    def eval_step(inp, parent_ids, targ_parent, targ_role, targ_category, targ_noun_type, targ_verb_name):
      enc_padding_mask = create_masks(inp)
      enc_relative_ids = create_relative_ids(max_len, relative_radius)
      with tf.GradientTape() as _:
        (parent_predictions,
         role_predictions,
         category_predictions,
         noun_type_predictions,
         verb_name_predictions) = transformer(inp, parent_ids, False,
                                              enc_padding_mask,
                                              enc_relative_ids)
        loss = loss_function(targ_parent, parent_predictions)
        accuracy = accuracy_function(targ_parent, parent_predictions)
        loss += loss_function(targ_role, role_predictions)
        accuracy += accuracy_function(targ_role, role_predictions)
        loss += loss_function(targ_category, category_predictions)
        accuracy += accuracy_function(targ_category, category_predictions)
        loss += loss_function(targ_noun_type, noun_type_predictions)
        accuracy += accuracy_function(targ_noun_type, noun_type_predictions)
        loss += loss_function(targ_verb_name, verb_name_predictions)
        accuracy += accuracy_function(targ_verb_name, verb_name_predictions)
        loss /= 5
        accuracy /= 5
      eval_loss(loss)
      eval_accuracy(accuracy)

    def eval_step_detailed(inp, parent_ids, targ_parent, targ_role, targ_category, targ_noun_type, targ_verb_name, max_to_show,
                           vocabs):
      enc_padding_mask = create_masks(inp)
      enc_relative_ids = create_relative_ids(max_len, relative_radius)
      with tf.GradientTape() as _:
        (parent_predictions,
         role_predictions,
         category_predictions,
         noun_type_predictions,
         verb_name_predictions) = transformer(inp, parent_ids, False,
                                              enc_padding_mask,
                                              enc_relative_ids)

      predicted_targ_parent = tf.argmax(parent_predictions, 2).numpy()
      predicted_targ_role = tf.argmax(role_predictions, 2).numpy()
      predicted_targ_category = tf.argmax(category_predictions, 2).numpy()
      predicted_targ_noun_type = tf.argmax(noun_type_predictions, 2).numpy()
      predicted_targ_verb_name = tf.argmax(verb_name_predictions, 2).numpy()

      accuracy_seq_detailed = [0, 0, 0, 0, 0]
      accuracy_seq = 0
      accuracy_tok = 0
      total_tok = 0
      shown = 0
      parent_ground_truth = targ_parent.numpy()
      role_ground_truth = targ_role.numpy()
      category_ground_truth = targ_category.numpy()
      noun_type_ground_truth = targ_noun_type.numpy()
      verb_name_ground_truth = targ_verb_name.numpy()
      for i in range(batch_size):
        # clear out the padding predictions:
        for j in range(len(parent_ground_truth[i])):
          if role_ground_truth[i][j] == 0:
            predicted_targ_parent[i][j] = 0
            predicted_targ_role[i][j] = 0
            predicted_targ_category[i][j] = 0
            predicted_targ_noun_type[i][j] = 0
            predicted_targ_verb_name[i][j] = 0
        for k in range(len(parent_ground_truth[i])):
          if parent_ground_truth[i][k] != 0:
            if predicted_targ_parent[i][k] == parent_ground_truth[i][k]:
              accuracy_tok += 1
            if predicted_targ_role[i][k] == role_ground_truth[i][k]:
              accuracy_tok += 1
            if predicted_targ_category[i][k] == category_ground_truth[i][k]:
              accuracy_tok += 1
            if predicted_targ_noun_type[i][k] == noun_type_ground_truth[i][k]:
              accuracy_tok += 1
            if predicted_targ_verb_name[i][k] == verb_name_ground_truth[i][k]:
              accuracy_tok += 1
            total_tok += 5
            # total_tok += 4
        if (predicted_targ_parent[i] == parent_ground_truth[i]).all():
          accuracy_seq_detailed[0] += 1
        if (predicted_targ_role[i] == role_ground_truth[i]).all():
          accuracy_seq_detailed[1] += 1
        if (predicted_targ_category[i] == category_ground_truth[i]).all():
          accuracy_seq_detailed[2] += 1
        if (predicted_targ_noun_type[i] == noun_type_ground_truth[i]).all():
          accuracy_seq_detailed[3] += 1
        if (predicted_targ_verb_name[i] == verb_name_ground_truth[i]).all():
          accuracy_seq_detailed[4] += 1
        if ((predicted_targ_parent[i] == parent_ground_truth[i]).all() and
            (predicted_targ_role[i] == role_ground_truth[i]).all() and
            (predicted_targ_category[i] == category_ground_truth[i]).all() and
            (predicted_targ_noun_type[i] == noun_type_ground_truth[i]).all() and
            (predicted_targ_verb_name[i] == verb_name_ground_truth[i]).all()):
          accuracy_seq += 1
        else:
          if shown < max_to_show:
            print("--------------")
            print("tokens:       " + decode(inp.numpy()[i], vocabs[0]))
            print("parent gt:    " + decode(parent_ground_truth[i], vocabs[1]))
            print("role gt:      " + decode(role_ground_truth[i], vocabs[2]))
            print("category gt:  " + decode(category_ground_truth[i], vocabs[3]))
            print("noun type gt: " + decode(noun_type_ground_truth[i], vocabs[4]))
            print("verb name gt: " + decode(verb_name_ground_truth[i], vocabs[5]))
            print("parent:       " + decode(predicted_targ_parent[i], vocabs[1]))
            print("role:         " + decode(predicted_targ_role[i], vocabs[2]))
            print("category:     " + decode(predicted_targ_category[i], vocabs[3]))
            print("noun type:    " + decode(predicted_targ_noun_type[i], vocabs[4]))
            print("verb name:    " + decode(predicted_targ_verb_name[i], vocabs[5]))
            shown += 1
      for i in range(5):
        accuracy_seq_detailed[i] /= float(batch_size)
      return accuracy_tok / float(total_tok), accuracy_seq / float(batch_size), accuracy_seq_detailed

    def evaluate_in_set(dataset_val, vocabs):
      steps_per_epoch_val = dataset_val.cardinality()
      n_test_batches = 0
      test_accuracy_token = 0
      test_accuracy_seq = 0
      test_accuracy_seq_detailed = [0, 0, 0, 0, 0]
      for (batch, (inp, parent_ids,
                   targ_parent, targ_role,
                   targ_category, targ_noun_type,
                   targ_verb_name)) in enumerate(dataset_val.take(steps_per_epoch_val)):
        if batch == 0:
          batch_accuracy_token, batch_accuracy_seq, batch_accuracy_seq_detailed = eval_step_detailed(
              inp, parent_ids, targ_parent, targ_role,
              targ_category, targ_noun_type,
              targ_verb_name, 0, vocabs)
          test_accuracy_token += batch_accuracy_token
          test_accuracy_seq += batch_accuracy_seq
          for i in range(5):
            test_accuracy_seq_detailed[i] += batch_accuracy_seq_detailed[i]
        else:
          batch_accuracy_token, batch_accuracy_seq, batch_accuracy_seq_detailed = eval_step_detailed(
              inp, parent_ids, targ_parent, targ_role,
              targ_category, targ_noun_type,
              targ_verb_name, 0, vocabs)
          test_accuracy_token += batch_accuracy_token
          test_accuracy_seq += batch_accuracy_seq
          for i in range(5):
            test_accuracy_seq_detailed[i] += batch_accuracy_seq_detailed[i]
        n_test_batches += 1

      print(
          f"Eval accuracy (token level): {test_accuracy_token/n_test_batches}")
      print(
          f"Eval accuracy (sequence level): {test_accuracy_seq/n_test_batches}")
      names = ["parent", "role", "category", "noun type", "verb name"]
      for i in range(5):
        print(f"Eval accuracy (sequence level: {names[i]}): {test_accuracy_seq_detailed[i]/n_test_batches}")
      print("")
      return (test_accuracy_token/n_test_batches, test_accuracy_seq/n_test_batches)

    steps_per_epoch = dataset_train.cardinality()
    print(f"steps_per_epoch: {steps_per_epoch}")
    repetition_metrics = []
    for epoch in range(num_epochs):
      steps_per_epoch_val = dataset_val_list[0].cardinality()
      start = time.time()

      train_loss.reset_states()
      train_accuracy.reset_states()
      eval_loss.reset_states()
      eval_accuracy.reset_states()

      for (batch, (inp, parent_ids,
                   targ_parent, targ_role,
                   targ_category, targ_noun_type,
                   targ_verb_name)) in enumerate(dataset_train.take(steps_per_epoch)):
        # print("inp:" + str(inp.shape))
        # print("targ:" + str(targ.shape))
        train_step(inp, parent_ids, targ_parent, targ_role,
                   targ_category, targ_noun_type,
                   targ_verb_name)
        if batch % 100 == 0:
          print("Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}".format(
              epoch + 1, batch, train_loss.result(), train_accuracy.result()))
      print("Epoch {} Loss {:.4f} Accuracy {:.4f}".format(
          epoch + 1, train_loss.result(), train_accuracy.result()))
      for (batch, (inp, parent_ids, targ_parent, targ_role, targ_category,
                   targ_noun_type, targ_verb_name)) in enumerate(
                       dataset_val_list[0].take(steps_per_epoch_val)):
        eval_step(inp, parent_ids, targ_parent, targ_role, targ_category,
                  targ_noun_type, targ_verb_name)
      print("Epoch {} Eval Loss {:.4f} Eval Accuracy {:.4f}".format(
          epoch + 1, eval_loss.result(), eval_accuracy.result()))

      print("Time taken for 1 epoch: {} secs\n".format(time.time() - start))

      if ((epoch + 1) % save_results_every_n_epochs) == 0:
        epoch_metrics = [epoch + 1]
        for i in range(len(dataset_val_list)):
          print(f"------- Evaluation in dataset_val {i} -------")
          (acc_token, acc_seq) = evaluate_in_set(dataset_val_list[i], vocabs)
          epoch_metrics.append(acc_token)
          epoch_metrics.append(acc_seq)
        repetition_metrics.append(epoch_metrics)

    # Make sure we save the last epoch results if it is not a
    # multiple of save_results_every_n_epochs:
    if (num_epochs % save_results_every_n_epochs) != 0:
      epoch_metrics = [num_epochs]
      for i in range(len(dataset_val_list)):
        print(f"------- Evaluation in dataset_val {i} -------")
        (acc_token, acc_seq) = evaluate_in_set(dataset_val_list[i], vocabs)
        epoch_metrics.append(acc_token)
        epoch_metrics.append(acc_seq)
      repetition_metrics.append(epoch_metrics)

    repetition_metrics_l.append(repetition_metrics)

  print("Raw repetition metrics:")
  for repetition_metrics in repetition_metrics_l:
    for epoch_metrics in repetition_metrics:
      print(epoch_metrics)

  averages = repetition_metrics_l[0]
  for i in range(1, len(repetition_metrics_l)):
    for j in range(len(repetition_metrics_l[i])):
      epoch_metrics = repetition_metrics_l[i][j]
      # skip the epoch number:
      for k in range(1, len(epoch_metrics)):
        averages[j][k] += epoch_metrics[k]
  for j in range(len(averages)):
    for k in range(1, len(averages[j])):
      averages[j][k] /= len(repetition_metrics_l)

  print("Average repetition metrics:")
  for average_metrics in averages:
    print("\t".join(map(str, average_metrics)))

  print(f"Transformer params: {transformer.count_params()}")
  transformer.detailed_param_count()

  return averages, transformer, test_sets_names


def pretty_print_results(result, test_sets_names):
  """Print table of results twice: once space aligned, once comma separated.

  results is a list of numbers:
  [epoch, pairs of (acc_token, acc_seq) for each of the 24 test sets]
  """
  assert len(result) == 2*len(test_sets_names) + 1
  print("%-20s  acc_token  acc_seq"%(""))
  for i in range(len(test_sets_names)):
    print("%-20s  %.3f      %.3f"%(test_sets_names[i][:20], result[1+2*i], result[1+2*i+1]))
  print()
  print(",acc_token,acc_seq")
  for i in range(len(test_sets_names)):
    print("%s,%.3f,%.3f"%(test_sets_names[i][:20], result[1+2*i], result[1+2*i+1]))


In [None]:
#@title Experiment - base

# Run the different experiments:
# parameters are:
# - dataset
# - num_layers, d_model, dff, num_heads
# - position_encoding, parent_encoding, share_layer_weights
# - num_epochs
# - optional keyword parameters:
#     num_repetitions=1,
#     save_results_every_n_epochs=None,
#     batch_size=64

# Example:
results, transformer, test_sets_names = run_evaluation(
    'cogs_tag',
    2, 64, 256, 4,
    ABSOLUTE_SINUSOIDAL, PARENT_ABSOLUTE, False,
    16,
    num_repetitions=1, save_results_every_n_epochs=None, batch_size=64)

pretty_print_results(results[-1], test_sets_names)

In [None]:
#@title Experiment - tuned

results, transformer, test_sets_names = run_evaluation(
    'cogs_tag',
    2, 64, 256, 4,
    RELATIVE16_BIAS, PARENT_ATTENTION, True,
    16,
    num_repetitions=1, save_results_every_n_epochs=None, batch_size=64)

pretty_print_results(results[-1], test_sets_names)