In [1]:
import tensorflow as tf
tf.__version__

'2.0.0-rc1'

In [5]:
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
  """Samples a permutation of the factorization order.
     Creates perm_mask and target_mask accordingly.
  Args:
    inputs: int64 Tensor in shape [seq_len], input ids.
    targets: int64 Tensor in shape [seq_len], target ids.
    is_masked: bool Tensor in shape [seq_len]. True means being selected for
      partial prediction.
    perm_size: the length of longest permutation. Could be set to be reuse_len.
      Should not be larger than reuse_len or there will be data leaks.
    seq_len: int, sequence length.
  Returns:
    perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1.
    If perm_mask[i][j] == 1, it means the ith token (in original order) cannot
    attend to the jth token
    (in original order). This case will happen only when the ith token's
    permutated position <= the jth token's permutated position,
    and the jth token is masked or is func token. If perm_mask[i][j] == 0, it
    means the ith token (in original order) can attend to the jth token
    (in original order). Note that non-masked tokens can be attended by all
    other tokens, which is different from the description in original paper.
    new_targets: int64 Tensor in shape [seq_len], target token ids to be
    predicted in XLNet.
    In XLNet, target doesn't need to be shifted one position.
    target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If
    target_mask[i] == 1,
    the ith token needs to be predicted and mask will be used as input. This
    token will count for loss.
    If target_mask[i] == 0, token (or [SEP], [CLS]) will be used as input. This
    token will not count for loss.
    inputs_k: int64 Tensor in shape [seq_len], input ids.
    inputs_q: float32 Tensor in shape [seq_len], the same as target_mask.
  """

  # Generate permutation indices
  index = tf.range(seq_len, dtype=tf.int64)
  index = tf.transpose(tf.reshape(index, [-1, perm_size]))
  index = tf.random.shuffle(index)
  index = tf.reshape(tf.transpose(index), [-1])

  # `perm_mask` and `target_mask`
  # non-functional tokens
  non_func_tokens = tf.logical_not(
      tf.logical_or(tf.equal(inputs, 0), tf.equal(inputs, 1)))

  non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
  masked_or_func_tokens = tf.logical_not(non_mask_tokens)

  # Set the permutation indices of non-masked (& non-funcional) tokens to the
  # smallest index (-1):
  # (1) they can be seen by all other positions
  # (2) they cannot see masked positions, so there won"t be information leak
  smallest_index = -tf.ones([seq_len], dtype=tf.int64)
  rev_index = tf.where(non_mask_tokens, smallest_index, index)

  # Create `target_mask`: non-funcional and masked tokens
  # 1: use mask as input and have loss
  # 0: use token (or [SEP], [CLS]) as input and do not have loss
  target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
  target_mask = tf.cast(target_tokens, tf.float32)

  # Create `perm_mask`
  # `target_tokens` cannot see themselves
  self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

  # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
  # 0: can attend if i > j or j is non-masked
  perm_mask = tf.logical_and(self_rev_index[:, None] <= rev_index[None, :],
                             masked_or_func_tokens)
  perm_mask = tf.cast(perm_mask, tf.float32)

  # new target: [next token] for LM and [curr token] (self) for PLM
  new_targets = tf.concat([inputs[0:1], targets[:-1]], axis=0)

  # construct inputs_k
  inputs_k = inputs

  # construct inputs_q
  inputs_q = target_mask

  return perm_mask, new_targets, target_mask, inputs_k, inputs_q

In [8]:
inputs = [2,3,4,5,6,7]
targets = [2,3,4,5,6,7]
is_masked = [False, True, False, False, True, False]
perm_size = 2
seq_len = 6
_local_perm(inputs, targets, is_masked, perm_size, seq_len)

(<tf.Tensor: id=142, shape=(6, 6), dtype=float32, numpy=
 array([[0., 1., 0., 1., 1., 0.],
        [0., 1., 0., 1., 1., 0.],
        [0., 1., 0., 1., 1., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 1., 0.]], dtype=float32)>,
 <tf.Tensor: id=146, shape=(6,), dtype=int32, numpy=array([2, 2, 3, 4, 5, 6], dtype=int32)>,
 <tf.Tensor: id=128, shape=(6,), dtype=float32, numpy=array([0., 1., 0., 0., 0., 0.], dtype=float32)>,
 [2, 3, 4, 5, 6, 7],
 <tf.Tensor: id=128, shape=(6,), dtype=float32, numpy=array([0., 1., 0., 0., 0., 0.], dtype=float32)>)

In [1]:
from models_and_trainers.BERT_style_modules import BERTStyleDecoder

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [2]:
decoder = BERTStyleDecoder()

In [7]:
unset_keys = decoder.load_pretrained()
for k in list(unset_keys)[0]:
    print(k)

decoder.layers.0.multihead_attn.in_proj_weight
decoder.layers.0.multihead_attn.in_proj_bias
decoder.layers.0.multihead_attn.out_proj.weight
decoder.layers.0.multihead_attn.out_proj.bias
decoder.layers.0.norm3.weight
decoder.layers.0.norm3.bias
decoder.layers.1.multihead_attn.in_proj_weight
decoder.layers.1.multihead_attn.in_proj_bias
decoder.layers.1.multihead_attn.out_proj.weight
decoder.layers.1.multihead_attn.out_proj.bias
decoder.layers.1.norm3.weight
decoder.layers.1.norm3.bias
decoder.layers.2.multihead_attn.in_proj_weight
decoder.layers.2.multihead_attn.in_proj_bias
decoder.layers.2.multihead_attn.out_proj.weight
decoder.layers.2.multihead_attn.out_proj.bias
decoder.layers.2.norm3.weight
decoder.layers.2.norm3.bias
decoder.layers.3.multihead_attn.in_proj_weight
decoder.layers.3.multihead_attn.in_proj_bias
decoder.layers.3.multihead_attn.out_proj.weight
decoder.layers.3.multihead_attn.out_proj.bias
decoder.layers.3.norm3.weight
decoder.layers.3.norm3.bias
decoder.layers.4.multihe

In [3]:
keys = decoder.state_dict().keys()
for k in keys:
    print(k)

embedder.word_embeddings.weight
embedder.position_embeddings.weight
embedder.token_type_embeddings.weight
embedder.embedding_layer_norm.weight
embedder.embedding_layer_norm.bias
decoder.layers.0.self_attn.in_proj_weight
decoder.layers.0.self_attn.in_proj_bias
decoder.layers.0.self_attn.out_proj.weight
decoder.layers.0.self_attn.out_proj.bias
decoder.layers.0.multihead_attn.in_proj_weight
decoder.layers.0.multihead_attn.in_proj_bias
decoder.layers.0.multihead_attn.out_proj.weight
decoder.layers.0.multihead_attn.out_proj.bias
decoder.layers.0.linear1.weight
decoder.layers.0.linear1.bias
decoder.layers.0.linear2.weight
decoder.layers.0.linear2.bias
decoder.layers.0.norm1.weight
decoder.layers.0.norm1.bias
decoder.layers.0.norm2.weight
decoder.layers.0.norm2.bias
decoder.layers.0.norm3.weight
decoder.layers.0.norm3.bias
decoder.layers.1.self_attn.in_proj_weight
decoder.layers.1.self_attn.in_proj_bias
decoder.layers.1.self_attn.out_proj.weight
decoder.layers.1.self_attn.out_proj.bias
decode