<a href="https://colab.research.google.com/github/jaekyoungkim/longformer/blob/main/bigbird.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# Copyright 2020 The BigBird Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

#The most important directory is core. There are three main files in core.

#attention.py: Contains BigBird linear attention mechanism
#encoder.py: Contains the main long sequence encoder stack
#modeling.py: Contains packaged BERT and seq2seq transformer models with BigBird attention

# py 파일  

In [14]:
############## attention ################

# Copyright 2021 The BigBird Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BigBird Attention Layers."""

from absl import logging
from bigbird.core import recompute_grad
from bigbird.core import utils
import numpy as np
import tensorflow.compat.v2 as tf


MAX_SEQ_LEN = 4096


def get_single_block_row_attention(block_id,
                                   to_start_block_id,
                                   to_end_block_id,
                                   num_rand_blocks,
                                   window_block_left=1,
                                   window_block_right=1,
                                   global_block_left=1,
                                   global_block_right=1):
  """For a single row block get random row attention.
  Args:
    block_id: int. block id of row.
    to_start_block_id: int. random attention coloum start id.
    to_end_block_id: int. random attention coloum end id.
    num_rand_blocks: int. number of random blocks to be selected.
    window_block_left: int. number of blocks of window to left of a block.
    window_block_right: int. number of blocks of window to right of a block.
    global_block_left: int. Number of blocks globally used to the left.
    global_block_right: int. Number of blocks globally used to the right.
  Returns:
    row containing the random attention vector of size num_rand_blocks.
  """

  # list of to_blocks from which to choose random attention
  to_block_list = np.arange(to_start_block_id, to_end_block_id,
                            dtype=np.int32)
  # permute the blocks
  perm_block = np.random.permutation(to_block_list)
  # print(perm_block)

  # illegal blocks for the current block id, using window
  illegal_blocks = list(
      range(block_id - window_block_left, block_id + window_block_right + 1))

  # Add blocks at the start and at the end
  illegal_blocks.extend(list(range(global_block_left)))
  illegal_blocks.extend(
      list(range(to_end_block_id - global_block_right, to_end_block_id)))

  # The second from_block cannot choose random attention on second last to_block
  if block_id == 1:
    illegal_blocks.append(to_end_block_id-2)

  # The second last from_block cannot choose random attention on second to_block
  if block_id == to_end_block_id - 2:
    illegal_blocks.append(1)

  selected_random_blokcs = []

  for i in range(to_end_block_id - to_start_block_id):
    if perm_block[i] not in illegal_blocks:
      selected_random_blokcs.append(perm_block[i])
    if len(selected_random_blokcs) == num_rand_blocks:
      break
  return np.array(selected_random_blokcs, dtype=np.int32)


def bigbird_block_rand_mask_with_head(seq_length,
                                      block_size,
                                      num_heads,
                                      plan_from_length,
                                      plan_num_rand_blocks,
                                      window_block_left=1,
                                      window_block_right=1,
                                      global_block_top=1,
                                      global_block_bottom=1,
                                      global_block_left=1,
                                      global_block_right=1):
  """Create adjacency list of random attention.
  Args:
    seq_length: int. length of sequence.
    block_size: int. size of block in sequence.
    num_heads: int. total number of heads.
    plan_from_length: list. plan from lenght where num_rand are choosen from.
    plan_num_rand_blocks: list. number of rand blocks within the plan.
    window_block_left: int. number of blocks of window to left of a block.
    window_block_right: int. number of blocks of window to right of a block.
    global_block_top: int. number of blocks at the top.
    global_block_bottom: int. number of blocks at the bottom.
    global_block_left: int. Number of blocks globally used to the left.
    global_block_right: int. Number of blocks globally used to the right.
  Returns:
    adjacency list of size num_head where each element is of size
    from_seq_length//from_block_size-2 by num_rand_blocks
  """
  # Total number of blocks in the mmask
  num_blocks = seq_length//block_size
  # Number of blocks per plan
  plan_block_length = np.array(plan_from_length) // block_size
  # till when to follow plan
  max_plan_idx = plan_from_length.index(seq_length)
  # Random Attention adjajency list
  rand_attn = [np.zeros((num_blocks,
                         np.sum(plan_num_rand_blocks[:max_plan_idx+1])),
                        dtype=np.int32) for i in range(num_heads)]

  # We will go iteratively over the plan blocks and pick random number of
  # Attention blocks from the legally allowed blocks
  for plan_idx in range(max_plan_idx+1):
    rnd_r_cnt = 0
    if plan_idx > 0:
      # set the row for all from_blocks starting from 0 to
      # plan_block_length[plan_idx-1]
      # column indx start fromm plan_block_length[plan_idx-1] and ends at
      # plan_block_length[plan_idx]
      if plan_num_rand_blocks[plan_idx] > 0:
        rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx]))
        curr_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx+1]))
        for blk_rw_idx in range(global_block_top,
                                plan_block_length[plan_idx-1]):
          for h in range(num_heads):
            # print("head", h, "blk_rw_idx", blk_rw_idx)
            rand_attn[h][blk_rw_idx,
                         rnd_r_cnt:curr_r_cnt] = get_single_block_row_attention(
                             block_id=blk_rw_idx,
                             to_start_block_id=plan_block_length[plan_idx - 1],
                             to_end_block_id=plan_block_length[plan_idx],
                             num_rand_blocks=plan_num_rand_blocks[plan_idx],
                             window_block_left=window_block_left,
                             window_block_right=window_block_right,
                             global_block_left=global_block_left,
                             global_block_right=global_block_right)

      for pl_id in range(plan_idx):
        if plan_num_rand_blocks[pl_id] == 0:
          continue
        for blk_rw_idx in range(plan_block_length[plan_idx-1],
                                plan_block_length[plan_idx]):
          rnd_r_cnt = 0
          to_start_block_id = 0
          if pl_id > 0:
            rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:pl_id]))
            to_start_block_id = plan_block_length[pl_id-1]
          curr_r_cnt = int(np.sum(plan_num_rand_blocks[:pl_id+1]))
          for h in range(num_heads):
            # print("head", h, "blk_rw_idx", blk_rw_idx)
            rand_attn[h][blk_rw_idx,
                         rnd_r_cnt:curr_r_cnt] = get_single_block_row_attention(
                             block_id=blk_rw_idx,
                             to_start_block_id=to_start_block_id,
                             to_end_block_id=plan_block_length[pl_id],
                             num_rand_blocks=plan_num_rand_blocks[pl_id],
                             window_block_left=window_block_left,
                             window_block_right=window_block_right,
                             global_block_left=global_block_left,
                             global_block_right=global_block_right)

    if plan_num_rand_blocks[plan_idx] == 0:
      continue
    # print("Start from here")
    curr_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx+1]))
    from_start_block_id = global_block_top
    to_start_block_id = 0
    if plan_idx > 0:
      rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx]))
      from_start_block_id = plan_block_length[plan_idx-1]
      to_start_block_id = plan_block_length[plan_idx-1]

    for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]):
      for h in range(num_heads):
        # print("head", h, "blk_rw_idx", blk_rw_idx)
        rand_attn[h][blk_rw_idx,
                     rnd_r_cnt:curr_r_cnt] = get_single_block_row_attention(
                         block_id=blk_rw_idx,
                         to_start_block_id=to_start_block_id,
                         to_end_block_id=plan_block_length[plan_idx],
                         num_rand_blocks=plan_num_rand_blocks[plan_idx],
                         window_block_left=window_block_left,
                         window_block_right=window_block_right,
                         global_block_left=global_block_left,
                         global_block_right=global_block_right)

  for nh in range(num_heads):
    rand_attn[nh] = rand_attn[nh][global_block_top:num_blocks -
                                  global_block_bottom, :]
  return rand_attn


def get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks):
  """Gives the plan of where to put random attention.
  Args:
    from_seq_length: int. length of from sequence.
    from_block_size: int. size of block in from sequence.
    num_rand_blocks: int. Number of random chunks per row.
  Returns:
    plan_from_length: ending location of from block
    plan_num_rand_blocks: number of random ending location for each block
  """
  # general plan
  plan_from_length = []
  plan_num_rand_blocks = []
  if (2*num_rand_blocks + 5) < (from_seq_length // from_block_size):
    plan_from_length.append(int((2*num_rand_blocks + 5)*from_block_size))
    plan_num_rand_blocks.append(num_rand_blocks)
    plan_from_length.append(from_seq_length)
    plan_num_rand_blocks.append(0)
  elif (num_rand_blocks + 5) < (from_seq_length // from_block_size):
    plan_from_length.append(int((num_rand_blocks + 5)*from_block_size))
    plan_num_rand_blocks.append(num_rand_blocks//2)
    plan_from_length.append(from_seq_length)
    plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks//2))
  else:
    plan_from_length.append(from_seq_length)
    plan_num_rand_blocks.append(num_rand_blocks)

  return plan_from_length, plan_num_rand_blocks


def bigbird_block_rand_mask(from_seq_length,
                            to_seq_length,
                            from_block_size,
                            to_block_size,
                            num_rand_blocks,
                            last_idx=-1):
  """Create adjacency list of random attention.
  Args:
    from_seq_length: int. length of from sequence.
    to_seq_length: int. length of to sequence.
    from_block_size: int. size of block in from sequence.
    to_block_size: int. size of block in to sequence.
    num_rand_blocks: int. Number of random chunks per row.
    last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence,
      if positive then num_rand_blocks blocks choosen only upto last_idx.
  Returns:
    adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks
  """
  rand_attn = np.zeros(
      (from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32)
  middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32)
  last = to_seq_length // to_block_size - 1
  if last_idx > (2 * to_block_size):
    last = (last_idx // to_block_size) - 1

  r = num_rand_blocks  # shorthand
  for i in range(1, from_seq_length // from_block_size-1):
    start = i-2
    end = i
    if i == 1:
      rand_attn[i-1, :] = np.random.permutation(middle_seq[2:last])[:r]
    elif i == 2:
      rand_attn[i-1, :] = np.random.permutation(middle_seq[3:last])[:r]
    elif i == from_seq_length // from_block_size - 3:
      rand_attn[i-1, :] = np.random.permutation(middle_seq[:last])[:r]
      # Missing -3: should have been sliced till last-3
    elif i == from_seq_length // from_block_size - 2:
      rand_attn[i-1, :] = np.random.permutation(middle_seq[:last])[:r]
      # Missing -4: should have been sliced till last-4
    else:
      if start > last:
        start = last
        rand_attn[i-1, :] = np.random.permutation(middle_seq[:start])[:r]
      elif (end+1) == last:
        rand_attn[i-1, :] = np.random.permutation(middle_seq[:start])[:r]
      else:
        rand_attn[i-1, :] = np.random.permutation(
            np.concatenate((middle_seq[:start], middle_seq[end+1:last])))[:r]
  return rand_attn


def full_bigbird_mask(from_seq_length,
                      to_seq_length,
                      from_block_size,
                      to_block_size,
                      rand_attn):
  """Calculate BigBird attention pattern as a full dense matrix.
  Args:
    from_seq_length: int. length of from sequence.
    to_seq_length: int. length of to sequence.
    from_block_size: int. size of block in from sequence.
    to_block_size: int. size of block in to sequence.
    rand_attn: adjajency matrix for random attention.
  Returns:
    attention mask matrix of shape [from_seq_length, to_seq_length]
  """

  attn_mask = np.zeros((MAX_SEQ_LEN, MAX_SEQ_LEN), dtype=np.int32)
  for i in range(1, (MAX_SEQ_LEN // from_block_size) - 1):
    attn_mask[(i) * from_block_size:(i + 1) * from_block_size,
              (i - 1) * to_block_size:(i + 2) * to_block_size] = 1
    for j in rand_attn[i - 1, :]:
      attn_mask[i * from_block_size:(i + 1) * from_block_size,
                j * to_block_size:(j + 1) * to_block_size] = 1

  attn_mask[:from_block_size, :] = 1
  attn_mask[:, :to_block_size] = 1
  attn_mask[:, -to_block_size:] = 1
  attn_mask[-from_block_size:, :] = 1
  clipped_attn_mask = attn_mask[:from_seq_length, :to_seq_length]
  return np.array(clipped_attn_mask, dtype=bool)


def create_rand_mask_from_inputs(from_blocked_mask,
                                 to_blocked_mask,
                                 rand_attn,
                                 num_attention_heads,
                                 num_rand_blocks,
                                 from_seq_length,
                                 from_block_size):
  """Create 4D attention mask from a 3D tensor mask.
  Args:
    from_blocked_mask: 2D Tensor of shape [batch_size,
      from_seq_length//from_block_size, from_block_size].
    to_blocked_mask: int32 Tensor of shape [batch_size,
      to_seq_length//to_block_size, to_block_size].
    rand_attn: [batch_size, num_attention_heads,
      from_seq_length//from_block_size-2, num_rand_blocks]
    num_attention_heads: int. Number of attention heads.
    num_rand_blocks: int. Number of random chunks per row.
    from_seq_length: int. length of from sequence.
    from_block_size: int. size of block in from sequence.
  Returns:
    float Tensor of shape [batch_size, num_attention_heads,
                           from_seq_length//from_block_size-2,
                           from_block_size, num_rand_blocks*to_block_size].
  """
  num_windows = from_seq_length // from_block_size - 2
  rand_mask = tf.reshape(
      tf.gather(to_blocked_mask, rand_attn, batch_dims=1), [
          -1, num_attention_heads, num_windows,
          num_rand_blocks * from_block_size
      ])
  rand_mask = tf.einsum("BLQ,BHLK->BHLQK", from_blocked_mask[:, 1:-1],
                        rand_mask)
  return rand_mask


def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask):
  """Create 4D attention mask from a 3D blocked tensor mask.
  Args:
    from_blocked_mask: 3D Tensor of shape [batch_size,
      from_seq_length//from_block_size, from_block_size].
    to_blocked_mask: 3D Tensor of shape [batch_size,
      to_seq_length//to_block_size, to_block_size].
  Returns:
    float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4,
                           from_block_size,  3*to_block_size].
  """
  exp_blocked_to_pad = tf.concat(
      [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2],
       to_blocked_mask[:, 3:-1]], 2)
  band_mask = tf.einsum(
      "BLQ,BLK->BLQK", from_blocked_mask[:, 2:-2], exp_blocked_to_pad)
  band_mask = tf.expand_dims(band_mask, 1)
  return band_mask


def create_attention_mask_from_input_mask(from_mask, to_mask):
  """Create attention mask from a 2D tensor mask.
  Args:
    from_mask: float32 Tensor of shape [batch_size, from_seq_length].
    to_mask: float32 Tensor of shape [batch_size, to_seq_length].
  Returns:
    float32 Tensor of shape [batch_size, 1, from_seq_length, to_seq_length].
  """
  mask = tf.einsum("BF,BT->BFT", from_mask, to_mask)

  # expand to create a slot for heads.
  mask = tf.expand_dims(mask, 1)

  return mask


def bigbird_block_sparse_attention(query_layer,
                                   key_layer,
                                   value_layer,
                                   band_mask,
                                   from_mask,
                                   to_mask,
                                   from_blocked_mask,
                                   to_blocked_mask,
                                   rand_attn,
                                   num_attention_heads,
                                   size_per_head,
                                   num_rand_blocks,
                                   from_seq_length,
                                   to_seq_length,
                                   from_block_size,
                                   to_block_size):
  """BigBird attention sparse calculation using blocks in linear time.
  Assumes from_seq_length//from_block_size == to_seq_length//to_block_size.
  A pure function with a long argument list to allow easy use outside our
  framework.
  Args:
    query_layer: float Tensor of shape [batch_size, num_attention_heads,
      from_seq_length, size_per_head]
    key_layer: float Tensor of shape [batch_size, num_attention_heads,
      to_seq_length, size_per_head]
    value_layer: float Tensor of shape [batch_size, num_attention_heads,
      to_seq_length, size_per_head]
    band_mask: float32 Tensor of shape [batch_size, 1,
      from_seq_length//from_block_size-4, from_block_size, 3*to_block_size].
      The values should be 1 or 0. The attention scores will effectively be
      set to -infinity for any positions in the mask that are 0, and will be
      unchanged for positions that are 1.
    from_mask: float32 Tensor of shape [batch_size, 1, from_seq_length, 1].
      The values should be 1 or 0. The attention scores will effectively be set
      to -infinity for any positions in the mask that are 0, and will be
      unchanged for positions that are 1.
    to_mask: float32 Tensor of shape [batch_size, 1, 1, to_seq_length].
      The values should be 1 or 0. The attention scores will effectively be set
      to -infinity for any positions in the mask that are 0, and will be
      unchanged for positions that are 1.
    from_blocked_mask: float32 Tensor of shape [batch_size,
      from_seq_length//from_block_size, from_block_size].
      Same as from_mask, just reshaped.
    to_blocked_mask: float32 Tensor of shape [batch_size,
      to_seq_length//to_block_size, to_block_size].
      Same as to_mask, just reshaped.
    rand_attn: int32 Tensor of shape [num_attention_heads,
      from_seq_length//from_block_size-2, num_rand_blocks] specifying which
      blocks to attend to for each from sequence block (except 2 global ones).
    num_attention_heads: int. Number of attention heads.
    size_per_head: int. Size of each attention head.
    num_rand_blocks: int. Number of random chunks per row.
    from_seq_length: int. length of from sequence.
    to_seq_length: int. length of to sequence.
    from_block_size: int. size of block in from sequence.
    to_block_size: int. size of block in to sequence.
  Returns:
    float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
      size_per_head].
  """
  assert from_seq_length//from_block_size == to_seq_length//to_block_size

  # repeat for batch size
  batch_size = utils.get_shape_list(query_layer)[0]
  rand_attn = tf.expand_dims(rand_attn, 0)
  rand_attn = tf.repeat(rand_attn, batch_size, 0)

  rand_mask = create_rand_mask_from_inputs(
      from_blocked_mask, to_blocked_mask, rand_attn,
      num_attention_heads, num_rand_blocks,
      from_seq_length, from_block_size)

  # Define shorthands
  # b = batch_size
  h = num_attention_heads
  r = num_rand_blocks
  d = size_per_head
  m = from_seq_length
  n = to_seq_length
  wm = from_block_size
  wn = to_block_size

  blocked_query_matrix = tf.reshape(query_layer, (-1, h, m // wm, wm, d))
  blocked_key_matrix = tf.reshape(key_layer, (-1, h, n // wn, wn, d))
  blocked_value_matrix = tf.reshape(value_layer, (-1, h, n // wn, wn, d))
  gathered_key = tf.reshape(
      tf.gather(blocked_key_matrix, rand_attn, batch_dims=2, name="gather_key"),
      (-1, h, m // wm - 2, r * wn, d))  # [b, h, n//wn-2, r, wn, -1]
  gathered_value = tf.reshape(
      tf.gather(
          blocked_value_matrix, rand_attn, batch_dims=2, name="gather_value"),
      (-1, h, m // wm - 2, r * wn, d))  # [b, h, n//wn-2, r, wn, -1]

  first_product = tf.einsum(
      "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 0],
      key_layer)  # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n]
  first_product = tf.multiply(first_product, 1.0 / np.sqrt(d))
  first_product += (1.0 - to_mask) * -10000.0
  first_attn_weights = tf.nn.softmax(first_product)  # [b, h, wm, n]
  first_context_layer = tf.einsum(
      "BHQK,BHKD->BHQD", first_attn_weights,
      value_layer)  # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1]
  first_context_layer = tf.expand_dims(first_context_layer, 2)

  second_key_mat = tf.concat([
      blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, 1],
      blocked_key_matrix[:, :, 2], blocked_key_matrix[:, :, -1],
      gathered_key[:, :, 0]], 2)  # [b, h, (4+r)*wn, -1]
  second_value_mat = tf.concat([
      blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, 1],
      blocked_value_matrix[:, :, 2], blocked_value_matrix[:, :, -1],
      gathered_value[:, :, 0]], 2)  # [b, h, (4+r)*wn, -1]
  second_product = tf.einsum(
      "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, 1], second_key_mat
  )  # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
  second_seq_pad = tf.concat([
      to_mask[:, :, :, :3 * wn], to_mask[:, :, :, -wn:],
      tf.ones_like(rand_mask[:, :1, 0, :1])], 3)
  second_rand_pad = tf.concat(
      [tf.ones_like(second_product[:, :, :, :4 * wn]), rand_mask[:, :, 0]], 3)
  second_product = tf.multiply(second_product, 1.0 / np.sqrt(d))
  second_product += (1.0 -
                     tf.minimum(second_seq_pad, second_rand_pad)) * -10000.0
  second_attn_weights = tf.nn.softmax(second_product)  # [b , h, wm, (4+r)*wn]
  second_context_layer = tf.einsum(
      "BHQK,BHKD->BHQD", second_attn_weights, second_value_mat
  )  # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1]
  second_context_layer = tf.expand_dims(second_context_layer, 2)

  exp_blocked_key_matrix = tf.concat([
      blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2],
      blocked_key_matrix[:, :, 3:-1]], 3)  # [b, h, m//wm-4, 3*wn, -1]
  exp_blocked_value_matrix = tf.concat([
      blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2],
      blocked_value_matrix[:, :, 3:-1]], 3)  # [b, h, m//wm-4, 3*wn, -1]
  middle_query_matrix = blocked_query_matrix[:, :, 2:-2]
  inner_band_product = tf.einsum(
      "BHLQD,BHLKD->BHLQK", middle_query_matrix, exp_blocked_key_matrix
  )  # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, 3*wn, -1]
  #     ==> [b, h, m//wm-4, wm, 3*wn]
  inner_band_product = tf.multiply(inner_band_product, 1.0 / np.sqrt(d))
  rand_band_product = tf.einsum(
      "BHLQD,BHLKD->BHLQK", middle_query_matrix, gathered_key[:, :, 1:-1]
  )  # [b, h, m//wm-4, wm, -1] x [b, h, m//wm-4, r*wn, -1]
  #     ==> [b, h, m//wm-4, wm, r*wn]
  rand_band_product = tf.multiply(rand_band_product, 1.0 / np.sqrt(d))
  first_band_product = tf.einsum(
      "BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, 0]
  )  # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn]
  first_band_product = tf.multiply(first_band_product, 1.0 / np.sqrt(d))
  last_band_product = tf.einsum(
      "BHLQD,BHKD->BHLQK", middle_query_matrix, blocked_key_matrix[:, :, -1]
  )  # [b, h, m//wm-4, wm, -1] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, wn]
  last_band_product = tf.multiply(last_band_product, 1.0 / np.sqrt(d))
  inner_band_product += (1.0 - band_mask) * -10000.0
  first_band_product += (
      1.0 - tf.expand_dims(to_mask[:, :, :, :wn], 3)) * -10000.0
  last_band_product += (
      1.0 - tf.expand_dims(to_mask[:, :, :, -wn:], 3)) * -10000.0
  rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * -10000.0
  band_product = tf.concat([
      first_band_product, inner_band_product, rand_band_product,
      last_band_product], -1)  # [b, h, m//wm-4, wm, (5+r)*wn]
  attn_weights = tf.nn.softmax(band_product)  # [b, h, m//wm-4, wm, (5+r)*wn]
  context_layer = tf.einsum(
      "BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :, wn:4 * wn],
      exp_blocked_value_matrix
  )  # [b, h, m//wm-4, wm, 3*wn] x [b, h, m//wm-4, 3*wn, -1]
  #     ==> [b, h, m//wm-4, wm, -1]
  context_layer += tf.einsum(
      "BHLQK,BHLKD->BHLQD", attn_weights[:, :, :, :, 4 * wn:-wn],
      gathered_value[:, :, 1:-1]
  )  # [b, h, m//wm-4, wm, r*wn] x [b, h, m//wm-4, r*wn, -1]
  #     ==> [b, h, m//wm-4, wm, -1]
  context_layer += tf.einsum(
      "BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :, :wn],
      blocked_value_matrix[:, :, 0]
  )  # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1]
  context_layer += tf.einsum(
      "BHLQK,BHKD->BHLQD", attn_weights[:, :, :, :, -wn:],
      blocked_value_matrix[:, :, -1]
  )  # [b, h, m//wm-4, wm, wn] x [b, h, wn, -1] ==> [b, h, m//wm-4, wm, -1]

  second_last_key_mat = tf.concat([
      blocked_key_matrix[:, :, 0], blocked_key_matrix[:, :, -3],
      blocked_key_matrix[:, :, -2], blocked_key_matrix[:, :, -1],
      gathered_key[:, :, -1]], 2)  # [b, h, (4+r)*wn, -1]
  second_last_value_mat = tf.concat([
      blocked_value_matrix[:, :, 0], blocked_value_matrix[:, :, -3],
      blocked_value_matrix[:, :, -2], blocked_value_matrix[:, :, -1],
      gathered_value[:, :, -1]], 2)  # [b, h, (4+r)*wn, -1]
  second_last_product = tf.einsum(
      "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -2], second_last_key_mat
  )  # [b, h, wm, -1] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, (4+r)*wn]
  second_last_seq_pad = tf.concat([
      to_mask[:, :, :, :wn], to_mask[:, :, :, -3 * wn:],
      tf.ones_like(rand_mask[:, :1, 0, :1])], 3)
  second_last_rand_pad = tf.concat(
      [tf.ones_like(second_last_product[:, :, :, :4 * wn]),
       rand_mask[:, :, -1]], 3)
  second_last_product = tf.multiply(second_last_product, 1.0 / np.sqrt(d))
  second_last_product += (
      1.0 - tf.minimum(second_last_seq_pad, second_last_rand_pad)) * -10000.0
  second_last_attn_weights = tf.nn.softmax(
      second_last_product)  # [b, h, wm, (4+r)*wn]
  second_last_context_layer = tf.einsum(
      "BHQK,BHKD->BHQD", second_last_attn_weights, second_last_value_mat
  )  # [b, h, wm, (4+r)*wn] x [b, h, (4+r)*wn, -1] ==> [b, h, wm, -1]
  second_last_context_layer = tf.expand_dims(second_last_context_layer, 2)

  last_product = tf.einsum(
      "BHQD,BHKD->BHQK", blocked_query_matrix[:, :, -1],
      key_layer)  # [b, h, wm, -1] x [b, h, n, -1] ==> [b, h, wm, n]
  last_product = tf.multiply(last_product, 1.0 / np.sqrt(d))
  last_product += (1.0 - to_mask) * -10000.0
  last_attn_weights = tf.nn.softmax(last_product)  # [b, h, wm, n]
  last_context_layer = tf.einsum(
      "BHQK,BHKD->BHQD", last_attn_weights,
      value_layer)  # [b, h, wm, n] x [b, h, n, -1] ==> [b, h, wm, -1]
  last_context_layer = tf.expand_dims(last_context_layer, 2)

  context_layer = tf.concat([
      first_context_layer, second_context_layer, context_layer,
      second_last_context_layer, last_context_layer
  ], 2)
  context_layer = tf.reshape(context_layer, (-1, h, m, d)) * from_mask
  context_layer = tf.transpose(context_layer, (0, 2, 1, 3))
  return context_layer


class MultiHeadedAttentionLayer(tf.keras.layers.Layer):
  """A multi-headed attention layer.
  It implements following types of multi-headed attention:
  - original_full attention from "Attention is all you Need".
  - simulated_sparse attention from BigBird with full quadratic implemention.
  - block_sparse attention from BigBird with memory efficient linear impl.
  """

  def __init__(self,
               attention_type,
               num_attention_heads=1,
               size_per_head=512,
               num_rand_blocks=3,
               from_seq_length=1024,
               to_seq_length=1024,
               from_block_size=64,
               to_block_size=64,
               attention_probs_dropout_prob=0.0,
               initializer_range=0.02,
               use_bias=True,
               seed=None,
               query_act=None,
               key_act=None,
               value_act=None,
               name=None):
    """Constructor for a multi-headed attention layer.
    Args:
      attention_type: Type of attention, needs to be one of ['original_full',
        'simulated_sparse', 'block_sparse'].
      num_attention_heads: (optional) int. Number of attention heads.
      size_per_head: (optional) int. Size of each attention head.
      num_rand_blocks: (optional) int. Number of random chunks per row.
      from_seq_length: int. (optional) length of from sequence.
      to_seq_length: int. (optional) length of to sequence.
      from_block_size: (optional) int. size of block in from sequence.
      to_block_size: (optional) int. size of block in to sequence.
      attention_probs_dropout_prob: (optional) float. Dropout probability of the
        attention probabilities.
      initializer_range: (optional) float. Range of the weight initializer.
      use_bias: Whether the layer uses a bias vector.
      seed: (Optional) int. Reandom seed for generating random mask.
      query_act: (optional) Activation function for the query transform.
      key_act: (optional) Activation function for the key transform.
      value_act: (optional) Activation function for the value transform.
      name: The name scope of this layer.
    """
    super(MultiHeadedAttentionLayer, self).__init__(name=name)
    self.num_attention_heads = num_attention_heads
    self.size_per_head = size_per_head
    self.num_rand_blocks = num_rand_blocks
    self.from_seq_length = from_seq_length
    self.to_seq_length = to_seq_length
    self.from_block_size = from_block_size
    self.to_block_size = to_block_size
    self.seed = seed

    with tf.compat.v1.variable_scope(name):
      self.query_layer = utils.Dense3dLayer(
          num_attention_heads, size_per_head,
          utils.create_initializer(initializer_range), query_act,
          "query", head_first=True, use_bias=use_bias)

      self.key_layer = utils.Dense3dLayer(
          num_attention_heads, size_per_head,
          utils.create_initializer(initializer_range), key_act,
          "key", head_first=True, use_bias=use_bias)

      self.value_layer = utils.Dense3dLayer(
          num_attention_heads, size_per_head,
          utils.create_initializer(initializer_range), value_act,
          "value", head_first=True, use_bias=use_bias)

    if attention_type == "original_full":
      logging.info("**** Using original full attention ****")
      self.attention_dropout = recompute_grad.RecomputingDropout(
          attention_probs_dropout_prob)
      self.attn_impl = self.original_full_attention
    elif attention_type == "simulated_sparse":
      logging.info("**** Using simulated sparse attention ****")
      self.attention_dropout = lambda x, training=None: x
      self.rand_attn = self.generate_rand_attn_list()
      self.rand_block_mask = self.convert_attn_list_to_mask(self.rand_attn)
      self.attn_impl = self.bigbird_simulated_attention
    elif attention_type == "block_sparse":
      logging.info("**** Using block sparse attention ****")
      assert from_seq_length//from_block_size == to_seq_length//to_block_size, (
          "Error the number of blocks needs to be same!")
      self.attention_dropout = None
      self.rand_attn = self.generate_rand_attn_list()
      self.attn_impl = self.bigbird_block_sparse_attention
    else:
      raise NotImplementedError(
          "Attention type {} is not implemented".format(attention_type))

  def generate_rand_attn_list(self):
    # generate random attention and corresponding masks
    if self.seed is not None:
      np.random.seed(self.seed)
    # old plans used in paper
    if self.from_seq_length in [1024, 2048, 3072, 4096]:
      rand_attn = [
          bigbird_block_rand_mask(  # pylint: disable=g-complex-comprehension
              MAX_SEQ_LEN, MAX_SEQ_LEN,
              self.from_block_size, self.to_block_size, self.num_rand_blocks,
              last_idx=1024
          )[:(self.from_seq_length // self.from_block_size - 2)]
          for _ in range(self.num_attention_heads)
      ]
    else:
      plan_from_length, plan_num_rand_blocks = get_rand_attn_plan(
          self.from_seq_length, self.from_block_size, self.num_rand_blocks)
      rand_attn = bigbird_block_rand_mask_with_head(
          seq_length=self.from_seq_length,
          block_size=self.from_block_size,
          num_heads=self.num_attention_heads,
          plan_from_length=plan_from_length,
          plan_num_rand_blocks=plan_num_rand_blocks)
    rand_attn = np.stack(rand_attn, axis=0)
    return tf.constant(rand_attn, dtype=tf.int32)

  def convert_attn_list_to_mask(self, rand_attn):
    temp_mask = [
        full_bigbird_mask(  # pylint: disable=g-complex-comprehension
            self.from_seq_length, self.to_seq_length,
            self.from_block_size, self.to_block_size,
            rand_attn=rand_attn[i])
        for i in range(self.num_attention_heads)
    ]
    temp_mask = np.stack(temp_mask, axis=0)
    temp_mask = np.array(temp_mask, dtype=bool)
    rand_block_mask = tf.constant(temp_mask, dtype=tf.bool)  # [N, F, T]
    return tf.cast(rand_block_mask, tf.float32)

  def original_full_attention(self,
                              query_layer,
                              key_layer,
                              value_layer,
                              masks,
                              training=None):
    """Full quadratic attention calculation.
    Args:
      query_layer: float Tensor of shape [batch_size, num_attention_heads,
        from_seq_length, size_per_head]
      key_layer: float Tensor of shape [batch_size, num_attention_heads,
        to_seq_length, size_per_head]
      value_layer: float Tensor of shape [batch_size, num_attention_heads,
        to_seq_length, size_per_head]
      masks: a list containing float32 Tensor representing attention_mask
        of shape [batch_size, from_seq_length, to_seq_length].
        The values should be 1 or 0. The attention scores will effectively be
        set to -infinity for any positions in the mask that are 0, and
        will be unchanged for positions that are 1.
      training: Boolean indicating whether the call is training or inference.
    Returns:
      float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
        size_per_head].
    """
    attention_mask = masks[0]

    # Directly take n^2 dot product between "query" and "key".
    attention_scores = tf.einsum("BNFH,BNTH->BNFT", query_layer, key_layer)
    attention_scores = tf.multiply(attention_scores,
                                   1.0 / np.sqrt(float(self.size_per_head)))

    if attention_mask is not None:
      # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
      # masked positions, this operation will create a tensor which is 0.0 for
      # positions we want to attend and -10000.0 for masked positions.
      adder = (1.0 - attention_mask) * -10000.0

      # Since we are adding it to the raw scores before the softmax, this is
      # effectively the same as removing these entirely.
      attention_scores += adder

    # Normalize the attention scores to probabilities.
    # `attention_probs` = [B, N, F, T]
    attention_probs = tf.nn.softmax(attention_scores)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self.attention_dropout(attention_probs, training=training)

    # `context_layer` = [B, F, N, H]
    context_layer = tf.einsum("BNFT,BNTH->BFNH", attention_probs, value_layer)
    return context_layer

  def bigbird_simulated_attention(self,
                                  query_layer,
                                  key_layer,
                                  value_layer,
                                  masks,
                                  training=None):
    """BigBird attention calculation using masks in quadratic time.
    Args:
      query_layer: float Tensor of shape [batch_size, num_attention_heads,
        from_seq_length, size_per_head]
      key_layer: float Tensor of shape [batch_size, num_attention_heads,
        to_seq_length, size_per_head]
      value_layer: float Tensor of shape [batch_size, num_attention_heads,
        to_seq_length, size_per_head]
      masks: a list containing float32 Tensor representing attention_mask
        of shape [batch_size, from_seq_length, to_seq_length].
        The values should be 1 or 0. The attention scores will effectively be
        set to -infinity for any positions in the mask that are 0, and
        will be unchanged for positions that are 1.
      training: Boolean indicating whether the call is training or inference.
    Returns:
      float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
        size_per_head].
    """
    attention_mask = masks[0]
    rand_block_mask = tf.expand_dims(self.rand_block_mask, 0)  # [1, N, F, T]
    if attention_mask is not None:
      attention_mask = tf.minimum(attention_mask, rand_block_mask)
    else:
      attention_mask = rand_block_mask
    return self.original_full_attention(
        query_layer, key_layer, value_layer, [attention_mask],
        training=training)

  def bigbird_block_sparse_attention(self,
                                     query_layer,
                                     key_layer,
                                     value_layer,
                                     masks,
                                     training=None):
    """BigBird attention sparse calculation using blocks in linear time.
    Args:
      query_layer: float Tensor of shape [batch_size, num_attention_heads,
        from_seq_length, size_per_head]
      key_layer: float Tensor of shape [batch_size, num_attention_heads,
        to_seq_length, size_per_head]
      value_layer: float Tensor of shape [batch_size, num_attention_heads,
        to_seq_length, size_per_head]
      masks: A list of 5 masks used in BigBird attention at position 1 to 5.
        Position 0 (first element) is not used can be left as none. In the mask,
        the values should be 1 or 0. The attention scores will effectively
        be set to -infinity for any positions in the mask that are 0,
        and will be unchanged for positions that are 1.
           "None": Not needed.
            "band_mask": (optional) float32 Tensor of shape
              [batch_size, 1, from_seq_length//from_block_size-4,
              from_block_size, 3*to_block_size].
            "from_mask": (optional) float32 Tensor of shape
              [batch_size, 1, from_seq_length, 1].
            "to_mask": (optional) float32 Tensor of shape
              [batch_size, 1, 1, to_seq_length].
            "from_blocked_mask": (optional) float32 Tensor of shape
              [batch_size, from_seq_length//from_block_size, from_block_size].
              Same as from_mask, just reshaped.
            "to_blocked_mask": (optional) float32 Tensor of shape
              [batch_size, to_seq_length//to_block_size, to_block_size].
              Same as to_mask, just reshaped.}
      training: Boolean indicating whether the call is training or inference.
    Returns:
      float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
        size_per_head].
    """

    (_, band_mask, from_mask, to_mask,
     from_blocked_mask, to_blocked_mask) = masks

    return bigbird_block_sparse_attention(
        query_layer, key_layer, value_layer,
        band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask,
        self.rand_attn, self.num_attention_heads, self.size_per_head,
        self.num_rand_blocks, self.from_seq_length, self.to_seq_length,
        self.from_block_size, self.to_block_size)

  def call(self,
           from_tensor,
           to_tensor,
           masks,
           cache=None,
           decode_i=None,
           training=None):
    """Implements a multi-headed attention layer from from_tensor to to_tensor.
    Args:
      from_tensor: float Tensor of shape [batch_size, from_seq_length,
        from_width]
      to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
      masks: A list of masks used in different attention. Only relevant masks
        need to be supplied and at other positions place None. In the mask,
        the values should be 1 or 0. The attention scores will effectively
        be set to -infinity for any positions in the mask that are 0,
        and will be unchanged for positions that are 1.
           "attention_mask": (optional) float32 Tensor of shape
              [batch_size, from_seq_length, to_seq_length].
            "band_mask": (optional) float32 Tensor of shape
              [batch_size, 1, from_seq_length//from_block_size-4,
              from_block_size, 3*to_block_size].
            "from_mask": (optional) float32 Tensor of shape
              [batch_size, 1, from_seq_length, 1].
            "to_mask": (optional) float32 Tensor of shape
              [batch_size, 1, 1, to_seq_length].
            "from_blocked_mask": (optional) float32 Tensor of shape
              [batch_size, from_seq_length//from_block_size, from_block_size].
              Same as from_mask, just reshaped.
            "to_blocked_mask": (optional) float32 Tensor of shape
              [batch_size, to_seq_length//to_block_size, to_block_size].
              Same as to_mask, just reshaped.}
      cache: (Used during prediction) A dictionary with tensors containing
        results of previous attentions. The dictionary must have the items:
            {"k": tensor with shape
                  [batch_size, max_len, num_attention_heads, size_per_head],
             "v": tensor with shape
                  [batch_size, max_len, num_attention_heads, size_per_head]}
      decode_i: (Used during prediction) current location of decoding
      training: Boolean indicating whether the call is training or inference.
    Returns:
      float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
        size_per_head].
    Raises:
      ValueError: Any of the arguments or tensor shapes are invalid.
      NotImplementedError: For unknown attention type.
    """

    # Scalar dimensions referenced here:
    #   b = batch size (number of sequences)
    #   m = `from_tensor` sequence length
    #   n = `to_tensor` sequence length
    #   h = `num_attention_heads`
    #   d = `size_per_head`

    # `query` = [b, h, m, d]
    query = self.query_layer(from_tensor)

    # `key` = [b, h, n, d]
    key = self.key_layer(to_tensor)

    # `value_layer` = [b, h, n, d]
    value = self.value_layer(to_tensor)

    if cache is not None and decode_i is not None:
      max_len = utils.get_shape_list(cache["k"])[2]
      indices_select = tf.reshape(
          tf.one_hot(decode_i, max_len, dtype=to_tensor.dtype),
          [1, 1, max_len, 1])
      key = cache["k"] + key * indices_select
      value = cache["v"] + value * indices_select
      cache["k"] = key
      cache["v"] = value

    contextual_output = self.attn_impl(
        query, key, value, masks, training=training)

    return contextual_output


In [15]:
####################### modeling #########################
# Copyright 2021 The BigBird Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The main BigBird model and related functions."""

import copy

from absl import logging
from bigbird.core import decoder
from bigbird.core import encoder
from bigbird.core import utils
import tensorflow.compat.v2 as tf


class BertModel(tf.keras.layers.Layer):
  """BERT model ("Bidirectional Encoder Representations from Transformers").
  Example usage:
  ```python
  # Already been converted into SentencePiece token ids
  input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
  token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
  params = utils.BigBirdConfig(vocab_size=32000, hidden_size=512,
    num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
  model = modeling.BertModel(params, train=True)
  _, pooled_output = model(input_ids=input_ids, token_type_ids=token_type_ids)
  label_embeddings = tf.get_variable(...)
  logits = tf.matmul(pooled_output, label_embeddings)
  ...
  ```
  """

  def __init__(self, params):
    """Constructor for BertModel.
    Args:
      params: `BigBirdConfig` dictionary.
    """
    self.params = copy.deepcopy(params)
    self.scope = params["scope"]
    super(BertModel, self).__init__(name=self.scope)

    # validate params
    self.pad = lambda x: x
    if params["max_encoder_length"] <= 512:
      logging.info("Switching to full attention for short sequences")
      self.params["attention_type"] = "original_full"
    if self.params["attention_type"] == "simulated_sparse" or self.params[
        "attention_type"] == "block_sparse":
      if params["max_encoder_length"] % params["block_size"]:
        logging.info("Expand max_encoder_length to next multiple of block_size")
        self.params["max_encoder_length"] = (
            params["max_encoder_length"] // params["block_size"] +
            1) * params["block_size"]
        pad_size = self.params["max_encoder_length"] - params[
            "max_encoder_length"]
        paddings = [[0, 0], [0, pad_size]]
        self.pad = lambda x: tf.pad(x, paddings)

    with tf.compat.v1.variable_scope(self.scope, reuse=tf.compat.v1.AUTO_REUSE):
      self.embeder = utils.EmbeddingLayer(
          vocab_size=self.params["vocab_size"],
          emb_dim=self.params["hidden_size"],
          initializer=utils.create_initializer(
              self.params["initializer_range"]),
          scale_emb=self.params["rescale_embedding"],
          use_token_type=True,
          num_token_types=self.params["type_vocab_size"],
          use_position_embeddings=True,
          max_position_embeddings=self.params["max_position_embeddings"],
          dropout_prob=self.params["hidden_dropout_prob"])
      self.encoder = encoder.EncoderStack(self.params)
      self.pooler = utils.SimpleDenseLayer(
          input_size=self.params["hidden_size"],
          output_size=self.params["hidden_size"],
          initializer=utils.create_initializer(
              self.params["initializer_range"]),
          activation=tf.tanh,
          name="pooler/dense")

  def call(self,
           input_ids,
           token_type_ids=None,
           training=None):
    """Constructor for BertModel.
    Args:
      input_ids: int32 Tensor of shape [batch_size, seq_length].
      token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
      training: Boolean indicating whether the call is training or inference.
    Returns:
      sequence_output: Tensor of shape [batch_size, seq_length, hidden_size]
      pooled_output: Tensor of shape [batch_size, hidden_size]
    Raises:
      ValueError: The config is invalid or one of the input tensor shapes
        is invalid.
    """
    # pad if needed
    input_ids = self.pad(input_ids)

    if token_type_ids is None:
      token_type_ids = tf.zeros_like(input_ids, dtype=tf.int32)
    else:
      token_type_ids = self.pad(token_type_ids)

    # Perform embedding lookup on the word ids.
    embedding_output = self.embeder(input_ids,
                                    self.params["max_encoder_length"],
                                    token_type_ids=token_type_ids,
                                    training=training)

    # Generate mask.
    input_mask = tf.where(input_ids > 0,
                          tf.ones_like(input_ids), tf.zeros_like(input_ids))

    # Run the stacked transformer.
    sequence_output = self.encoder(embedding_output, input_mask, training)

    # The "pooler" converts the encoded sequence tensor of shape
    # [batch_size, seq_length, hidden_size] to a tensor of shape
    # [batch_size, hidden_size]. This is necessary for segment-level
    # (or segment-pair-level) classification tasks where we need a fixed
    # dimensional representation of the segment.
    first_token_tensor = sequence_output[:, 0, :]
    # We "pool" the model by simply taking the hidden state corresponding
    # to the first token. We assume that this has been pre-trained
    pooled_output = self.pooler(first_token_tensor)

    return sequence_output, pooled_output


class TransformerModel(tf.keras.layers.Layer):
  """Encoder-Decoder transformer model.
  Example usage:
  ```python
  # Already been converted into SentencePiece token ids
  input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
  target_ids = tf.constant([[43, 76, 38], [56, 8, 0]])
  params = utils.BigBirdConfig(vocab_size=32000, hidden_size=512,
    num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
  model = modeling.TransformerModel(params, train=True)
  predictions, _ = model(input_ids=input_ids, target_ids=target_ids)
  log_probs, logits, pred_ids = predictions
  ...
  ```
  """

  def __init__(self, params):
    """Constructor for TransformerModel.
    Args:
      params: `BigBirdConfig` dictionary.
    """
    self.params = copy.deepcopy(params)
    self.scope = params["scope"]
    super(TransformerModel, self).__init__(name=self.scope)

    # validate params
    self.pad = lambda x: x
    if params["max_encoder_length"] <= 512:
      logging.info("Switching to full attention for short sequences")
      self.params["attention_type"] = "original_full"
    if self.params["attention_type"] == "simulated_sparse" or self.params[
        "attention_type"] == "block_sparse":
      if params["max_encoder_length"] % params["block_size"]:
        logging.info("Expand max_encoder_length to next multiple of block_size")
        self.params["max_encoder_length"] = (
            params["max_encoder_length"] // params["block_size"] +
            1) * params["block_size"]
        pad_size = self.params["max_encoder_length"] - params[
            "max_encoder_length"]
        paddings = [[0, 0], [0, pad_size]]
        self.pad = lambda x: tf.pad(x, paddings)

    with tf.compat.v1.variable_scope(self.scope, reuse=tf.compat.v1.AUTO_REUSE):
      self.embeder = utils.EmbeddingLayer(
          vocab_size=self.params["vocab_size"],
          emb_dim=self.params["hidden_size"],
          initializer=utils.create_initializer(
              self.params["initializer_range"]),
          scale_emb=self.params["rescale_embedding"],
          use_token_type=False,
          num_token_types=None,
          use_position_embeddings=True,
          max_position_embeddings=self.params["max_position_embeddings"],
          dropout_prob=self.params["hidden_dropout_prob"])
      self.encoder = encoder.EncoderStack(self.params)
      self.decoder = decoder.DecoderStack(self.params)

  def _encode(self, input_ids, training=None):
    """Generate continuous representation for ids.
    Args:
      input_ids: Int tensor with shape [batch_size, input_length].
      training: Boolean indicating whether the call is training or inference.
    Returns:
      A float tensors of shape
          [batch_size, input_length, hidden_size].
    """
    # pad if needed
    input_ids = self.pad(input_ids)

    # Perform embedding lookup on the word ids.
    input_embs = self.embeder(
        input_ids, self.params["max_encoder_length"], training=training)

    # Generate mask.
    input_mask = tf.where(input_ids > 0,
                          tf.ones_like(input_ids), tf.zeros_like(input_ids))

    # Run the stacked transformer.
    encoder_output = self.encoder(input_embs, input_mask, training=training)

    return encoder_output, input_mask

  def _get_start_token_ids(self, tensor_for_shape):
    start_token_id = 2
    batch_size = utils.get_shape_list(tensor_for_shape)[0]
    return tf.ones([batch_size], dtype=tf.int32) * start_token_id

  def get_inputs_from_targets(self, targets, start_token_ids):
    """Converts target ids to input ids, i.e. adds <s> and removes last."""
    length = tf.math.count_nonzero(targets, axis=1, dtype=tf.int32)
    # Add start token ids.
    inputs = tf.concat([tf.expand_dims(start_token_ids, axis=1), targets], 1)
    # Remove </s> from the input.
    mask = tf.sequence_mask(length, self.params["max_decoder_length"]+1,
                            dtype=tf.int32)
    inputs = (mask * inputs)[:, :-1]
    return inputs

  def _decode(self, target_ids, target_mask, start_token_ids,
              encoder_output, encoder_mask, training=None):
    """Compute likelihood of target tokens under the model.
    Args:
      target_ids: tensor with shape [batch_size, target_length, hidden_size]
      target_mask: self-attention bias for decoder attention layer. [batch_size,
        input_length]
      start_token_ids: int32 tensor of shape [batch_size] for first decoder
        input.
      encoder_output: Continuous representation of input sequence. Float tensor
        with shape [batch_size, input_length, hidden_size].
      encoder_mask: Float tensor with shape [batch_size, input_length].
      training: Boolean indicating whether the call is training or inference.
    Returns:
      A dict containing the output ids, the output log-probs, the output logits.
    """

    # Prepare inputs to decoder layers by shifting targets, embedding ids,
    # adding positional encoding and applying dropout.
    input_ids = self.get_inputs_from_targets(target_ids, start_token_ids)

    input_embs = self.embeder(input_ids, self.params["max_decoder_length"],
                              training=training)

    outputs = self.decoder(input_embs, target_mask,
                           encoder_output, encoder_mask, training=training)

    logits = self.embeder.linear(outputs)
    output_ids = tf.cast(tf.argmax(logits, axis=-1), tf.int32)

    log_probs = -tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=target_ids, logits=logits)
    log_probs = tf.where(target_ids > 0, log_probs,
                         tf.zeros_like(log_probs, tf.float32))

    return (tf.identity(log_probs, name="log_probs"),
            tf.identity(logits, name="logits"),
            tf.cast(output_ids, tf.int32, name="pred_ids"),)

  def _init_cache(self, batch_size):
    """Initialize cache for decoding."""

    max_decode_len = self.params["max_decoder_length"]
    num_heads = self.params["num_attention_heads"]
    head_size = int(self.params["hidden_size"] / num_heads)

    cache = {}
    for layer in range(self.params["num_hidden_layers"]):
      cache["layer_%d" % layer] = {
          "k": tf.zeros([batch_size, num_heads, max_decode_len, head_size]),
          "v": tf.zeros([batch_size, num_heads, max_decode_len, head_size]),
      }
    return cache

  def _get_symbols_to_logits_fn(self, decoder_self_attention_mask):
    """Returns a decoding function that calculates logits of the next tokens."""

    max_decode_len = self.params["max_decoder_length"]

    def _symbols_to_logits_fn(target_ids, cache, i):
      """Generate logits for next candidate IDs.
      Args:
        target_ids: Current decoded sequences. int tensor with shape
          [batch_size, i + 1]
        cache: dictionary of values storing the encoder output, encoder-decoder
          attention bias, and previous decoder attention values.
        i: Loop index
      Returns:
        Tuple of
          (logits with shape [batch_size * beam_size, vocab_size],
           updated cache values)
      """
      decoder_input = tf.slice(target_ids,
                               [0, tf.maximum(tf.cast(0, i.dtype), i - 1)],
                               [target_ids.shape[0], 1])
      self_attention_mask = tf.slice(decoder_self_attention_mask, [0, 0, i, 0],
                                     [1, 1, 1, max_decode_len])

      # Preprocess decoder input by getting embeddings and adding timing signal.
      decoder_input = self.embeder(
          decoder_input, 1, start_pos=i, training=False)

      decoder_output = self.decoder(
          decoder_input, self_attention_mask,
          cache.get("encoder_output"), cache.get("encoder_mask"),
          cache=cache, decode_i=i, training=False)

      logits = self.embeder.linear(decoder_output)
      logits = tf.squeeze(logits, axis=[1])

      return logits

    return _symbols_to_logits_fn

  def _predict(self, target_ids, target_mask, start_token_ids,
               encoder_output, encoder_mask):
    """Beam decode output tokens and probabilities.
    Args:
      target_ids: tensor with shape [batch_size, target_length, hidden_size]
      target_mask: self-attention bias for decoder attention layer. [batch_size,
        input_length]
      start_token_ids: int32 tensor of shape [batch_size] for first decoder
        input.
      encoder_output: Continuous representation of input sequence. Float
        tensor with shape [batch_size, target_length, num_hidden_layers,
        hidden_size]
      encoder_mask: bias for encoder-decoder attention layer. [batch_size,
        input_length]
    Returns:
      A tuple of:
        `log_probs`: Log-probs of output tokens.
        `logits`: Logits of output tokens.
        `pred_ids`: Predicted output sequence.
    """
    batch_size = utils.get_shape_list(start_token_ids)[0]
    end_token_id = 1

    # One step logit function.
    symbols_to_logits_fn = self._get_symbols_to_logits_fn(target_mask)

    # Create cache storing decoder attention values for each layer.
    cache = self._init_cache(batch_size)

    if encoder_output is not None:
      # Add encoder output and attention bias to the cache.
      cache["encoder_output"] = encoder_output
      cache["encoder_mask"] = encoder_mask

    decoded_ids = decoder.left2right_decode(
        symbols_to_logits_fn,
        start_token_ids,
        cache,
        batch_size,
        self.params["max_decoder_length"],
        vocab_size=self.params["vocab_size"],
        beam_size=self.params["beam_size"],
        beam_start=5,
        beam_alpha=self.params["alpha"],
        beam_min=0,
        beam_max=-1,
        eos_id=end_token_id)

    # Get the top sequence for each batch element
    output_ids = tf.cast(decoded_ids, tf.int32, name="pred_ids")

    # Calculate log probs for given sequence if available.
    calc_ids = output_ids if target_ids is None else target_ids
    output_log_probs, output_logits, _ = self._decode(
        calc_ids, target_mask, start_token_ids,
        encoder_output, encoder_mask, training=False)

    return (output_log_probs, output_logits, output_ids)

  def _decode_and_predict(self, target_ids, encoder_output, encoder_mask,
                          training=None):
    """Decodes a sequence given the input and the encoder.
    Args:
      target_ids: tensor with shape [batch_size, target_length, hidden_size]
      encoder_output: Continuous representation of input sequence. Float
        tensor with shape [batch_size, target_length, num_hidden_layers,
        hidden_size]
      encoder_mask: bias for encoder-decoder attention layer. [batch_size,
        input_length]
      training: Boolean indicating whether the call is training or inference.
    Returns:
      A tuple of:
        `log_probs`: Log-probs of output tokens.
        `logits`: Logits of output tokens.
        `pred_ids`: Predicted output sequence.
    """
    # Create initial set of IDs that will be passed into symbols_to_logits_fn.
    start_token_ids = self._get_start_token_ids(encoder_output)

    # Create causal self-attention mask for decoder.
    target_mask = decoder.create_self_attention_mask(
        self.params["max_decoder_length"])

    predictions = {}
    if training:
      predictions = self._decode(target_ids, target_mask, start_token_ids,
                                 encoder_output, encoder_mask, training=True)
    else:
      predictions = self._predict(target_ids, target_mask, start_token_ids,
                                  encoder_output, encoder_mask)

    return predictions

  def call(self,
           input_ids,
           target_ids=None,
           training=None):
    # Run the inputs through the encoder layer to map the symbol
    # representations to continuous representations.
    encoder_output, encoder_mask = self._encode(input_ids, training=training)

    # Decode.
    predictions = self._decode_and_predict(target_ids, encoder_output,
                                           encoder_mask, training=training)

    return predictions, encoder_output

In [17]:
#######################  beam search #######################

# Copyright 2021 The BigBird Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Beam search branched from Pegasus.
Original source:
https://github.com/google-research/pegasus/blob/master/pegasus/layers/beam_search.py
This beam search implementation is designed for TPU usage only and prefers
flexibility over efficiency. Transformer attention caching is not enabled yet.
Mostly follows implementation in T2T. Several difference to pure beamsearch:
1. has finished and alive seqs, use 2 * beam_size to grow alive seqs,
   which makes beam_size=1 doesn't equal greedy.
2. prefers finished seq over alive seqs.
3. prefers lower indices when equal probability (though unlikely).
4. with custom length normalization and constraint.
Notations:
  B: batch_size, M: beam_size, T: max_decode_len, V: vocab_size, U: undefined
"""
# pylint: disable=invalid-name

import tensorflow.compat.v2 as tf


def length_normalization(start, alpha, min_len, max_len, out_of_range_penalty):
  r"""Create length normalization function.
  Combines length penalty from https://arxiv.org/abs/1609.08144,
  and length constraint from https://www.aclweb.org/anthology/W18-2706.pdf.
  scores = \sum_j log(P_j) / ((start + lengths)/(1 + start))**alpha
          + out_of_range_penalty * (length > max_len or length < min_len)
  Args:
    start: int, length normalization start offset.
    alpha: float, [0, 1.0],  length normalization power.
    min_len: int, minimum decode length.
    max_len: int, maximum decode lengths.
    out_of_range_penalty: float, penalty for lengths outside min len and max
      len. Use a negative number that penalize out of range decodes, does hard
      constraint if set to -inf.
  Returns:
    fn(log_probs_BxM, length)->scores_BxM: a function to normalize sum log
    probabilities of sequence with current decoding lengths.
  """

  def length_norm_fn(log_probs_BxM, length_int):
    """Normalize sum log probabilities given a sequence length."""
    dtype = log_probs_BxM.dtype
    norm_flt = tf.pow(((start + tf.cast(length_int, dtype)) / (1. + start)),
                      alpha)
    log_probs_BxM /= norm_flt
    too_short_bool = tf.less(length_int, min_len)
    too_long_bool = tf.logical_and(tf.greater(length_int, max_len), max_len > 0)
    out_of_range_bool = tf.logical_or(too_long_bool, too_short_bool)
    log_probs_BxM += out_of_range_penalty * tf.cast(out_of_range_bool, dtype)
    return log_probs_BxM

  return length_norm_fn


def beam_search(symbols_to_logits_fn,
                init_seq_BxT,
                initial_cache_BxU,
                vocab_size,
                beam_size,
                length_norm_fn,
                eos_id=1):
  """Beam search.
  Args:
    symbols_to_logits_fn: fn(seq_BxT, cache_BxU, i) -> (logits_BxV, cache_BxU)
    init_seq_BxT: initial sequence ids.
    initial_cache_BxU: dictionary of tensors with shape BxU.
    vocab_size: vocabulary size.
    beam_size: beam size.
    length_norm_fn: length normalization function.
    eos_id: end of sequence.
  Returns:
    Tuple of (beams_BxMxT, scores_BxM). Beam searched sequences and scores.
  """
  B, T = init_seq_BxT.shape
  M, V = beam_size, vocab_size
  dtype = tf.float32
  int_dtype = init_seq_BxT.dtype

  def _loop_body(i, alive_seq_BxMxT, alive_log_probs_BxM, alive_cache_BxMxU,
                 finished_seq_BxMxT, finished_scores_BxM):
    """Beam search loop body."""
    # Decode one step with beam
    logits_BMxV, cache_BMxU = symbols_to_logits_fn(
        _flatten_beam_dim(alive_seq_BxMxT),
        tf.nest.map_structure(_flatten_beam_dim, alive_cache_BxMxU), i)
    logits_BxMxV = _unflatten_beam_dim(logits_BMxV, M)
    new_cache_BxMxU = tf.nest.map_structure(lambda t: _unflatten_beam_dim(t, M),
                                            cache_BMxU)

    # select top 2 * beam_size and fill alive and finished.
    log_probs_BxMxV = logits_BxMxV - tf.reduce_logsumexp(
        logits_BxMxV, axis=2, keepdims=True)
    log_probs_BxMxV += tf.expand_dims(alive_log_probs_BxM, axis=2)
    log_probs_BxMV = tf.reshape(log_probs_BxMxV, [B, -1])
    new_log_probs_Bx2M, topk_indices_Bx2M = tf.nn.top_k(log_probs_BxMV, k=2 * M)
    topk_beam_Bx2M = topk_indices_Bx2M // V
    topk_seq_Bx2MxT, new_cache_Bx2MxU = _gather_nested(
        [alive_seq_BxMxT, new_cache_BxMxU], topk_beam_Bx2M)
    topk_ids_Bx2M = topk_indices_Bx2M % V
    new_seq_Bx2MxT = _update_i(topk_seq_Bx2MxT, topk_ids_Bx2M, i)
    new_finished_flags_Bx2M = tf.cast(
        tf.reduce_any(tf.equal(new_seq_Bx2MxT, eos_id), axis=-1), dtype)

    # get new alive
    _, topk_alive_indices_BxM = tf.nn.top_k(
        new_log_probs_Bx2M + new_finished_flags_Bx2M * dtype.min, k=M)
    (alive_seq_BxMxT, alive_log_probs_BxM, alive_cache_BxMxU) = _gather_nested(
        [new_seq_Bx2MxT, new_log_probs_Bx2M, new_cache_Bx2MxU],
        topk_alive_indices_BxM)

    # get new finished
    new_scores_Bx2M = length_norm_fn(new_log_probs_Bx2M, i + 1)
    new_scores_Bx2M += (1 - new_finished_flags_Bx2M) * dtype.min
    finished_seq_Bx3MxT = tf.concat([finished_seq_BxMxT, new_seq_Bx2MxT],
                                    axis=1)
    finished_scores_Bx3M = tf.concat([finished_scores_BxM, new_scores_Bx2M],
                                     axis=1)
    _, topk_finished_indices_BxM = tf.nn.top_k(finished_scores_Bx3M, k=M)
    (finished_seq_BxMxT, finished_scores_BxM) = _gather_nested(
        [finished_seq_Bx3MxT, finished_scores_Bx3M], topk_finished_indices_BxM)

    return [
        i + 1, alive_seq_BxMxT, alive_log_probs_BxM, alive_cache_BxMxU,
        finished_seq_BxMxT, finished_scores_BxM
    ]

  # initialize.
  init_i = tf.constant(0, dtype=int_dtype)
  init_alive_seq_BxMxT = _expand_to_beam_size(init_seq_BxT, M)
  log_probs_1xM = tf.constant([[0.] + [dtype.min] * (M - 1)], dtype=dtype)
  init_alive_log_probs_BxM = tf.tile(log_probs_1xM, [B, 1])
  init_alive_cache_BxMxU = tf.nest.map_structure(
      lambda t: _expand_to_beam_size(t, M), initial_cache_BxU)
  init_finished_seq_BxMxT = tf.zeros(tf.shape(init_alive_seq_BxMxT), int_dtype)
  init_finished_scores_BxM = tf.zeros([B, M], dtype=dtype) + dtype.min

  # run loop.
  (_, final_alive_seq_BxMxT, final_alive_scores_BxM, _,
   final_finished_seq_BxMxT, final_finished_scores_BxM) = tf.while_loop(
       lambda *args: True,  # Always do T iterations
       _loop_body,
       loop_vars=[
           init_i, init_alive_seq_BxMxT, init_alive_log_probs_BxM,
           init_alive_cache_BxMxU, init_finished_seq_BxMxT,
           init_finished_scores_BxM
       ],
       parallel_iterations=1,
       back_prop=False,
       maximum_iterations=T,
   )

  # process finished.
  final_finished_flag_BxMx1 = tf.reduce_any(
      tf.equal(final_finished_seq_BxMxT, eos_id), axis=-1, keepdims=True)
  final_seq_BxMxT = tf.where(
      tf.tile(final_finished_flag_BxMx1, [1, 1, T]), final_finished_seq_BxMxT,
      final_alive_seq_BxMxT)
  final_scores_BxM = tf.where(
      tf.squeeze(final_finished_flag_BxMx1, axis=-1), final_finished_scores_BxM,
      final_alive_scores_BxM)
  return final_seq_BxMxT, final_scores_BxM


def _update_i(tensor_BxNxT, updates_BxN, i):
  B, N, T = tensor_BxNxT.shape
  tensor_BNxT = tf.reshape(tensor_BxNxT, [-1, T])
  updates_BN = tf.reshape(updates_BxN, [-1])
  batch_BN = tf.range(B * N, dtype=tf.int32)
  i_BN = tf.fill([B * N], i)
  ind_BNx2 = tf.stack([batch_BN, i_BN], axis=-1)
  tensor_BNxT = tf.tensor_scatter_nd_update(tensor_BNxT, ind_BNx2, updates_BN)
  return tf.reshape(tensor_BNxT, [B, N, T])


def _expand_to_beam_size(tensor_BxU, beam_size):
  tensor_Bx1xU = tf.expand_dims(tensor_BxU, axis=1)
  tile_dims = [1] * tensor_Bx1xU.shape.ndims
  tile_dims[1] = beam_size
  tensor_BxMxU = tf.tile(tensor_Bx1xU, tile_dims)
  return tensor_BxMxU


def _flatten_beam_dim(tensor_BxMxU):
  shape = tensor_BxMxU.shape.as_list()
  tensor_BMxU = tf.reshape(tensor_BxMxU, [shape[0] * shape[1]] + shape[2:])
  return tensor_BMxU


def _unflatten_beam_dim(tensor_BMxU, M):
  shape = tensor_BMxU.shape.as_list()
  tensor_BxMxU = tf.reshape(tensor_BMxU, [shape[0] // M, M] + shape[1:])
  return tensor_BxMxU


def _gather_nested(nested_BxMxU, indices_BxN):

  def _gather_beam(tensor_BxMxU):
    tensor_BxNxU = tf.gather(tensor_BxMxU, indices_BxN, batch_dims=1, axis=1)
    return tensor_BxNxU

  return tf.nest.map_structure(_gather_beam, nested_BxMxU)

In [18]:
##################### decoder #######################
# Copyright 2021 The BigBird Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BigBird Decoder Layers."""

from bigbird.core import attention
from bigbird.core import beam_search
from bigbird.core import recompute_grad
from bigbird.core import utils
import tensorflow.compat.v2 as tf


class PrenormDecoderLayer(tf.keras.layers.Layer):
  """Decoder layer of a transformer in Pegasus style.
  The layer_norm is taken before self-attention.
  """

  def __init__(self,
               hidden_size=768,
               intermediate_size=3072,
               intermediate_act_fn=utils.gelu,
               attention_probs_dropout_prob=0.0,
               hidden_dropout_prob=0.1,
               initializer_range=0.02,
               num_attention_heads=12,
               use_bias=True,
               name=None):
    """Constructor of a decoder layer of a transformer in Pegasus style.
    Args:
      hidden_size: (optional) int. Size of hidden dimension.
      intermediate_size: (optional) int. Size of intermediate dimension.
      intermediate_act_fn: optional) Activation function for intermediate layer.
      attention_probs_dropout_prob: (optional) float. Dropout probability of the
        attention probabilities.
      hidden_dropout_prob: (optional) float. Dropout probability of the
        attention.
      initializer_range: (optional) float. Range of the weight initializer.
      num_attention_heads: (optional) int. Number of attention heads.
      use_bias: (optional) bool. Whether key/query/value uses a bias vector.
      name: The name scope of this layer.
    """
    super(PrenormDecoderLayer, self).__init__(name=name)

    with tf.compat.v1.variable_scope(name):

      attention_head_size = hidden_size // num_attention_heads
      with tf.compat.v1.variable_scope("attention"):
        # Pre-Normalization layer
        with tf.compat.v1.variable_scope("self"):
          self.first_layer_norm = utils.NormLayer(hidden_size)
        # Self-Attention layer
        self.self_attn_layer = attention.MultiHeadedAttentionLayer(
            "original_full", use_bias=use_bias, name="self",
            num_attention_heads=num_attention_heads,
            size_per_head=attention_head_size,
            initializer_range=initializer_range,
            attention_probs_dropout_prob=attention_probs_dropout_prob)
        # Feedforward layer
        with tf.compat.v1.variable_scope("output"):
          self.self_proj_layer = utils.Dense3dProjLayer(
              num_attention_heads, attention_head_size,
              utils.create_initializer(initializer_range), None,
              "dense", use_bias)
        # Dropout
        self.self_attn_dropout = recompute_grad.RecomputingDropout(
            hidden_dropout_prob)
        # Pre-Normalization layer
        with tf.compat.v1.variable_scope("encdec"):
          self.second_layer_norm = utils.NormLayer(hidden_size)
        # Cross-Attention layer
        self.cross_attn_layer = attention.MultiHeadedAttentionLayer(
            "original_full", use_bias=use_bias, name="encdec",
            num_attention_heads=num_attention_heads,
            size_per_head=attention_head_size,
            initializer_range=initializer_range,
            attention_probs_dropout_prob=attention_probs_dropout_prob)
        # Feedforward layer
        with tf.compat.v1.variable_scope("encdec_output"):
          self.cross_proj_layer = utils.Dense3dProjLayer(
              num_attention_heads, attention_head_size,
              utils.create_initializer(initializer_range), None,
              "dense", use_bias)
        # Dropout
        self.cross_attn_dropout = recompute_grad.RecomputingDropout(
            hidden_dropout_prob)

      with tf.compat.v1.variable_scope("intermediate"):
        # Normalization layer
        self.third_layer_norm = utils.NormLayer(hidden_size)
        # Feedforward layer
        self.expand_layer = utils.Dense2dLayer(
            hidden_size, intermediate_size,
            utils.create_initializer(initializer_range),
            intermediate_act_fn, "dense")

      with tf.compat.v1.variable_scope("output"):
        # Feedforward layer
        self.contract_layer = utils.Dense2dLayer(
            intermediate_size, hidden_size,
            utils.create_initializer(initializer_range),
            None, "dense")
        # Dropout
        self.output_dropout = recompute_grad.RecomputingDropout(
            hidden_dropout_prob)

  def call(self,
           layer_input,
           encoder_outputs,
           self_attention_mask,
           attention_mask,
           cache=None,
           decode_i=None,
           training=None):
    """Implements a decoder layer of a transformer in Pegasus style.
    The layer_norm is taken after self-attention.
    Args:
      layer_input: float Tensor of shape [batch_size, seq_length, hidden_size].
      encoder_outputs: tensors with shape [batch_size, input_length,
          num_hidden_layers, hidden_size]
      self_attention_mask: bias for decoder self-attention layer. [1, 1,
        target_length, target_length]
      attention_mask: bias for encoder-decoder attention layer. [batch_size, 1,
        1, input_length]
      cache: (Used during prediction) A dictionary with tensors containing
        results of previous attentions. The dictionary must have the items:
            {"k": tensor with shape
                  [batch_size, max_len, num_attention_heads, size_per_head],
             "v": tensor with shape
                  [batch_size, max_len, num_attention_heads, size_per_head]}
      decode_i: (Used during prediction) current location of decoding
      training: Boolean indicating whether the call is training or inference.
    Returns:
      float Tensor of shape [batch_size, seq_length, hidden_size].
    Raises:
      ValueError: Any of the arguments or tensor shapes are invalid.
      NotImplementedError: For unknown attention type.
    """
    # self-attention
    normalized_layer_input = self.first_layer_norm(layer_input)
    self_attention_output = self.self_attn_layer(
        normalized_layer_input, normalized_layer_input, [self_attention_mask],
        cache=cache, decode_i=decode_i, training=training)

    # Run a linear projection of `hidden_size` then add a residual
    # with `layer_input`.
    self_attention_output = self.self_proj_layer(self_attention_output)
    self_attention_output = self.self_attn_dropout(self_attention_output,
                                                   training=training)
    self_attention_output = self_attention_output + layer_input

    # Cross-attention
    normalized_self_attention_output = self.second_layer_norm(
        self_attention_output)
    attention_output = self.cross_attn_layer(
        normalized_self_attention_output, encoder_outputs, [attention_mask],
        training=training)

    # Run a linear projection of `hidden_size` then add a residual
    # with `layer_input`.
    attention_output = self.cross_proj_layer(attention_output)
    attention_output = self.cross_attn_dropout(attention_output,
                                               training=training)
    attention_output = attention_output + self_attention_output

    # The activation is only applied to the "intermediate" hidden layer.
    normalized_attention_output = self.third_layer_norm(attention_output)
    intermediate_output = self.expand_layer(normalized_attention_output)

    # Down-project back to `hidden_size` then add the residual.
    layer_output = self.contract_layer(intermediate_output)
    layer_output = self.output_dropout(layer_output, training=training)
    layer_output = layer_output + attention_output
    return layer_output


class PostnormDecoderLayer(tf.keras.layers.Layer):
  """Decoder layer of a transformer in BERT style.
  The layer_norm is taken before self-attention.
  """

  def __init__(self,
               hidden_size=768,
               intermediate_size=3072,
               intermediate_act_fn=utils.gelu,
               attention_probs_dropout_prob=0.0,
               hidden_dropout_prob=0.1,
               initializer_range=0.02,
               num_attention_heads=12,
               use_bias=True,
               name=None):
    """Constructor of a decoder layer of a transformer in BERT style.
    Args:
      hidden_size: (optional) int. Size of hidden dimension.
      intermediate_size: (optional) int. Size of intermediate dimension.
      intermediate_act_fn: optional) Activation function for intermediate layer.
      attention_probs_dropout_prob: (optional) float. Dropout probability of the
        attention probabilities.
      hidden_dropout_prob: (optional) float. Dropout probability of the
        attention.
      initializer_range: (optional) float. Range of the weight initializer.
      num_attention_heads: (optional) int. Number of attention heads.
      use_bias: (optional) bool. Whether key/query/value uses a bias vector.
      name: The name scope of this layer.
    """
    super(PostnormDecoderLayer, self).__init__(name=name)

    with tf.compat.v1.variable_scope(name):

      attention_head_size = hidden_size // num_attention_heads
      with tf.compat.v1.variable_scope("attention"):
        # Self-Attention layers
        self.self_attn_layer = attention.MultiHeadedAttentionLayer(
            "original_full", use_bias=use_bias, name="self",
            num_attention_heads=num_attention_heads,
            size_per_head=attention_head_size,
            initializer_range=initializer_range,
            attention_probs_dropout_prob=attention_probs_dropout_prob)

        with tf.compat.v1.variable_scope("output"):
          # Feedforward layer
          self.self_proj_layer = utils.Dense3dProjLayer(
              num_attention_heads, attention_head_size,
              utils.create_initializer(initializer_range), None,
              "dense", use_bias)
          # Post-Normalization layer
          self.first_layer_norm = utils.NormLayer(hidden_size)
          # Dropout
          self.self_attn_dropout = recompute_grad.RecomputingDropout(
              hidden_dropout_prob)

        # Cross-Attention layers
        self.cross_attn_layer = attention.MultiHeadedAttentionLayer(
            "original_full", use_bias=use_bias, name="encdec",
            num_attention_heads=num_attention_heads,
            size_per_head=attention_head_size,
            initializer_range=initializer_range,
            attention_probs_dropout_prob=attention_probs_dropout_prob)

        with tf.compat.v1.variable_scope("encdec_output"):
          # Feedforward layer
          self.cross_proj_layer = utils.Dense3dProjLayer(
              num_attention_heads, attention_head_size,
              utils.create_initializer(initializer_range), None,
              "dense", use_bias)
          # Post-Normalization layer
          self.second_layer_norm = utils.NormLayer(hidden_size)
          # Dropout
          self.cross_attn_dropout = recompute_grad.RecomputingDropout(
              hidden_dropout_prob)

      with tf.compat.v1.variable_scope("intermediate"):
        # Feedforward layer
        self.expand_layer = utils.Dense2dLayer(
            hidden_size, intermediate_size,
            utils.create_initializer(initializer_range),
            intermediate_act_fn, "dense")

      with tf.compat.v1.variable_scope("output"):
        # Feedforward layer
        self.contract_layer = utils.Dense2dLayer(
            intermediate_size, hidden_size,
            utils.create_initializer(initializer_range),
            None, "dense")
        # Normalization layer
        self.third_layer_norm = utils.NormLayer(hidden_size)
        # Dropout
        self.output_dropout = recompute_grad.RecomputingDropout(
            hidden_dropout_prob)

  def call(self,
           layer_input,
           encoder_outputs,
           self_attention_mask,
           attention_mask,
           cache=None,
           decode_i=None,
           training=None):
    """Implements a decoder layer of a transformer in BERT style.
    The layer_norm is taken after self-attention.
    Args:
      layer_input: float Tensor of shape [batch_size, seq_length, hidden_size].
      encoder_outputs: tensors with shape [batch_size, input_length,
          num_hidden_layers, hidden_size]
      self_attention_mask: bias for decoder self-attention layer. [1, 1,
        target_length, target_length]
      attention_mask: bias for encoder-decoder attention layer. [batch_size, 1,
        1, input_length]
      cache: (Used during prediction) A dictionary with tensors containing
        results of previous attentions. The dictionary must have the items:
            {"k": tensor with shape
                  [batch_size, max_len, num_attention_heads, size_per_head],
             "v": tensor with shape
                  [batch_size, max_len, num_attention_heads, size_per_head]}
      decode_i: (Used during prediction) current location of decoding
      training: Boolean indicating whether the call is training or inference.
    Returns:
      float Tensor of shape [batch_size, seq_length, hidden_size].
    Raises:
      ValueError: Any of the arguments or tensor shapes are invalid.
      NotImplementedError: For unknown attention type.
    """
    # self-attention
    self_attention_output = self.self_attn_layer(
        layer_input, layer_input, [self_attention_mask],
        cache=cache, decode_i=decode_i, training=training)

    # Run a linear projection of `hidden_size` then add a residual
    # with `layer_input`.
    self_attention_output = self.self_proj_layer(self_attention_output)
    self_attention_output = self.self_attn_dropout(self_attention_output,
                                                   training=training)
    self_attention_output = self.first_layer_norm(
        self_attention_output + layer_input)

    # cross-attention
    attention_output = self.cross_attn_layer(
        self_attention_output, encoder_outputs, [attention_mask],
        training=training)

    # Run a linear projection of `hidden_size` then add a residual
    # with `layer_input`.
    attention_output = self.cross_proj_layer(attention_output)
    attention_output = self.cross_attn_dropout(attention_output,
                                               training=training)
    attention_output = self.second_layer_norm(
        attention_output + self_attention_output)

    # The activation is only applied to the "intermediate" hidden layer.
    intermediate_output = self.expand_layer(attention_output)

    # Down-project back to `hidden_size` then add the residual.
    layer_output = self.contract_layer(intermediate_output)
    layer_output = self.output_dropout(layer_output, training=training)
    layer_output = self.third_layer_norm(layer_output + attention_output)
    return layer_output


def add_gradient_recomputation(original_class):
  """Creats a subclass which enables gradient checkpointing."""

  class RecomputeLayer(original_class):
    """Transformer layer that recomputes the forward pass during backprop."""

    def call(self,
             layer_input,
             encoder_outputs,
             self_attention_mask,
             attention_mask,
             cache=None,
             decode_i=None,
             training=None):

      def f(layer_input, encoder_outputs):
        x = super(RecomputeLayer, self).call(
            layer_input, encoder_outputs, self_attention_mask, attention_mask,
            cache, decode_i, training=training)
        return x

      f = recompute_grad.recompute_grad(f)

      return f(layer_input, encoder_outputs)
  return RecomputeLayer


class DecoderStack(tf.keras.layers.Layer):
  """Transformer decoder stack."""

  def __init__(self, params):
    if params["couple_encoder_decoder"]:
      name = "encoder"
      super(DecoderStack, self).__init__(name=name)
    else:
      name = "decoder"
      super(DecoderStack, self).__init__(name=name)

    self.params = params

    if params["norm_type"] == "prenorm":
      decoder_class = PrenormDecoderLayer
    elif params["norm_type"] == "postnorm":
      decoder_class = PostnormDecoderLayer
    else:
      raise NotImplementedError(
          "Norm type {} is not implemented".format(params["norm_type"]))

    if params["use_gradient_checkpointing"]:
      decoder_class = add_gradient_recomputation(decoder_class)

    if self.params.get("num_decoder_layers", None) is not None:
      num_hidden_layers = self.params["num_decoder_layers"]
    else:
      num_hidden_layers = self.params["num_hidden_layers"]

    with tf.compat.v1.variable_scope(name):
      # Decoder layers
      self.decoder_layers = [
          decoder_class(  # pylint: disable=g-complex-comprehension
              self.params["hidden_size"],
              self.params["intermediate_size"],
              utils.get_activation(self.params["hidden_act"]),
              self.params["attention_probs_dropout_prob"],
              self.params["hidden_dropout_prob"],
              self.params["initializer_range"],
              self.params["num_attention_heads"],
              self.params["use_bias"],
              name="layer_%d" % layer_idx)
          for layer_idx in range(num_hidden_layers)
      ]

      # Normalization layer
      self.layer_norm = utils.NormLayer(self.params["hidden_size"])

  def call(self,
           decoder_inputs,
           self_attention_mask,
           encoder_outputs,
           encoder_mask,
           cache=None,
           decode_i=None,
           training=None):
    """Return the output of the decoder layer stacks.
    Args:
      decoder_inputs: tensor with shape
        [batch_size, target_length, hidden_size]
      self_attention_mask: bias for decoder self-attention layer. [1, 1,
        target_length, target_length]
      encoder_outputs: tensors with shape [batch_size, input_length,
        hidden_size]
      encoder_mask: bias for encoder-decoder attention layer. [batch_size,
        input_length]
      cache: (Used during prediction) A dictionary with tensors containing
        results of previous attentions. The dictionary must have the items:
            {"k": tensor with shape
                  [batch_size, max_len, num_attention_heads, size_per_head],
             "v": tensor with shape
                  [batch_size, max_len, num_attention_heads, size_per_head]}
      decode_i: (Used during prediction) current location of decoding.
      training: Boolean indicating whether the call is training or inference.
    Returns:
      Output of decoder layer stack. A float32 tensor with shape [batch_size,
        target_length, hidden_size]
    """
    # Expand encoder mask to broadcast over num heads and from_seq axis
    attention_mask = tf.expand_dims(tf.expand_dims(encoder_mask, 1), 1)
    attention_mask = tf.cast(attention_mask, tf.float32)

    if self.params["norm_type"] == "postnorm":
      decoder_inputs = self.layer_norm(decoder_inputs)

    layer_output = decoder_inputs
    for layer in self.decoder_layers:
      layer_cache = cache[layer.name] if cache is not None else None
      layer_output = layer(
          layer_output, encoder_outputs, self_attention_mask, attention_mask,
          layer_cache, decode_i, training=training)

    if self.params["norm_type"] == "prenorm":
      layer_output = self.layer_norm(layer_output)

    return layer_output


def create_self_attention_mask(length):
  with tf.name_scope("decoder_self_attention_mask"):
    valid_locs = tf.linalg.band_part(tf.ones([length, length]), -1, 0)
    valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
  return valid_locs


def inplace_update_i(inp_tensor, updates, i):
  """Inplace update a tensor. B: batch_size, L: tensor length."""
  batch_size = inp_tensor.shape[0]
  indices = tf.stack([
      tf.range(batch_size, dtype=tf.int32),
      tf.fill([batch_size], tf.cast(i, tf.int32))
  ], axis=-1)
  return tf.tensor_scatter_nd_update(inp_tensor, indices, updates)


# pylint: disable=invalid-name
def left2right_decode(symbols_to_logits_fn,
                      start_symbols,
                      context_BxU_dict,
                      batch_size,
                      max_decode_len,
                      vocab_size,
                      beam_size=1,
                      beam_start=5,
                      beam_alpha=0.6,
                      beam_min=0,
                      beam_max=-1,
                      eos_id=1):
  """left to right decode.
  Notations:
    B: batch_size, V: vocab_size, T: decode_len, U: undefined dimensions
  Args:
    symbols_to_logits_fn: logits = fn(decodes, context, i). Shoud take
      [batch_size, decoded_ids] and return [batch_size, vocab_size].
    start_symbols: starting ids [batch_size]
    context_BxU_dict: dict of Tensors.
    batch_size: int, decode batch size.
    max_decode_len: int, maximum number of steps to decode.
    vocab_size: int, output vocab size.
    beam_size: Number of beams to decode.
    beam_start: start length for scaling, default to 5.
    beam_alpha: Length penalty for decoding. Should be between 0 (shorter) and 1
      (longer), default to 0.6.
    beam_min: Minimum beam search lengths.
    beam_max: Maximum beam search lengths. Set -1 to use unlimited.
    eos_id: end of token id, default to 1.
  Returns:
    decodes: Tensor[batch, decode_len]
  """
  dtype = tf.int32
  start_symbols = tf.expand_dims(start_symbols, 1)
  # When beam_size=1, beam_search does not behave exactly like greedy.
  # This is due to using 2 * beam_size in grow_topk, and keep the top beam_size
  # ones that haven't reached EOS into alive.
  # In this case, alpha value for length penalty will take effect.
  if beam_size == 1:

    def decode_loop(i, decodes_BxT, cache_BxU_dict):
      logits_BxV = symbols_to_logits_fn(decodes_BxT, cache_BxU_dict, i)
      decodes_BxT = inplace_update_i(
          decodes_BxT, tf.argmax(logits_BxV, -1, output_type=tf.int32), i)
      return i + 1, decodes_BxT, cache_BxU_dict

    def loop_cond(i, decodes_BxT, unused_cache_BxU_dict):
      finished_B = tf.reduce_any(tf.equal(decodes_BxT, eos_id), axis=1)
      return tf.logical_and(i < max_decode_len,
                            tf.logical_not(tf.reduce_all(finished_B)))

    init_dec_BxT = tf.concat([tf.cast(start_symbols, dtype=dtype),
                              tf.zeros([batch_size, max_decode_len-1],
                                       dtype=dtype)], axis=1)
    _, decodes, _ = tf.while_loop(
        loop_cond, decode_loop,
        [tf.constant(0, dtype=dtype), init_dec_BxT, context_BxU_dict])
    return decodes

  else:

    def symbols_to_logits_fn_with_sampling(decodes_BxT, states_BxU_dict, i):
      logits_BxV = symbols_to_logits_fn(decodes_BxT, states_BxU_dict, i)
      return logits_BxV, states_BxU_dict

    length_norm_fn = beam_search.length_normalization(beam_start, beam_alpha,
                                                      beam_min, beam_max, -1e3)

    init_dec_BxT = tf.concat([tf.cast(start_symbols, dtype=tf.int32),
                              tf.zeros([batch_size, max_decode_len-1],
                                       dtype=tf.int32)], axis=1)

    beams, _ = beam_search.beam_search(
        symbols_to_logits_fn_with_sampling,
        init_dec_BxT,
        context_BxU_dict, vocab_size, beam_size, length_norm_fn, eos_id)
    return beams[:, 0, :]

In [19]:
########################### encoder ############################
# Copyright 2021 The BigBird Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BigBird Encoder Layers."""

from bigbird.core import attention
from bigbird.core import recompute_grad
from bigbird.core import utils
import tensorflow.compat.v2 as tf


class PrenormEncoderLayer(tf.keras.layers.Layer):
  """Encoder layer of a transformer in Pegasus style.
  The layer_norm is taken before self-attention.
  """

  def __init__(self,
               attention_type,
               hidden_size=768,
               intermediate_size=3072,
               intermediate_act_fn=utils.gelu,
               attention_probs_dropout_prob=0.0,
               hidden_dropout_prob=0.1,
               initializer_range=0.02,
               num_attention_heads=12,
               num_rand_blocks=3,
               seq_length=1024,
               block_size=64,
               use_bias=True,
               seed=None,
               name=None):
    """Constructor of an encoder layer of a transformer in Pegasus style.
    Args:
      attention_type: Type of attention, needs to be one of ['original_full',
        'simulated_sparse', 'block_sparse'].
      hidden_size: (optional) int. Size of hidden dimension.
      intermediate_size: (optional) int. Size of intermediate dimension.
      intermediate_act_fn: optional) Activation function for intermediate layer.
      attention_probs_dropout_prob: (optional) float. Dropout probability of the
        attention probabilities.
      hidden_dropout_prob: (optional) float. Dropout probability of the
        attention.
      initializer_range: (optional) float. Range of the weight initializer.
      num_attention_heads: (optional) int. Number of attention heads.
      num_rand_blocks: (optional) int. Number of random chunks per row.
      seq_length: (optional) int. length of sequence.
      block_size: (optional) int. size of block in sequence.
      use_bias: (optional) bool. Whether key/query/value uses a bias vector.
      seed: (Optional) int. Reandom seed for generating random mask.
      name: The name scope of this layer.
    """
    super(PrenormEncoderLayer, self).__init__(name=name)

    with tf.compat.v1.variable_scope(name):

      attention_head_size = hidden_size // num_attention_heads
      with tf.compat.v1.variable_scope("attention"):
        # Pre-Normalization layer
        with tf.compat.v1.variable_scope("self"):
          self.first_layer_norm = utils.NormLayer(hidden_size)
        # Self-Attention layer
        self.attn_layer = attention.MultiHeadedAttentionLayer(
            attention_type, num_attention_heads, attention_head_size,
            num_rand_blocks, seq_length, seq_length, block_size, block_size,
            attention_probs_dropout_prob, initializer_range, use_bias,
            seed, name="self")
        # Feedforward layer
        with tf.compat.v1.variable_scope("output"):
          self.projection_layer = utils.Dense3dProjLayer(
              num_attention_heads, attention_head_size,
              utils.create_initializer(initializer_range), None,
              "dense", use_bias)
        # Dropout
        self.attention_dropout = recompute_grad.RecomputingDropout(
            hidden_dropout_prob)

      with tf.compat.v1.variable_scope("intermediate"):
        # Normalization layer
        self.second_layer_norm = utils.NormLayer(hidden_size)
        # Feedforward layer
        self.expand_layer = utils.Dense2dLayer(
            hidden_size, intermediate_size,
            utils.create_initializer(initializer_range),
            intermediate_act_fn, "dense")
      with tf.compat.v1.variable_scope("output"):
        # Feedforward layer
        self.contract_layer = utils.Dense2dLayer(
            intermediate_size, hidden_size,
            utils.create_initializer(initializer_range),
            None, "dense")
        # Dropout
        self.output_dropout = recompute_grad.RecomputingDropout(
            hidden_dropout_prob)

  def call(self,
           layer_input,
           attention_mask=None,
           band_mask=None,
           from_mask=None,
           to_mask=None,
           input_blocked_mask=None,
           training=None):
    """Implements a encoder layer of a transformer in Pegasus style.
    Args:
      layer_input: float Tensor of shape [batch_size, seq_length, hidden_size].
      attention_mask: (optional) float32 Tensor of shape [batch_size,
        seq_length, seq_length]. The values should be 1 or 0. The
        attention scores will effectively be set to -infinity for any positions
        in the mask that are 0, and will be unchanged for positions that are 1.
      band_mask: (optional) float32 Tensor of shape [batch_size, 1,
        seq_length//block_size-4, block_size, 3*block_size].
        The values should be 1 or 0. The attention scores will effectively be
        set to -infinity for any positions in the mask that are 0, and will be
        unchanged for positions that are 1.
      from_mask: (optional) float32 Tensor of shape [batch_size, 1,
        seq_length, 1]. The values should be 1 or 0. The
        attention scores will effectively be set to -infinity for any positions
        in the mask that are 0, and will be unchanged for positions that are 1.
      to_mask: (optional) float32 Tensor of shape [batch_size, 1, 1,
        seq_length]. The values should be 1 or 0. The
        attention scores will effectively be set to -infinity for any positions
        in the mask that are 0, and will be unchanged for positions that are 1.
      input_blocked_mask: (optional) float32 Tensor of shape [batch_size,
        seq_length//block_size, block_size]. Same as from/to_mask, just
        reshaped.
      training: Boolean indicating whether the call is training or inference.
    Returns:
      float Tensor of shape [batch_size, seq_length, hidden_size].
    Raises:
      ValueError: Any of the arguments or tensor shapes are invalid.
      NotImplementedError: For unknown attention type.
    """
    # self-attention
    normalized_layer_input = self.first_layer_norm(layer_input)
    attention_output = self.attn_layer(
        normalized_layer_input, normalized_layer_input, [
            attention_mask, band_mask, from_mask, to_mask, input_blocked_mask,
            input_blocked_mask
        ], training=training)

    # Run a linear projection of `hidden_size` then add a residual
    # with `layer_input`.
    attention_output = self.projection_layer(attention_output)
    attention_output = self.attention_dropout(attention_output,
                                              training=training)
    attention_output = attention_output + layer_input

    # The activation is only applied to the "intermediate" hidden layer.
    normalized_attention_output = self.second_layer_norm(attention_output)
    intermediate_output = self.expand_layer(normalized_attention_output)

    # Down-project back to `hidden_size` then add the residual.
    layer_output = self.contract_layer(intermediate_output)
    layer_output = self.output_dropout(layer_output, training=training)
    layer_output = layer_output + attention_output
    return layer_output


class PostnormEncoderLayer(tf.keras.layers.Layer):
  """Encoder layer of a transformer in BERT style.
  The layer_norm is taken after self-attention.
  """

  def __init__(self,
               attention_type,
               hidden_size=768,
               intermediate_size=3072,
               intermediate_act_fn=utils.gelu,
               attention_probs_dropout_prob=0.0,
               hidden_dropout_prob=0.1,
               initializer_range=0.02,
               num_attention_heads=12,
               num_rand_blocks=3,
               seq_length=1024,
               block_size=64,
               use_bias=True,
               seed=None,
               name=None):
    """Constructor of an encoder layer of a transformer in BERT style.
    Args:
      attention_type: Type of attention, needs to be one of ['original_full',
        'simulated_sparse', 'block_sparse'].
      hidden_size: (optional) int. Size of hidden dimension.
      intermediate_size: (optional) int. Size of intermediate dimension.
      intermediate_act_fn: optional) Activation function for intermediate layer.
      attention_probs_dropout_prob: (optional) float. Dropout probability of the
        attention probabilities.
      hidden_dropout_prob: (optional) float. Dropout probability of the
        attention.
      initializer_range: (optional) float. Range of the weight initializer.
      num_attention_heads: (optional) int. Number of attention heads.
      num_rand_blocks: (optional) int. Number of random chunks per row.
      seq_length: (optional) int. length of sequence.
      block_size: (optional) int. size of block in sequence.
      use_bias: (optional) bool. Whether key/query/value uses a bias vector.
      seed: (Optional) int. Reandom seed for generating random mask.
      name: The name scope of this layer.
    """
    super(PostnormEncoderLayer, self).__init__(name=name)

    with tf.compat.v1.variable_scope(name):

      attention_head_size = hidden_size // num_attention_heads
      with tf.compat.v1.variable_scope("attention"):
        # Self-Attention layer
        self.attn_layer = attention.MultiHeadedAttentionLayer(
            attention_type, num_attention_heads, attention_head_size,
            num_rand_blocks, seq_length, seq_length, block_size, block_size,
            attention_probs_dropout_prob, initializer_range, use_bias,
            seed, name="self")

        with tf.compat.v1.variable_scope("output"):
          # Feedforward layer
          self.projection_layer = utils.Dense3dProjLayer(
              num_attention_heads, attention_head_size,
              utils.create_initializer(initializer_range), None,
              "dense", use_bias)
          # Post-Normalization layer
          self.first_layer_norm = utils.NormLayer(hidden_size)
          # Dropout
          self.attention_dropout = recompute_grad.RecomputingDropout(
              hidden_dropout_prob)

      with tf.compat.v1.variable_scope("intermediate"):
        # Feedforward layer
        self.expand_layer = utils.Dense2dLayer(
            hidden_size, intermediate_size,
            utils.create_initializer(initializer_range),
            intermediate_act_fn, "dense")

      with tf.compat.v1.variable_scope("output"):
        # Feedforward layer
        self.contract_layer = utils.Dense2dLayer(
            intermediate_size, hidden_size,
            utils.create_initializer(initializer_range),
            None, "dense")
        # Normalization layer
        self.second_layer_norm = utils.NormLayer(hidden_size)
        # Dropout
        self.output_dropout = recompute_grad.RecomputingDropout(
            hidden_dropout_prob)

  def call(self,
           layer_input,
           attention_mask=None,
           band_mask=None,
           from_mask=None,
           to_mask=None,
           input_blocked_mask=None,
           training=None):
    """Implements a encoder layer of a transformer in BERT style.
    Args:
      layer_input: float Tensor of shape [batch_size, seq_length, hidden_size].
      attention_mask: (optional) float32 Tensor of shape [batch_size,
        seq_length, seq_length]. The values should be 1 or 0. The
        attention scores will effectively be set to -infinity for any positions
        in the mask that are 0, and will be unchanged for positions that are 1.
      band_mask: (optional) float32 Tensor of shape [batch_size, 1,
        seq_length//block_size-4, block_size, 3*block_size].
        The values should be 1 or 0. The attention scores will effectively be
        set to -infinity for any positions in the mask that are 0, and will be
        unchanged for positions that are 1.
      from_mask: (optional) float32 Tensor of shape [batch_size, 1,
        seq_length, 1]. The values should be 1 or 0. The
        attention scores will effectively be set to -infinity for any positions
        in the mask that are 0, and will be unchanged for positions that are 1.
      to_mask: (optional) float32 Tensor of shape [batch_size, 1, 1,
        seq_length]. The values should be 1 or 0. The
        attention scores will effectively be set to -infinity for any positions
        in the mask that are 0, and will be unchanged for positions that are 1.
      input_blocked_mask: (optional) float32 Tensor of shape [batch_size,
        seq_length//block_size, block_size]. Same as from/to_mask, just
        reshaped.
      training: Boolean indicating whether the call is training or inference.
    Returns:
      float Tensor of shape [batch_size, seq_length, hidden_size].
    Raises:
      ValueError: Any of the arguments or tensor shapes are invalid.
      NotImplementedError: For unknown attention type.
    """
    # self-attention
    attention_output = self.attn_layer(
        layer_input, layer_input, [
            attention_mask, band_mask, from_mask, to_mask, input_blocked_mask,
            input_blocked_mask
        ], training=training)

    # Run a linear projection of `hidden_size` then add a residual
    # with `layer_input`.
    attention_output = self.projection_layer(attention_output)
    attention_output = self.attention_dropout(attention_output,
                                              training=training)
    attention_output = self.first_layer_norm(attention_output + layer_input)

    # The activation is only applied to the "intermediate" hidden layer.
    intermediate_output = self.expand_layer(attention_output)

    # Down-project back to `hidden_size` then add the residual.
    layer_output = self.contract_layer(intermediate_output)
    layer_output = self.output_dropout(layer_output, training=training)
    layer_output = self.second_layer_norm(layer_output + attention_output)
    return layer_output


def add_gradient_recomputation(original_class):
  """Creats a subclass which enables gradient checkpointing."""

  class RecomputeLayer(original_class):
    """Transformer layer that recomputes the forward pass during backprop."""

    def call(self,
             layer_input,
             attention_mask=None,
             band_mask=None,
             from_mask=None,
             to_mask=None,
             input_blocked_mask=None,
             training=None):
      def f(layer_input, attention_mask, band_mask,
            from_mask, to_mask, input_blocked_mask):
        x = super(RecomputeLayer, self).call(
            layer_input, attention_mask, band_mask, from_mask, to_mask,
            input_blocked_mask, training=training)
        return x

      f = recompute_grad.recompute_grad(f)

      return f(layer_input, attention_mask, band_mask,
               from_mask, to_mask, input_blocked_mask)
  return RecomputeLayer


class EncoderStack(tf.keras.layers.Layer):
  """Transformer encoder stack."""

  def __init__(self, params):
    name = "encoder"
    super(EncoderStack, self).__init__(name=name)
    self.params = params

    if params["norm_type"] == "prenorm":
      encoder_class = PrenormEncoderLayer
    elif params["norm_type"] == "postnorm":
      encoder_class = PostnormEncoderLayer
    else:
      raise NotImplementedError(
          "Norm type {} is not implemented".format(params["norm_type"]))

    if params["use_gradient_checkpointing"]:
      encoder_class = add_gradient_recomputation(encoder_class)

    with tf.compat.v1.variable_scope(name):
      # Encoder layers
      self.encoder_layers = [
          encoder_class(  # pylint: disable=g-complex-comprehension
              self.params["attention_type"],
              self.params["hidden_size"],
              self.params["intermediate_size"],
              utils.get_activation(self.params["hidden_act"]),
              self.params["attention_probs_dropout_prob"],
              self.params["hidden_dropout_prob"],
              self.params["initializer_range"],
              self.params["num_attention_heads"],
              self.params["num_rand_blocks"],
              self.params["max_encoder_length"],
              self.params["block_size"],
              self.params["use_bias"],
              seed=layer_idx,
              name="layer_%d" % layer_idx)
          for layer_idx in range(self.params["num_hidden_layers"])
      ]

      # Normalization layer
      self.layer_norm = utils.NormLayer(self.params["hidden_size"])

  def call(self,
           encoder_inputs,
           encoder_inputs_mask,
           training=None):
    """Return the output of the decoder layer stacks.
    Args:
      encoder_inputs: tensor with shape
        [batch_size, input_length, hidden_size]
      encoder_inputs_mask: Mask for enccoder input. [batch_size, input_length]
      training: Boolean indicating whether the call is training or inference.
    Returns:
      Finaly layer encoder output. float tensor with shape
        [batch_size, input_length, hidden_size]
    """
    if self.params["attention_type"] == "block_sparse":
      # reshape and cast for blocking
      encoder_length = self.params["max_encoder_length"]
      encoder_block_size = self.params["block_size"]
      encoder_inputs_mask = tf.cast(encoder_inputs_mask, tf.float32)
      blocked_encoder_mask = tf.reshape(
          encoder_inputs_mask,
          (-1, encoder_length//encoder_block_size, encoder_block_size))
      encoder_from_mask = tf.reshape(encoder_inputs_mask,
                                     (-1, 1, encoder_length, 1))
      encoder_to_mask = tf.reshape(encoder_inputs_mask,
                                   (-1, 1, 1, encoder_length))

      # create band padding
      band_mask = attention.create_band_mask_from_inputs(
          blocked_encoder_mask, blocked_encoder_mask)

      # For unused masks 0 instead of None for compatilibity with recompute_grad
      attention_mask = 0.0

    else:
      # For unused masks 0 instead of None for compatilibity with recompute_grad
      blocked_encoder_mask = 0.0
      encoder_to_mask = 0.0
      encoder_from_mask = 0.0
      band_mask = 0.0

      encoder_inputs_mask = tf.cast(encoder_inputs_mask, tf.float32)
      attention_mask = attention.create_attention_mask_from_input_mask(
          encoder_inputs_mask, encoder_inputs_mask)

    if self.params["norm_type"] == "postnorm":
      encoder_inputs = self.layer_norm(encoder_inputs)

    layer_output = encoder_inputs
    for layer in self.encoder_layers:
      layer_output = layer(
          layer_output, attention_mask, band_mask,
          encoder_from_mask, encoder_to_mask, blocked_encoder_mask,
          training=training)

    if self.params["norm_type"] == "prenorm":
      layer_output = self.layer_norm(layer_output)

    return layer_output

In [22]:
######################  optimization ####################
# Copyright 2021 The BigBird Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions and classes related to optimization (weight updates)."""

import re

from absl import logging
import tensorflow.compat.v2 as tf

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops import resource_variable_ops


def get_linear_warmup_linear_decay_lr(init_lr, num_train_steps,
                                      num_warmup_steps):
  """Calculate learning rate with linear warmup and linear decay."""
  global_step = tf.compat.v1.train.get_or_create_global_step()

  learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)

  # Implements linear decay of the learning rate.
  learning_rate = tf.compat.v1.train.polynomial_decay(
      learning_rate,
      global_step,
      num_train_steps,
      end_learning_rate=0.0,
      power=1.0,
      cycle=False)

  # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
  # learning rate will be `global_step/num_warmup_steps * init_lr`.
  if num_warmup_steps:
    global_steps_int = tf.cast(global_step, tf.int32)
    warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

    global_steps_float = tf.cast(global_step, tf.float32)
    warmup_steps_float = tf.cast(num_warmup_steps, tf.float32)

    warmup_percent_done = global_steps_float / warmup_steps_float
    warmup_learning_rate = init_lr * warmup_percent_done

    is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
    learning_rate = (
        (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)

  return learning_rate


def get_linear_warmup_rsqrt_decay_lr(init_lr, hidden_size,
                                     num_warmup_steps):
  """Calculate learning rate with linear warmup and rsqrt decay."""
  num_warmup_steps = tf.cast(num_warmup_steps, tf.float32)
  global_step = tf.compat.v1.train.get_or_create_global_step()
  global_step = tf.cast(global_step, tf.float32)

  learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
  learning_rate *= tf.math.rsqrt(tf.cast(hidden_size, tf.float32))
  # Apply linear warmup
  learning_rate *= tf.minimum(1.0, global_step / num_warmup_steps)
  # Apply rsqrt decay
  learning_rate *= tf.math.rsqrt(tf.maximum(global_step, num_warmup_steps))

  return learning_rate


def get_optimizer(params, learning_rate):
  """Gets the optimzer based on the hparams and current mode (TPU vs. CPU/GPU).
  Args:
      params: A dictionary containing training hyperparameters.
      learning_rate: A float32 scalar.
  Returns:
    A string or an optimizer instance.
  """
  optimizer = None

  if params["optimizer"] == "Adafactor":
    try:
      from tensor2tensor.utils import adafactor  # pylint: disable=g-import-not-at-top
      optimizer = adafactor.AdafactorOptimizer(learning_rate=learning_rate)
    except ImportError:
      logging.error("tensor2tensor not installed. Cannot use Adafactor."
                    "Defaulting to Adam.")
      params["optimizer"] = "Adam"

  if params["optimizer"] == "Adam":
    optimizer = tf.compat.v1.train.AdamOptimizer(
        learning_rate,
        beta1=params["optimizer_beta1"],
        beta2=params["optimizer_beta2"],
        epsilon=params["optimizer_epsilon"])

  if params["optimizer"] == "AdamWeightDecay":
    optimizer = AdamWeightDecayOptimizer(
        learning_rate,
        weight_decay_rate=params["weight_decay_rate"],
        beta_1=params["optimizer_beta1"],
        beta_2=params["optimizer_beta2"],
        epsilon=params["optimizer_epsilon"],
        exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

  if params["optimizer"] == "SGD":
    optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate)

  if optimizer is None:
    raise ValueError("Unknown optimizer: {}.".format(params["optimizer"]))

  if params["use_tpu"]:
    # Average the gradients across TPU cores.
    optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)

  return optimizer


class AdamWeightDecayOptimizer(tf.compat.v1.train.Optimizer):
  """A basic Adam optimizer that includes "correct" L2 weight decay."""

  def __init__(self,
               learning_rate,
               weight_decay_rate=0.0,
               beta_1=0.9,
               beta_2=0.999,
               epsilon=1e-6,
               exclude_from_weight_decay=None,
               name="AdamWeightDecayOptimizer"):
    """Constructs a AdamWeightDecayOptimizer."""
    super(AdamWeightDecayOptimizer, self).__init__(False, name)

    self.learning_rate = learning_rate
    self.weight_decay_rate = weight_decay_rate
    self.beta_1 = beta_1
    self.beta_2 = beta_2
    self.epsilon = epsilon
    self.exclude_from_weight_decay = exclude_from_weight_decay

  def _create_slots(self, var_list):
    # Create slots for the first and second moments.
    for v in var_list:
      self._zeros_slot(v, "m", self._name)
      self._zeros_slot(v, "v", self._name)

  def _apply_dense(self, grad, var):
    param_name = self._get_variable_name(var.name)
    m = self.get_slot(var, "m")
    v = self.get_slot(var, "v")

    # Standard Adam update.
    next_m = (
        tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
    next_v = (
        tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
                                                  tf.square(grad)))

    update = next_m / (tf.sqrt(next_v) + self.epsilon)

    # Just adding the square of the weights to the loss function is *not*
    # the correct way of using L2 regularization/weight decay with Adam,
    # since that will interact with the m and v parameters in strange ways.
    #
    # Instead we want ot decay the weights in a manner that doesn't interact
    # with the m/v parameters. This is equivalent to adding the square
    # of the weights to the loss with plain (non-momentum) SGD.
    if self._do_use_weight_decay(param_name):
      update += self.weight_decay_rate * var

    update_with_lr = self.learning_rate * update

    next_param = var - update_with_lr

    return tf.group(
        [var.assign(next_param),
         m.assign(next_m),
         v.assign(next_v)])

  def _resource_apply_dense(self, grad, var):
    """See `tf.train.Optimizer._resource_apply_dense()`."""
    return self._apply_dense(grad, var)

  def _apply_sparse(self, grad, var):
    """See `tf.train.Optimizer._apply_sparse()`."""
    def scatter_update_fn(x, i, v):
      return tf.compat.v1.scatter_update(x, i, v, use_locking=self._use_locking)
    return self._apply_sparse_shared(
        grad.values, grad.indices, var, scatter_update_fn)

  def _resource_apply_sparse(self, grad, var, indices):
    """See `tf.train.Optimizer._resource_apply_spase()`."""
    def scatter_update_fn(x, i, v):
      with tf.control_dependencies(
          [resource_variable_ops.resource_scatter_update(x.handle, i, v)]):
        return x.value()
    return self._apply_sparse_shared(grad, indices, var, scatter_update_fn)

  def _apply_sparse_shared(self, grad, indices, var, scatter_update_fn):
    """Applies sparse gradients to a variable.
    Args:
      grad: A tensor for the `values` of `tf.IndexedSlices`.
      indices: A tensor for the `indices` of `tf.IndexedSlices`.
      var: A `tf.Variable` object.
      scatter_update_fn: A function which performs scattered update to
        a `tf.Variable` object. It takes tuple of (x, i, v) where:
          * x: A `tf.Variable` object which is updated by `i` and `v`,
          * i: A tensor for the `indices` of `tf.IndexedSlices`,
          * v: A tensor for the `values` of `tf.IndexedSlices`,
        and returns a tensor after updating `x`.
    Returns:
      An op which updates `var` with `grad` and `indices`.
    """
    param_name = self._get_variable_name(var.name)
    m = self.get_slot(var, "m")
    v = self.get_slot(var, "v")

    # m_t = beta1 * m + (1 - beta1) * g_t
    m_scaled_g_values = tf.multiply(1.0 - self.beta_1, grad)
    m_t = m.assign(m * self.beta_1)
    with tf.control_dependencies([m_t]):
      m_slice = tf.gather(m, indices) + m_scaled_g_values
      m_t = scatter_update_fn(m, indices, m_slice)

    # v_t = beta2 * v + (1 - beta2) * g_t^2
    v_scaled_g_values = tf.multiply(1.0 - self.beta_2, tf.square(grad))
    v_t = v.assign(v * self.beta_2)
    with tf.control_dependencies([v_t]):
      v_slice = tf.gather(v, indices) + v_scaled_g_values
      v_t = scatter_update_fn(v, indices, v_slice)

    update = m_t / (tf.sqrt(v_t) + self.epsilon)

    # Just adding the square of the weights to the loss function is *not*
    # the correct way of using L2 regularization/weight decay with Adam,
    # since that will interact with the m and v parameters in strange ways.
    #
    # Instead we want ot decay the weights in a manner that doesn't interact
    # with the m/v parameters. This is equivalent to adding the square
    # of the weights to the loss with plain (non-momentum) SGD.
    if self._do_use_weight_decay(param_name):
      update += self.weight_decay_rate * var

    update_with_lr = self.learning_rate * update

    next_param = var - update_with_lr

    return tf.group([var.assign(next_param), m_t, v_t])

  def _do_use_weight_decay(self, param_name):
    """Whether to use L2 weight decay for `param_name`."""
    if not self.weight_decay_rate:
      return False
    if self.exclude_from_weight_decay:
      for r in self.exclude_from_weight_decay:
        if re.search(r, param_name) is not None:
          return False
    return True

  def _get_variable_name(self, param_name):
    """Get the variable name from the tensor name."""
    m = re.match("^(.*):\\d+$", param_name)
    if m is not None:
      param_name = m.group(1)
    return param_name

In [23]:
############################ utils #############################
# Copyright 2021 The BigBird Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper and utility functions."""

import re

from absl import logging
import numpy as np
import tensorflow.compat.v2 as tf


############################### SHAPE UTILS ####################################


def get_shape_list(tensor, expected_rank=None, name=None):
  """Returns a list of the shape of tensor, preferring static dimensions.
  Args:
    tensor: A tf.Tensor object to find the shape of.
    expected_rank: (optional) int. The expected rank of `tensor`. If this is
      specified and the `tensor` has a different rank, and exception will be
      thrown.
    name: Optional name of the tensor for the error message.
  Returns:
    A list of dimensions of the shape of tensor. All static dimensions will
    be returned as python integers, and dynamic dimensions will be returned
    as tf.Tensor scalars.
  """
  if not tf.executing_eagerly() and name is None:
    name = tensor.name

  if expected_rank is not None:
    assert_rank(tensor, expected_rank, name)

  shape = tensor.shape.as_list()

  non_static_indexes = []
  for (index, dim) in enumerate(shape):
    if dim is None:
      non_static_indexes.append(index)

  if not non_static_indexes:
    return shape

  # assert False, "Static shape not available for {}".format(tensor)

  dyn_shape = tf.shape(tensor)
  for index in non_static_indexes:
    shape[index] = dyn_shape[index]
  return shape


def reshape_to_matrix(input_tensor):
  """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
  ndims = input_tensor.shape.ndims
  if ndims < 2:
    raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
                     (input_tensor.shape))
  if ndims == 2:
    return input_tensor

  width = input_tensor.shape[-1]
  output_tensor = tf.reshape(input_tensor, [-1, width])
  return output_tensor


def reshape_from_matrix(output_tensor, orig_shape_list):
  """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
  if len(orig_shape_list) == 2:
    return output_tensor

  output_shape = get_shape_list(output_tensor)

  orig_dims = orig_shape_list[0:-1]
  width = output_shape[-1]

  return tf.reshape(output_tensor, orig_dims + [width])


def assert_rank(tensor, expected_rank, name=None):
  """Raises an exception if the tensor rank is not of the expected rank.
  Args:
    tensor: A tf.Tensor to check the rank of.
    expected_rank: Python integer or list of integers, expected rank.
    name: Optional name of the tensor for the error message.
  Raises:
    ValueError: If the expected shape doesn't match the actual shape.
  """
  if not tf.executing_eagerly() and name is None:
    name = tensor.name

  expected_rank_dict = {}
  if isinstance(expected_rank, int):
    expected_rank_dict[expected_rank] = True
  else:
    for x in expected_rank:
      expected_rank_dict[x] = True

  actual_rank = tensor.shape.ndims
  if actual_rank not in expected_rank_dict:
    scope_name = tf.compat.v1.get_variable_scope().name
    raise ValueError(
        "For the tensor `{}` in scope `{}`, the actual rank "
        "`{}` (shape = {}) is not equal to the expected rank `{}`".format(
            name, scope_name, actual_rank, str(tensor.shape),
            str(expected_rank)))


############################### DENSE LAYERS ###################################


def create_initializer(initializer_range=0.02):
  """Creates a `truncated_normal_initializer` with the given range."""
  return tf.compat.v1.truncated_normal_initializer(stddev=initializer_range)


class Dense3dLayer(tf.keras.layers.Layer):
  """A dense layer with 3D kernel."""

  def __init__(self,
               num_attention_heads,
               size_per_head,
               initializer,
               activation,
               name=None,
               head_first=False,
               use_bias=True):
    """Constructor for dense layer with 3D kernel.
    Args:
      num_attention_heads: The size of output dimension.
      size_per_head: The size per attention head.
      initializer: Kernel initializer.
      activation: Actication function.
      name: The name scope of this layer.
      head_first: Whether to output head dimension before or after sequence dim.
      use_bias: Whether the layer uses a bias vector.
    """
    super(Dense3dLayer, self).__init__(name=name)
    self.num_attention_heads = num_attention_heads
    self.size_per_head = size_per_head
    self.initializer = initializer
    self.activation = activation
    self.head_first = head_first
    self.use_bias = use_bias

    with tf.compat.v1.variable_scope(name):
      hidden_size = self.num_attention_heads * self.size_per_head
      self.w = tf.compat.v1.get_variable(
          name="kernel",
          shape=[hidden_size, hidden_size],
          initializer=self.initializer)

      if self.use_bias:
        self.b = tf.compat.v1.get_variable(
            name="bias",
            shape=[hidden_size],
            initializer=tf.zeros_initializer())
      else:
        self.b = None

  def call(self, input_tensor):
    """Constructor for dense layer with 3D kernel.
    Args:
      input_tensor: float Tensor of shape [batch, seq_length, hidden_size].
    Returns:
      float logits Tensor.
    """
    hidden_size = self.num_attention_heads * self.size_per_head
    reshape_w = tf.reshape(
        self.w, [hidden_size, self.num_attention_heads, self.size_per_head])
    if self.head_first:
      ret = tf.einsum("abc,cde->adbe", input_tensor, reshape_w)
    else:
      ret = tf.einsum("abc,cde->abde", input_tensor, reshape_w)

    if self.use_bias:
      if self.head_first:
        reshape_b = tf.reshape(
            self.b, [1, self.num_attention_heads, 1, self.size_per_head])
      else:
        reshape_b = tf.reshape(
            self.b, [self.num_attention_heads, self.size_per_head])
      ret += reshape_b

    if self.activation is not None:
      return self.activation(ret)
    else:
      return ret


class Dense3dProjLayer(tf.keras.layers.Layer):
  """A dense layer with 3D kernel for projection."""

  def __init__(self,
               num_attention_heads,
               size_per_head,
               initializer,
               activation,
               name=None,
               use_bias=True):
    """Constructor for dense layer with 3D kernel for projection.
    Args:
      num_attention_heads: The size of output dimension.
      size_per_head: The size per attention head.
      initializer: Kernel initializer.
      activation: Actication function.
      name: The name scope of this layer.
      use_bias: Whether the layer uses a bias vector.
    """
    super(Dense3dProjLayer, self).__init__(name=name)
    self.num_attention_heads = num_attention_heads
    self.size_per_head = size_per_head
    self.initializer = initializer
    self.activation = activation
    self.use_bias = use_bias

    with tf.compat.v1.variable_scope(name):
      hidden_size = self.num_attention_heads * self.size_per_head
      self.w = tf.compat.v1.get_variable(
          name="kernel",
          shape=[hidden_size, hidden_size],
          initializer=self.initializer)

      if self.use_bias:
        self.b = tf.compat.v1.get_variable(
            name="bias",
            shape=[hidden_size],
            initializer=tf.zeros_initializer())
      else:
        self.b = None

  def call(self, input_tensor):
    """Constructor for dense layer with 3D kernel for projection.
    Args:
      input_tensor: float Tensor of shape [batch,from_seq_length,
        num_attention_heads, size_per_head].
    Returns:
      float logits Tensor.
    """
    hidden_size = self.num_attention_heads * self.size_per_head
    reshape_w = tf.reshape(
        self.w, [self.num_attention_heads, self.size_per_head, hidden_size])
    ret = tf.einsum("BFNH,NHD->BFD", input_tensor, reshape_w)

    if self.use_bias:
      ret += self.b

    if self.activation is not None:
      return self.activation(ret)
    else:
      return ret


class Dense2dLayer(tf.keras.layers.Layer):
  """A dense layer with 2D kernel."""

  def __init__(self,
               input_size,
               output_size,
               initializer,
               activation,
               name=None,
               use_bias=True):
    """Constructor for dense layer with 2D kernel.
    Args:
      input_size: The size of input dimension.
      output_size: The size of output dimension.
      initializer: Kernel initializer.
      activation: Actication function.
      name: The name scope of this layer.
      use_bias: Whether the layer uses a bias vector.
    """
    super(Dense2dLayer, self).__init__(name=name)
    self.input_size = input_size
    self.output_size = output_size
    self.initializer = initializer
    self.activation = activation
    self.use_bias = use_bias

    with tf.compat.v1.variable_scope(name):
      self.w = tf.compat.v1.get_variable(
          name="kernel",
          shape=[self.input_size, self.output_size],
          initializer=self.initializer)

      if self.use_bias:
        self.b = tf.compat.v1.get_variable(
            name="bias",
            shape=[self.output_size],
            initializer=tf.zeros_initializer())
      else:
        self.b = None

  def call(self, input_tensor):
    """Forward pass for dense layer with 2D kernel.
    Args:
      input_tensor: Float tensor with rank 3.
    Returns:
      float logits Tensor.
    """
    ret = tf.einsum("abc,cd->abd", input_tensor, self.w)

    if self.use_bias:
      ret += self.b

    if self.activation is not None:
      return self.activation(ret)
    else:
      return ret


class SimpleDenseLayer(tf.keras.layers.Layer):
  """A simple dense layer with 2D kernel."""

  def __init__(self,
               input_size,
               output_size,
               initializer,
               activation,
               name=None,
               use_bias=True):
    """Constructor for dense layer with 2D kernel.
    Args:
      input_size: The size of input dimension.
      output_size: The size of output dimension.
      initializer: Kernel initializer.
      activation: Actication function.
      name: The name scope of this layer.
      use_bias: Whether the layer uses a bias vector.
    """
    super(SimpleDenseLayer, self).__init__(name=name)
    self.input_size = input_size
    self.output_size = output_size
    self.initializer = initializer
    self.activation = activation
    self.use_bias = use_bias

    with tf.compat.v1.variable_scope(name):
      self.w = tf.compat.v1.get_variable(
          name="kernel",
          shape=[self.input_size, self.output_size],
          initializer=self.initializer)

      if self.use_bias:
        self.b = tf.compat.v1.get_variable(
            name="bias",
            shape=[self.output_size],
            initializer=tf.zeros_initializer())
      else:
        self.b = None

  def call(self, input_tensor):
    """Forward pass for dense layer with 2D kernel.
    Args:
      input_tensor: Float tensor with rank 2.
    Returns:
      float logits Tensor.
    """
    ret = tf.einsum("ab,bc->ac", input_tensor, self.w)

    if self.use_bias:
      ret += self.b

    if self.activation is not None:
      return self.activation(ret)
    else:
      return ret


def gelu(x):
  """Gaussian Error Linear Unit.
  This is a smoother version of the RELU.
  Original paper: https://arxiv.org/abs/1606.08415
  Args:
    x: float Tensor to perform activation.
  Returns:
    `x` with the GELU activation applied.
  """
  cdf = 0.5 * (1.0 + tf.tanh(
      (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
  return x * cdf


def get_activation(activation_string):
  """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
  Args:
    activation_string: String name of the activation function.
  Returns:
    A Python function corresponding to the activation function. If
    `activation_string` is None, empty, or "linear", this will return None.
    If `activation_string` is not a string, it will return `activation_string`.
  Raises:
    ValueError: The `activation_string` does not correspond to a known
      activation.
  """

  # We assume that anything that"s not a string is already an activation
  # function, so we just return it.
  if not isinstance(activation_string, str):
    return activation_string

  if not activation_string:
    return None

  act = activation_string.lower()
  if act == "linear":
    return None
  elif act == "relu":
    return tf.nn.relu
  elif act == "gelu":
    return gelu
  elif act == "tanh":
    return tf.tanh
  else:
    raise ValueError("Unsupported activation: %s" % act)


############################## NORM LAYERS #####################################


class NormLayer(tf.keras.layers.Layer):
  """Replacement for contrib_layers.layer_norm."""

  def __init__(self, hdim, dtype=tf.float32, name="LayerNorm"):
    super(NormLayer, self).__init__(name=name)
    self._dtype = dtype

    with tf.compat.v1.variable_scope(name):
      self.beta = tf.compat.v1.get_variable(
          "beta", [hdim], dtype=dtype, initializer=tf.zeros_initializer())
      self.gamma = tf.compat.v1.get_variable(
          "gamma", [hdim], dtype=dtype, initializer=tf.ones_initializer())

  def call(self, inputs):
    inputs_shape = inputs.shape

    # Compute norm along last axis
    mean, variance = tf.nn.moments(inputs, [-1], keepdims=True)
    # Compute layer normalization using the batch_normalization function.
    # Note that epsilon must be increased for float16 due to the limited
    # representable range.
    variance_epsilon = 1e-12 if self._dtype != tf.float16 else 1e-3
    outputs = tf.nn.batch_normalization(
        inputs,
        mean,
        variance,
        offset=self.beta,
        scale=self.gamma,
        variance_epsilon=variance_epsilon)
    outputs.set_shape(inputs_shape)
    return outputs


############################# EMBEDDING LAYER ##################################


class EmbeddingLayer(tf.keras.layers.Layer):
  """An embedding layer."""

  def __init__(self,
               vocab_size,
               emb_dim,
               initializer,
               scale_emb=False,
               use_token_type=False,
               num_token_types=16,
               use_position_embeddings=True,
               max_position_embeddings=4096,
               dropout_prob=0.0,
               name="embeddings"):
    super(EmbeddingLayer, self).__init__(name=name)
    self.vocab_size = vocab_size
    self.emb_dim = emb_dim
    self.scale_emb = scale_emb
    self.num_token_types = num_token_types
    self.max_position_embeddings = max_position_embeddings
    self.dropout_prob = dropout_prob

    with tf.compat.v1.variable_scope(name):
      self.word_embeddings = tf.compat.v1.get_variable(
          "word_embeddings", [vocab_size, emb_dim],
          dtype=tf.float32, initializer=initializer)

      if use_token_type:
        self.token_type_table = tf.compat.v1.get_variable(
            "token_type_embeddings", [num_token_types, emb_dim],
            dtype=tf.float32, initializer=initializer)
      else:
        self.token_type_table = None

      if use_position_embeddings:
        self.position_embeddings = tf.compat.v1.get_variable(
            "position_embeddings", [max_position_embeddings, emb_dim],
            dtype=tf.float32, initializer=initializer)
      else:
        self.position_embeddings = None

  def call(self,
           input_ids,
           seq_length,
           start_pos=0,
           token_type_ids=None,
           training=None):
    if input_ids is None:
      return None

    # subtoken embedding
    output = tf.nn.embedding_lookup(params=self.word_embeddings, ids=input_ids)

    if self.scale_emb:
      output = output * self.emb_dim ** 0.5

    if self.token_type_table is not None:
      # This vocab will be small so we always do one-hot here, since it is
      # always faster for a small vocabulary.
      one_hot_ids = tf.one_hot(token_type_ids, depth=self.num_token_types)
      token_type_embeddings = tf.tensordot(
          one_hot_ids, self.token_type_table, 1)
      output += token_type_embeddings

    if self.position_embeddings is not None:
      # assert_op = tf.compat.v1.assert_less_equal(
      #     start_pos + seq_length, self.max_position_embeddings)
      # with tf.control_dependencies([assert_op]):
      # So `position_embeddings` is effectively an embedding table for
      # position [0, 1, 2, ..., max_position_embeddings-1], and the current
      # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
      # perform a slice.
      position_embeddings = tf.slice(self.position_embeddings, [start_pos, 0],
                                     [seq_length, self.emb_dim])
      output += tf.expand_dims(position_embeddings, axis=0)

    if training and self.dropout_prob > 0:
      output = tf.nn.dropout(output, self.dropout_prob)
    return output

  def linear(self, x):
    """Computes logits by running x through a linear layer.
    Args:
      x: A float32 tensor with shape [..., hidden_size]
    Returns:
      float32 tensor with shape [..., vocab_size].
    """
    with tf.compat.v1.name_scope("presoftmax_linear"):
      logits = tf.tensordot(x, self.word_embeddings, [[-1], [1]])
    return logits


########################## TPU/CHECKPOINT UTILS ################################


def get_estimator(config, model_fn, keep_checkpoint_max=10):
  """Create TPUEstimator object for given config and model_fn."""
  tpu_cluster_resolver = None
  if config["use_tpu"] and config["tpu_name"]:
    tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        config["tpu_name"],
        zone=config["tpu_zone"],
        project=config["gcp_project"])

  # Batch size book-keeping
  # Estimators handle batch sizes differently among GPUs and TPUs
  # GPU: Estimator needs per core batch size
  # TPU: Estimator needs total batch size, i.e. num_cores * per core batch size
  config_train_batch_size = config["train_batch_size"]     # For estimator
  config_eval_batch_size = config["eval_batch_size"]       # For estimator
  effective_train_batch_size = config["train_batch_size"]  # For human
  effective_eval_batch_size = config["eval_batch_size"]    # For human
  session_config = None
  if config["use_tpu"]:
    sliced_eval_mode = tf.compat.v1.estimator.tpu.InputPipelineConfig.SLICED
    distribute_strategy = None
    config_train_batch_size *= config["num_tpu_cores"]
    config_eval_batch_size *= config["num_tpu_cores"]
    effective_train_batch_size = config_train_batch_size
    effective_eval_batch_size = config_eval_batch_size
  else:
    session_config = tf.compat.v1.ConfigProto(
        allow_soft_placement=True,
        gpu_options=tf.compat.v1.GPUOptions(
            per_process_gpu_memory_fraction=1.2))
    cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
    with tf.compat.v1.Session(cluster_resolver.master(),
                              config=session_config) as sess:
      logging.info(sess.list_devices())
    sliced_eval_mode = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V1
    distribute_strategy = tf.distribute.MirroredStrategy(devices=None)
    effective_train_batch_size *= distribute_strategy.num_replicas_in_sync
    # effective_eval_batch_size *= distribute_strategy.num_replicas_in_sync

  is_per_host = tf.compat.v1.estimator.tpu.InputPipelineConfig.PER_HOST_V2
  run_config = tf.compat.v1.estimator.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=config["master"],
      model_dir=config["output_dir"],
      save_checkpoints_steps=config["save_checkpoints_steps"],
      keep_checkpoint_max=keep_checkpoint_max,
      train_distribute=distribute_strategy,
      session_config=session_config,
      tpu_config=tf.compat.v1.estimator.tpu.TPUConfig(
          tpu_job_name=config["tpu_job_name"],
          iterations_per_loop=config["iterations_per_loop"],
          num_shards=config["num_tpu_cores"],
          per_host_input_for_training=is_per_host,
          eval_training_input_configuration=sliced_eval_mode))

  if config["init_checkpoint"]:
    ckpt_var_list = tf.compat.v1.train.list_variables(config["init_checkpoint"])
    ckpt_var_list = {
        name: shape for name, shape in ckpt_var_list
        if not re.findall("(Adam|Adafactor|global_step)", name)
    }
    vars_to_warm_start = "({})".format("|".join(ckpt_var_list.keys()))
    warm_start_settings = tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=config["init_checkpoint"],
        vars_to_warm_start=vars_to_warm_start)
  else:
    ckpt_var_list = {}
    warm_start_settings = None
  config["ckpt_var_list"] = ckpt_var_list

  # If no TPU, this will fall back to normal Estimator on CPU or GPU.
  estimator = tf.compat.v1.estimator.tpu.TPUEstimator(
      use_tpu=config["use_tpu"],
      model_fn=model_fn,
      config=run_config,
      train_batch_size=config_train_batch_size,
      eval_batch_size=config_eval_batch_size,
      warm_start_from=warm_start_settings)

  # assign batch sizes
  estimator.train_batch_size = effective_train_batch_size
  estimator.eval_batch_size = effective_eval_batch_size

  return estimator


def log_variables(variables, ckpt_var_list):
  """Log trainable variables."""
  logging.info("**** Trainable Variables ****")

  model_var_list = {var.name: var.get_shape().as_list() for var in variables}
  num_params = sum(np.prod(shape) for shape in model_var_list.values())
  length = max(len(name) for name in model_var_list) + 2
  line = "{{:<{}}}{{:<13}}{{}}".format(length)

  logging.info("The model has {} trainable variables "
               "({:,} parameters):\n".format(len(model_var_list), num_params))
  logging.info(line.format("Name", "Initialized", "Shape"))
  logging.info(line.format("----", "-----------", "-----"))

  ckpt_var_list = ckpt_var_list.copy()
  for name, shape in model_var_list.items():
    name = name.split(":")[0]
    if name in ckpt_var_list:
      warm_started = "from ckpt"
      del ckpt_var_list[name]
    else:
      warm_started = "random"
    logging.info(line.format(name, warm_started, shape))

  if ckpt_var_list:
    logging.warning(
        "The warm start checkpoint contained %d variables that were not used "
        "for the model:\n", len(ckpt_var_list))
    for name, shape in ckpt_var_list.items():
      logging.warning(line.format(name, "not used", shape))


def add_scalars_to_summary(summary_dir, scalar_tensors_dict):
  """Creates a host_call function that writes summaries on TPU."""

  #  All tensors outfed from TPU should preserve batch size dimension.
  scalar_tensors_dict = {
      k: tf.reshape(v, [1]) for k, v in scalar_tensors_dict.items()
  }

  def host_call_fn(**kwargs):
    writer = tf.summary.create_file_writer(summary_dir, max_queue=1000)
    always_record = tf.summary.record_if(True)
    with writer.as_default(), always_record:
      for name, scalar in kwargs.items():
        tf.summary.scalar(name, tf.reduce_mean(scalar),
                          tf.compat.v1.train.get_or_create_global_step())
      return tf.compat.v1.summary.all_v2_summary_ops()

  return host_call_fn, scalar_tensors_dict


########################## DEFAULT CONFIG UTILS ################################


def get_default_config():
  """Default values for BigBird."""

  default_config = {
      # transformer basic configs
      "attention_probs_dropout_prob": 0.1,
      "hidden_act": "gelu",
      "hidden_dropout_prob": 0.1,
      "hidden_size": 768,
      "initializer_range": 0.02,
      "intermediate_size": 3072,
      "max_position_embeddings": 4096,
      "num_attention_heads": 12,
      "num_hidden_layers": 12,
      "type_vocab_size": 2,
      "use_bias": True,
      "rescale_embedding": False,
      "scope": "bert",
      # sparse mask configs
      "attention_type": "block_sparse",
      "norm_type": "postnorm",
      "block_size": 16,
      "num_rand_blocks": 3,
      # common bert configs
      "max_encoder_length": 1024,
      "max_decoder_length": 64,
      "couple_encoder_decoder": False,
      "beam_size": 5,
      "alpha": 0.7,
      "label_smoothing": 0.1,
      "weight_decay_rate": 0.01,
      "optimizer_beta1": 0.9,
      "optimizer_beta2": 0.999,
      "optimizer_epsilon": 1e-6,
      # TPU settings
      "use_tpu": True,
      "tpu_name": None,
      "tpu_zone": None,
      "tpu_job_name": None,
      "gcp_project": None,
      "master": None,
      "num_tpu_cores": 8,
      "iterations_per_loop": "1000",
  }

  return default_config

In [24]:
######################### recompute grade ########################
# Copyright 2021 The BigBird Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for rematerialization.
Incubates a version of tf.recompute_grad that is XLA compatible.
"""
import collections
import numbers
import os
import threading
from typing import Deque, List, NamedTuple, Optional, Sequence, Text, Union

from absl import logging
import numpy as np
import tensorflow.compat.v2 as tf

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import ops
from tensorflow.python.ops import custom_gradient


# Remove when https://github.com/tensorflow/tensorflow/pull/45298
# gets merged
def get_variable_by_name(var_name):
  """Retrieves tf.Variable from name in MirroredStrategy (multi-gpu)."""

  # Get all variables, but it will have copies from different replicas
  all_global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)

  def _replica_filter(var):
    """Filter out variables from different context."""
    try:
      return var_name == var.op.name
    except AttributeError:
      return False
  candidate_vars = list(filter(_replica_filter, all_global_vars))

  if len(candidate_vars) >= 1:
    # Filter out non-trainable variables.
    candidate_vars = [v for v in candidate_vars if v.trainable]
  else:
    raise ValueError('Unsuccessful at finding variable {}.'.format(var_name))

  if len(candidate_vars) == 1:
    return candidate_vars[0]
  elif len(candidate_vars) > 1:
    raise ValueError(
        'Unsuccessful at finding trainable variable {}. '
        'Number of candidates: {}. '
        'Candidates: {}'.format(var_name, len(candidate_vars), candidate_vars))
  else:
    # The variable is not trainable.
    return None
custom_gradient.get_variable_by_name = get_variable_by_name


class RecomputeContext(
    NamedTuple('RecomputeContext', [
        ('is_recomputing', bool),
        ('seed', tf.Tensor),
        ('children', Deque['RecomputeContext']),
    ])):
  """Context for recomputation.
  Attributes:
    is_recomputing: Whether we are in a recomputation phase.
    seed: Scalar integer tensor that should be used with stateless random ops
      for deterministic behavior and correct computation of the gradient.
    children: Nested `RecomputeContext` instances. Used internally by
      `recompute_grad` to track nested instances of `RecomputeContext`.
  """

  def __enter__(self):
    return _context_stack.push(self)

  def __exit__(self, exc_type, exc_value, traceback):
    _context_stack.pop(self)


# Simplified version of `_DefaultStack` in
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/ops.py.
class _ContextStack(threading.local):
  """A thread-local stack for providing implicit recompute contexts."""

  def __init__(self):
    super(_ContextStack, self).__init__()
    self._stack = []

  def top(self) -> Optional[RecomputeContext]:
    return self._stack[-1] if self._stack else None

  def push(self, context: RecomputeContext):
    self._stack.append(context)
    return context

  def pop(self, context: RecomputeContext):
    if self._stack[-1] is not context:
      raise AssertionError('Nesting violated for RecomputeContext.')
    self._stack.pop()


_context_stack = _ContextStack()


def get_recompute_context() -> Optional[RecomputeContext]:
  """Returns the current recomputing context if it exists."""
  return _context_stack.top()


# Adapted from
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_util.py.
def _get_containing_xla_context(graph: tf.Graph) -> Optional[object]:
  """Returns the first ancestor `XLAControlFlowContext` in the `graph`."""
  ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
  while ctxt:
    if ctxt.IsXLAContext():
      return ctxt
    ctxt = ctxt.outer_context
  return None


def _in_xla_context(graph: Optional[tf.Graph] = None) -> bool:
  """Detects whether we are in an XLA context."""
  if '--tf_xla_auto_jit=2' in os.environ.get('TF_XLA_FLAGS', ''):
    return True
  graph = tf.compat.v1.get_default_graph() if graph is None else graph
  while True:
    if _get_containing_xla_context(graph) is not None:
      return True
    try:
      graph = graph.outer_graph
    except AttributeError:
      return False


def _force_data_dependency(
    first_compute: Sequence[tf.Tensor],
    then_compute: Sequence[tf.Tensor]) -> List[tf.Tensor]:
  """Force all of `then_compute` to depend on all of `first_compute`.
  Uses a dummy data dependency, which is useful when running on TPUs because
  XLA ignores control dependencies. Only supports float arguments.
  Args:
    first_compute: Sequence of `Tensor`s to be executed before `then_compute`.
    then_compute: Sequence of `Tensor`s to executed after `first_compute`.
  Returns:
    Sequence of `Tensor`s with same length of `then_compute`.
  Raises:
    ValueError: if ranks are unknown or types are not floating.
  """

  def _first_element(x):
    if x.shape.ndims is None:
      raise ValueError('Rank of Tensor %s must be known' % x)
    ndims = x.shape.ndims
    begin = tf.zeros(ndims, dtype=tf.int32)
    size = tf.ones(ndims, dtype=tf.int32)
    return tf.reshape(tf.slice(x, begin, size), [])

  first_compute_sum = tf.add_n(
      [_first_element(x) for x in first_compute if x is not None])
  dtype = first_compute_sum.dtype
  if not dtype.is_floating:
    raise ValueError('_force_data_dependency only supports floating dtypes.')
  zero = np.finfo(dtype.as_numpy_dtype).tiny * first_compute_sum
  return [
      x + tf.cast(zero, x.dtype) if x is not None else None
      for x in then_compute
  ]


def _make_seed_if_none(seed: Optional[tf.Tensor]) -> tf.Tensor:
  """Uses the global generator to make a seed if necessary."""
  if seed is not None:
    return seed
  generator = tf.random.experimental.get_global_generator()
  # The two seeds for stateless random ops don't have individual semantics and
  # are scrambled together, so providing one seed is fine. This makes it easier
  # for users to provide a local seed without worrying about integer overflow.
  # See `make_seeds` in
  # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/stateful_random_ops.py.
  try:
    return generator.uniform_full_int([], tf.int32, name='recompute_grad_seed')
  except (RuntimeError, TypeError, ValueError, tf.errors.NotFoundError) as e:
    # For a number of reasons, the above operation can fail like using multiple
    # graphs or toggling between eager and graph modes. Reset the generator.
    logging.warn('Resetting the generator. %s: %s', type(e), e)
    tf.random.experimental.set_global_generator(None)
    generator = tf.random.experimental.get_global_generator()
    return generator.uniform_full_int([], tf.int32, name='recompute_grad_seed')


def recompute_grad(f, seed=None):
  """An eager-compatible version of recompute_grad.
  For f(*args, **kwargs), this supports gradients with respect to args, or to
  gradients with respect to any variables residing in the kwarg 'variables'.
  Note that for keras layer and model objects, this is handled automatically.
  Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
  be able to access the member variables of that object, because `g` returns
  through the wrapper function `inner`.  When recomputing gradients through
  objects that inherit from keras, we suggest keeping a reference to the
  underlying object around for the purpose of accessing these variables.
  Args:
    f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs.
    seed: Optional seed for random ops. `seed` should an integer scalar
      `Tensor`. When compiling to XLA, `seed` must have dtype `tf.int32`. If
      `seed` is not provided one will be generated.
  Returns:
   A function `g` that wraps `f`, but which recomputes `f` on the backwards
   pass of a gradient call.
  """

  @tf.custom_gradient
  def inner(*args, **kwargs):
    """Inner function closure for calculating gradients."""
    # Detect when we're nested and in the backwards pass, so we don't generate
    # an additional seed.
    parent_context = get_recompute_context()
    if parent_context is not None and parent_context.is_recomputing:
      # Use the cached context in the recomputation phase.
      with parent_context.children.popleft()._replace(
          is_recomputing=True) as context:
        result = f(*args, **kwargs)
    else:
      with RecomputeContext(
          is_recomputing=False,
          seed=_make_seed_if_none(seed),
          children=collections.deque()) as context:
        result = f(*args, **kwargs)
        # In the forward pass, build up a tree of recomputation contexts.
        if parent_context is not None and not parent_context.is_recomputing:
          parent_context.children.append(context)

    def grad(*dresult, **grad_kwargs):
      """Gradient function calculation for inner function."""
      variables = grad_kwargs.pop('variables', None)
      if grad_kwargs:
        raise ValueError('Found unexpected kwargs for `grad`: ',
                         list(grad_kwargs.keys()))
      inputs, seed = list(args), context.seed
      if _in_xla_context():
        inputs = _force_data_dependency(
            tf.nest.flatten(dresult), inputs + [seed])
        seed = inputs.pop()
      # tf.keras.backend.set_learning_phase(1)
      with tf.GradientTape() as tape:
        tape.watch(inputs)
        if variables is not None:
          tape.watch(variables)
        with tf.control_dependencies(dresult):
          with context._replace(is_recomputing=True, seed=seed):
            result = f(*inputs, **kwargs)
      kw_vars = []
      if variables is not None:
        kw_vars = list(variables)
      grads = tape.gradient(
          result, list(inputs) + kw_vars, output_gradients=dresult)
      return grads[:len(inputs)], grads[len(inputs):]

    return result, grad

  return inner


######################## STATELESS DROPOUT LAYERS ##############################


def _as_shape(shape: Union[Sequence[int], tf.TensorShape]) -> tf.TensorShape:
  """Converts the given object to a TensorShape."""
  return shape if isinstance(shape, tf.TensorShape) else tf.TensorShape(shape)


def _get_noise_shape(
    x: tf.Tensor, noise_shape: Union[Sequence[int], tf.TensorShape]
) -> Union[tf.Tensor, tf.TensorShape, Sequence[int]]:
  """Computes the shape of the binary mask for dropout."""
  # If noise_shape is none return immediately.
  if noise_shape is None:
    return tf.shape(x)

  try:
    # Best effort to figure out the intended shape.
    # If not possible, let the op to handle it.
    # In eager mode exception will show up.
    noise_shape_ = _as_shape(noise_shape)
  except (TypeError, ValueError):
    return noise_shape

  if x.shape.dims is not None and len(x.shape.dims) == len(noise_shape_.dims):
    new_dims = []
    for i, dim in enumerate(x.shape.dims):
      if noise_shape_.dims[i].value is None and dim.value is not None:
        new_dims.append(dim.value)
      else:
        new_dims.append(noise_shape_.dims[i].value)
    return tf.TensorShape(new_dims)

  return noise_shape


def stateless_dropout(x: tf.Tensor,
                      rate: float,
                      seed: tf.Tensor,
                      noise_shape: Optional[Union[Sequence[int],
                                                  tf.TensorShape]] = None,
                      name: Optional[Text] = None) -> tf.Tensor:
  """Computes dropout: randomly sets elements to zero to prevent overfitting.
  See https://www.tensorflow.org/api_docs/python/tf/nn/dropout.
  This version differs in that the seed is required if the rate is nonzero.
  Args:
    x: A floating point tensor.
    rate: A scalar `Tensor` with the same type as x. The probability that each
      element is dropped. For example, setting rate=0.1 would drop 10% of input
      elements.
    seed: A shape [2] integer Tensor of seeds to the random number generator.
      Must have dtype `tf.int32` when compiling to XLA.
    noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for
      randomly generated keep/drop flags.
    name: A name for this operation (optional).
  Returns:
    A `Tensor` of the same shape of `x`.
  Raises:
    ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating point
      tensor. `rate=1` is disallowed, because the output would be all zeros,
      which is likely not what was intended.
  """
  with tf.name_scope(name or 'stateless_dropout') as name:
    x = tf.convert_to_tensor(x, name='x')
    if not x.dtype.is_floating:
      raise ValueError('x has to be a floating point tensor since it\'s going '
                       ' to be scaled. Got a %s tensor instead.' % x.dtype)
    if isinstance(rate, numbers.Real):
      if not (rate >= 0 and rate < 1):
        raise ValueError('rate must be a scalar tensor or a float in the '
                         'range [0, 1), got %g' % rate)
      if rate > 0.5:
        logging.log_first_n(
            logging.WARN, 'Large dropout rate: %g (>0.5). In TensorFlow '
            '.x, dropout() uses dropout rate instead of keep_prob. '
            'Please ensure that this is intended.', 5, rate)

    # Early return if nothing needs to be dropped.
    if tf.get_static_value(rate) == 0:
      return x

    rate = tf.convert_to_tensor(rate, dtype=x.dtype, name='rate')
    rate.shape.assert_has_rank(0)
    noise_shape = _get_noise_shape(x, noise_shape)
    # Sample a uniform distribution on [0.0, 1.0) and select values larger than
    # rate.
    #
    # NOTE: Random uniform actually can only generate 2^23 floats on [1.0, 2.0)
    # and subtract 1.0.
    random_tensor = tf.random.stateless_uniform(
        noise_shape, seed=seed, dtype=x.dtype)
    keep_prob = 1 - rate
    scale = 1 / keep_prob
    # NOTE: if (1.0 + rate) - 1 is equal to rate, then we want to consider that
    # float to be selected, hence we use a >= comparison.
    keep_mask = random_tensor >= rate
    ret = x * scale * tf.cast(keep_mask, x.dtype)
    if not tf.executing_eagerly():
      ret.set_shape(x.get_shape())
    return ret


# Reimplements internal function
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/smart_cond.py.
def smart_cond(pred, true_fn=None, false_fn=None, name=None):
  """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
  If `pred` is a bool or has a constant value, we return either `true_fn()`
  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
  Arguments:
    pred: A scalar determining whether to return the result of `true_fn` or
      `false_fn`.
    true_fn: The callable to be performed if pred is true.
    false_fn: The callable to be performed if pred is false.
    name: Optional name prefix when using `tf.cond`.
  Returns:
    Tensors returned by the call to either `true_fn` or `false_fn`.
  Raises:
    TypeError: If `true_fn` or `false_fn` is not callable.
  """
  if not callable(true_fn):
    raise TypeError('`true_fn` must be callable.')
  if not callable(false_fn):
    raise TypeError('`false_fn` must be callable.')
  pred_value = tf.get_static_value(pred)
  if isinstance(pred, tf.Variable) or pred_value is None:
    return tf.cond(
        pred, true_fn=true_fn, false_fn=false_fn, name=name)
  if pred_value:
    return true_fn()
  else:
    return false_fn()


# See https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dropout.
class RecomputingDropout(tf.keras.layers.Layer):
  """`tf.keras.layers.Dropout` that supports `recompute_grad`."""

  def __init__(self,
               rate,
               noise_shape=None,
               seed=None,
               force_recomputation=False,
               **kwargs):
    """Initializes `RecomputingDropout`.
    Args:
      rate: Float between 0 and 1. Fraction of the input units to drop.
      noise_shape: 1D integer tensor representing the shape of the binary
        dropout mask that will be multiplied with the input. For instance, if
        inputs have shape `(batch_size, timesteps, features)` and you want the
        dropout mask to be the same for all timesteps, you can use
        `noise_shape=(batch_size, 1, features)`.
      seed: A Python integer to use as random seed.
      force_recomputation: If `True`, then raises an error if called outside a
        recompute context.
      **kwargs: Keyword arguments for `tf.keras.layers.Layer`.
    """

    super(RecomputingDropout, self).__init__(**kwargs)
    self.rate = rate
    self.noise_shape = noise_shape
    self.seed = seed
    self.force_recomputation = force_recomputation
    self.supports_masking = True
    # Create a layer-specific seed to combine with the global recompute seed.
    self._recompute_seed = (
        np.random.randint(-2**31, 2**31, dtype=np.int32)
        if seed is None else seed)

  def _get_noise_shape(self, inputs):
    # Subclasses of `Dropout` may implement `_get_noise_shape(self, inputs)`,
    # which will override `self.noise_shape`, and allows for custom noise
    # shapes with dynamically sized inputs.
    if self.noise_shape is None:
      return None

    concrete_inputs_shape = tf.shape(inputs)
    noise_shape = []
    for i, value in enumerate(self.noise_shape):
      noise_shape.append(concrete_inputs_shape[i] if value is None else value)
    return tf.convert_to_tensor(noise_shape)

  def call(self, inputs, training=None):
    """Builds computation graph.
    Args:
      inputs: Input tensor (of any rank).
      training: Python boolean indicating whether the layer should behave in
        training mode (adding dropout) or in inference mode (doing nothing).
    Returns:
      `inputs` masked according to layer configuration.
    Raises:
      ValueError: If `force_recomputation` is `True` and called outside a
        a recompute context.
    """
    if self.rate == 0:
      return inputs

    if training is None:
      training = tf.keras.backend.learning_phase()

    def dropped_inputs():
      """Randomly drops elements of `inputs` when `training=True`."""
      recompute_context = get_recompute_context()
      if recompute_context is None:
        if self.force_recomputation:
          raise ValueError(
              'RecomputeContext is required when force_recomputation=True.')
        return tf.nn.dropout(
            inputs,
            noise_shape=self._get_noise_shape(inputs),
            seed=self.seed,
            rate=self.rate)
      seed = tf.stack([recompute_context.seed, self._recompute_seed])
      return stateless_dropout(
          inputs,
          rate=self.rate,
          seed=seed,
          noise_shape=self._get_noise_shape(inputs))

    output = smart_cond(training, dropped_inputs, lambda: tf.identity(inputs))
    return output

  def compute_output_shape(self, input_shape):
    return input_shape

  def get_config(self):
    config = {
        'rate': self.rate,
        'noise_shape': self.noise_shape,
        'seed': self.seed,
        'force_recomputation': self.force_recomputation,
    }
    base_config = super(RecomputingDropout, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

# Main 

In [None]:
# set-up
# https://chancoding.tistory.com/86 콘다 환경과 쥬피터 노트북 연결방법
# 해당 환경으로 들어감(prompt)
# pip install jupyter notebook
# pip install ipykernel
# python -m ipykernel install --user --name bigbird --display-name "bigbird"
# web에서 해당 패키지에 대한 정보를 가져옴
# !pip install git+https://github.com/google-research/bigbird.git -q

In [2]:
!pip install git+https://github.com/google-research/bigbird.git -q
# 해당환경에 bigbird설치함

[K     |████████████████████████████████| 1.2 MB 7.7 MB/s 
[K     |████████████████████████████████| 4.4 MB 37.0 MB/s 
[K     |████████████████████████████████| 1.4 MB 60.2 MB/s 
[K     |████████████████████████████████| 4.0 MB 53.5 MB/s 
[K     |████████████████████████████████| 981 kB 57.8 MB/s 
[K     |████████████████████████████████| 191 kB 57.6 MB/s 
[K     |████████████████████████████████| 5.8 MB 32.2 MB/s 
[K     |████████████████████████████████| 352 kB 51.4 MB/s 
[K     |████████████████████████████████| 1.1 MB 13.0 MB/s 
[K     |████████████████████████████████| 367 kB 41.5 MB/s 
[K     |████████████████████████████████| 366 kB 63.1 MB/s 
[K     |████████████████████████████████| 79 kB 8.1 MB/s 
[K     |████████████████████████████████| 48 kB 3.5 MB/s 
[K     |████████████████████████████████| 251 kB 63.2 MB/s 
[K     |████████████████████████████████| 191 kB 51.5 MB/s 
[K     |████████████████████████████████| 178 kB 60.2 MB/s 
[?25h  Building wheel for bi

In [4]:
from bigbird.core import flags
from bigbird.core import modeling
from bigbird.core import utils
from bigbird.classifier import run_classifier
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
import sys

FLAGS = flags.FLAGS
if not hasattr(FLAGS, "f"): flags.DEFINE_string("f", "", "")
FLAGS(sys.argv)

tf.enable_v2_behavior()

In [5]:
# setup

In [6]:
FLAGS.data_dir = "tfds://imdb_reviews/plain_text"
FLAGS.attention_type = "block_sparse"
FLAGS.max_encoder_length = 4096  # reduce for quicker demo on free colab
FLAGS.learning_rate = 1e-5
FLAGS.num_train_steps = 2000
FLAGS.attention_probs_dropout_prob = 0.0
FLAGS.hidden_dropout_prob = 0.0
FLAGS.use_gradient_checkpointing = True
FLAGS.vocab_model_file = "gpt2"

In [7]:
bert_config = flags.as_dictionary()

In [8]:
# define classification model

In [9]:
model = modeling.BertModel(bert_config)
headl = run_classifier.ClassifierLossLayer(
        bert_config["hidden_size"], bert_config["num_labels"],
        bert_config["hidden_dropout_prob"],
        utils.create_initializer(bert_config["initializer_range"]),
        name=bert_config["scope"]+"/classifier")

In [10]:
@tf.function(experimental_compile=True)
def fwd_bwd(features, labels):
  with tf.GradientTape() as g:
    _, pooled_output = model(features, training=True)
    loss, log_probs = headl(pooled_output, labels, True)
  grads = g.gradient(loss, model.trainable_weights+headl.trainable_weights)
  return loss, log_probs, grads

In [11]:
# dataset pipe line

In [13]:
train_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=True)
dataset = train_input_fn({'batch_size': 8})

[1mDownloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Shuffling imdb_reviews-train.tfrecord...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Shuffling imdb_reviews-test.tfrecord...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]

Shuffling imdb_reviews-unsupervised.tfrecord...:   0%|          | 0/50000 [00:00<?, ? examples/s]

[1mDataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.[0m


In [None]:
# inspect at a few examples
for ex in dataset.take(3):
  print(ex)

In [None]:
# (Optionally) Check outputs