In [1]:
def parse_depth_string(depth_str):
  depth_config = depth_str.split("x")
  if len(depth_config) == 1:
    depth_config.append(1)
  assert len(depth_config) == 2, "Require two-element depth config."

  return list(map(int, depth_config))

In [2]:
block_size = "6-3x2-3x2"
# block_size = "8-8-8"
block_size = "4-4-4"

In [3]:
block_size = block_size.split("-")
n_block = len(block_size)
block_rep = []
block_param = []
print(block_size, n_block)
for i, _ in enumerate(block_size):
  block_size_i = parse_depth_string(block_size[i])
  print(block_size_i)
  block_param.append(block_size_i[0])
  block_rep.append(block_size_i[1])

['4', '4', '4'] 3
[4, 1]
[4, 1]
[4, 1]


In [31]:
block_param

[4, 4, 4]

In [32]:
block_rep

[1, 1, 1]

In [33]:
[
    ["FunnelLayer" for _ in range(block_size)]
    for block_index, block_size in enumerate(block_param)
]

[['FunnelLayer', 'FunnelLayer', 'FunnelLayer', 'FunnelLayer'],
 ['FunnelLayer', 'FunnelLayer', 'FunnelLayer', 'FunnelLayer'],
 ['FunnelLayer', 'FunnelLayer', 'FunnelLayer', 'FunnelLayer']]

In [1]:
import torch.nn as nn
import torch
from torch.nn import functional as F

In [2]:
from transformers import FunnelConfig



In [5]:
config = FunnelConfig()

In [6]:
config

FunnelConfig {
  "activation_dropout": 0.0,
  "attention_dropout": 0.1,
  "attention_type": "relative_shift",
  "block_repeats": [
    1,
    1,
    1
  ],
  "block_sizes": [
    4,
    4,
    4
  ],
  "d_head": 64,
  "d_inner": 3072,
  "d_model": 768,
  "hidden_act": "gelu_new",
  "hidden_dropout": 0.1,
  "initializer_range": 0.1,
  "initializer_std": null,
  "layer_norm_eps": 1e-09,
  "max_position_embeddings": 512,
  "model_type": "funnel",
  "n_head": 12,
  "num_decoder_layers": 2,
  "pool_q_only": true,
  "pooling_type": "mean",
  "separate_cls": true,
  "truncate_seq": true,
  "type_vocab_size": 3,
  "vocab_size": 30522
}

### Inputs

In [10]:
class FunnelEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout)

    def forward(self, input_ids=None, inputs_embeds=None):
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        embeddings = self.layer_norm(inputs_embeds)
        embeddings = self.dropout(embeddings)
        return embeddings

In [11]:
fe = FunnelEmbeddings(fc)

In [12]:
input_ids = torch.LongTensor([[101, 2057, 1012, 102], [101, 1, 2, 102]])
input_ids.shape

torch.Size([2, 4])

In [13]:
input_embeds = fe(input_ids)
input_embeds.size()

torch.Size([2, 4, 768])

In [14]:
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
attention_mask = attention_mask.type_as(input_embeds)
attention_mask.shape

torch.Size([2, 4])

In [15]:
token_type_ids = torch.zeros(
    input_ids.shape, dtype=torch.long, device=input_ids.device)
token_type_ids.shape

torch.Size([2, 4])

### Type, Mask

In [16]:
def token_type_ids_to_mat(token_type_ids):
    """Convert `token_type_ids` to `token_type_mat`."""
    token_type_mat = token_type_ids[:, :, None] == token_type_ids[:, None]
    # Treat <cls> as in the same segment as both A & B
    cls_ids = token_type_ids == 2
    cls_mat = cls_ids[:, :, None] | cls_ids[:, None]
    return cls_mat | token_type_mat

In [17]:
seq_len = input_embeds.size(1)
seq_len

4

In [18]:
token_type_mat = token_type_ids_to_mat(token_type_ids)
token_type_mat.shape

torch.Size([2, 4, 4])

In [19]:
input_embeds.shape

torch.Size([2, 4, 768])

In [20]:
token_type_ids

tensor([[0, 0, 0, 0],
        [0, 0, 0, 0]])

In [21]:
cls_mask = (
    F.pad(input_embeds.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0))
)
cls_mask.shape

torch.Size([4, 4])

In [22]:
cls_mask

tensor([[0., 0., 0., 0.],
        [0., 1., 1., 1.],
        [0., 1., 1., 1.],
        [0., 1., 1., 1.]])

In [23]:
x = torch.randint(100, [2, 511])
print(x.shape)
xx = x[:, None, :, None]
xxx = F.avg_pool2d(xx, (2,1), stride=(2,1), ceil_mode=True)
xxx[:, 0, :, 0].shape

torch.Size([2, 511])


torch.Size([2, 256])

In [24]:
x = torch.randint(100, [2, 511, 768])
print(x.shape)
xx = x[:, None, :, :]
xxx = F.avg_pool2d(xx, (2,1), stride=(2,1), ceil_mode=True)
xxx[:, 0].shape

torch.Size([2, 511, 768])


torch.Size([2, 256, 768])

In [25]:
input_embeds.ndim

3

In [26]:
x = input_embeds
xx = x[:, None, :, :]
xxx = F.avg_pool2d(xx, (2,1), stride=(2,1), ceil_mode=True)
xxx[:, 0].shape

torch.Size([2, 2, 768])

### Position

In [446]:
def relative_pos(pos, stride, pooled_pos=None, shift=1):
    """
    Build the relative positional vector between `pos` and `pooled_pos`.
    """
    if pooled_pos is None:
        pooled_pos = pos

    ref_point = pooled_pos[0] - pos[0]
    num_remove = shift * len(pooled_pos)
    max_dist = ref_point + num_remove * stride
    min_dist = pooled_pos[0] - pos[-1]

    return torch.arange(max_dist, min_dist - 1, -stride, dtype=torch.long, device=pos.device)

In [447]:
def stride_pool_pos(pos_id, block_index):
    """
    Pool `pos_id` while keeping the cls token separate (if `config.separate_cls=True`).
    """
    if config.separate_cls:
        # Under separate <cls>, we treat the <cls> as the first token in
        # the previous block of the 1st real block. Since the 1st real
        # block always has position 1, the position of the previous block
        # will be at `1 - 2 ** block_index`.
        cls_pos = pos_id.new_tensor([-(2 ** block_index) + 1])
        pooled_pos_id = pos_id[1:-1] if config.truncate_seq else pos_id[1:]
        return torch.cat([cls_pos, pooled_pos_id[::2]], 0)
    else:
        return pos_id[::2]

In [448]:
def get_position_embeds(seq_len, dtype, device):
    """
    Create and cache inputs related to relative position encoding. Those are very different depending on whether we
    are using the factorized or the relative shift attention:

    For the factorized attention, it returns the matrices (phi, pi, psi, omega) used in the paper, appendix A.2.2,
    final formula.

    For the relative shif attention, it returns all possible vectors R used in the paper, appendix A.2.1, final
    formula.

    Paper link: https://arxiv.org/abs/2006.03236
    """
    d_model = config.d_model
    if config.attention_type == "factorized":
        # Notations from the paper, appending A.2.2, final formula.
        # We need to create and return the matrics phi, psi, pi and omega.
        pos_seq = torch.arange(0, seq_len, 1.0, dtype=dtype, device=device)
        freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=dtype, device=device)
        inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
        sinusoid = pos_seq[:, None] * inv_freq[None]
        sin_embed = torch.sin(sinusoid)
        sin_embed_d = nn.Dropout(config.hidden_dropout)(sin_embed)
        cos_embed = torch.cos(sinusoid)
        cos_embed_d = nn.Dropout(config.hidden_dropout)(cos_embed)
        # This is different from the formula on the paper...
        phi = torch.cat([sin_embed_d, sin_embed_d], dim=-1)
        psi = torch.cat([cos_embed, sin_embed], dim=-1)
        pi = torch.cat([cos_embed_d, cos_embed_d], dim=-1)
        omega = torch.cat([-sin_embed, cos_embed], dim=-1)
        return (phi, pi, psi, omega)
    else:
        # Notations from the paper, appending A.2.1, final formula.
        # We need to create and return all the possible vectors R for all blocks and shifts.
        freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=dtype, device=device)
        inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
        # Maximum relative positions for the first input
        rel_pos_id = torch.arange(-seq_len * 2, seq_len * 2, 1.0, dtype=dtype, device=device)
        zero_offset = seq_len * 2
        sinusoid = rel_pos_id[:, None] * inv_freq[None]
        sin_embed = nn.Dropout(config.hidden_dropout)(torch.sin(sinusoid))
        cos_embed = nn.Dropout(config.hidden_dropout)(torch.cos(sinusoid))
        pos_embed = torch.cat([sin_embed, cos_embed], dim=-1)

        pos = torch.arange(0, seq_len, dtype=dtype, device=device)
        xpos = pos
        pooled_pos = pos
        position_embeds_list = []
        for block_index in range(0, config.num_blocks):
            # For each block with block_index > 0, we need two types position embeddings:
            #   - Attention(pooled-q, unpooled-kv)
            #   - Attention(pooled-q, pooled-kv)
            # For block_index = 0 we only need the second one and leave the first one as None.

            # First type
            if block_index == 0:
                position_embeds_pooling = None
            else:
                pooled_pos = stride_pool_pos(pos, block_index)

                # construct rel_pos_id
                stride = 2 ** (block_index - 1)
                rel_pos = relative_pos(pos, stride, pooled_pos, shift=2)
                print("FirstType rel pos", rel_pos)
                rel_pos = rel_pos[:, None] + zero_offset
                rel_pos = rel_pos.expand(rel_pos.size(0), d_model)
                position_embeds_pooling = torch.gather(pos_embed, 0, rel_pos)

            # Second type
            pos = pooled_pos
            stride = 2 ** block_index
            rel_pos = relative_pos(pos, stride)
            print("SecondType rel pos", rel_pos)

            rel_pos = rel_pos[:, None] + zero_offset
            rel_pos = rel_pos.expand(rel_pos.size(0), d_model)
            position_embeds_no_pooling = torch.gather(pos_embed, 0, rel_pos)

            position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling])

            print()
        return position_embeds_list

In [283]:
config.attention_type = "factorized"

In [284]:
position_embeds = get_position_embeds(
    seq_len, input_embeds.dtype, input_embeds.device)

In [285]:
for i in range(len(position_embeds)):
    print(position_embeds[i].shape)

torch.Size([4, 768])
torch.Size([4, 768])
torch.Size([4, 768])
torch.Size([4, 768])


In [449]:
config.attention_type = ""
position_embeds = get_position_embeds(
    seq_len, input_embeds.dtype, input_embeds.device)

SecondType rel pos tensor([ 4,  3,  2,  1,  0, -1, -2, -3])

FirstType rel pos tensor([ 3,  2,  1,  0, -1, -2, -3, -4])
SecondType rel pos tensor([ 4,  2,  0, -2])

FirstType rel pos tensor([ 2,  0, -2, -4])
SecondType rel pos tensor([4, 0])



In [450]:
for item in position_embeds:
    for sub in item:
        if sub != None:
            print(sub.shape)
        else:
            print("None")
    print()

torch.Size([8, 768])
None

torch.Size([4, 768])
torch.Size([8, 768])

torch.Size([2, 768])
torch.Size([4, 768])



In [240]:
d_model = 768
freq_seq = torch.arange(0, d_model // 2, 1.0)

In [241]:
inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))

In [103]:
rel_pos_id = torch.arange(-seq_len * 2, seq_len * 2, 1.0)

In [106]:
rel_pos_id

tensor([-8., -7., -6., -5., -4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.,  5.,
         6.,  7.])

In [107]:
sinusoid = rel_pos_id[:, None] * inv_freq[None]

In [109]:
sinusoid.shape

torch.Size([16, 384])

In [110]:
sin_embed = nn.Dropout(config.hidden_dropout)(torch.sin(sinusoid))
cos_embed = nn.Dropout(config.hidden_dropout)(torch.cos(sinusoid))
pos_embed = torch.cat([sin_embed, cos_embed], dim=-1)

In [113]:
pos_embed.shape

torch.Size([16, 768])

In [115]:
pos = torch.arange(0, seq_len)
pooled_pos = pos
position_embeds_list = []

In [123]:
for i in range((seq_len)):
    pooled_pos = stride_pool_pos(pos, i)
    print(pooled_pos)

tensor([0, 1])
tensor([-1,  1])
tensor([-3,  1])
tensor([-7,  1])


In [125]:
config.separate_cls

True

In [128]:
pos

tensor([0, 1, 2, 3])

In [127]:
pos[::2]

tensor([0, 2])

In [140]:
pos[::2]

tensor([0, 2])

In [150]:
cls_pos = pos.new_tensor([-(2 ** 3) + 1])
pooled_pos_id = pos[1:-1]
torch.cat([cls_pos, pooled_pos_id[::2]], 0)

tensor([-7,  1])

In [151]:
pos

tensor([0, 1, 2, 3])

In [162]:
pos[::2]

tensor([0, 2])

In [163]:
cls_pos

tensor([-7])

In [210]:
pos

tensor([0, 1, 2, 3])

In [226]:
pos_embed.shape

torch.Size([16, 768])

In [225]:
for block_index in range(1, 3):
    pooled_pos = stride_pool_pos(pos, block_index)
    stride = 2 ** (block_index - 1)
    rel_pos = relative_pos(pos, stride, pooled_pos, shift=2)
    print("pooled_pos: ", pooled_pos)
    print("stride: ", stride)
    print("rel_pos: ", rel_pos)
    
    zero_offset = seq_len * 2

    rel_pos = rel_pos[:, None] + zero_offset
    print(rel_pos.shape)
    rel_pos = rel_pos.expand(rel_pos.size(0), d_model)
    print(rel_pos.shape)
    position_embeds_pooling = torch.gather(pos_embed, 0, rel_pos)
    print(position_embeds_pooling.shape)
    
    print()

ref_point:  tensor(-1)
num_remove:  4
tensor(3) tensor(-4)
pooled_pos:  tensor([-1,  1])
stride:  1
rel_pos:  tensor([ 3,  2,  1,  0, -1, -2, -3, -4])
torch.Size([8, 1])
torch.Size([8, 768])
torch.Size([8, 768])

ref_point:  tensor(-3)
num_remove:  4
tensor(5) tensor(-6)
pooled_pos:  tensor([-3,  1])
stride:  2
rel_pos:  tensor([ 5,  3,  1, -1, -3, -5])
torch.Size([6, 1])
torch.Size([6, 768])
torch.Size([6, 768])



In [191]:
def relative_pos(pos, stride, pooled_pos=None, shift=1):
        if pooled_pos is None:
            pooled_pos = pos

        ref_point = pooled_pos[0] - pos[0]
        print("ref_point: ", ref_point)
        num_remove = shift * len(pooled_pos)
        print("num_remove: ", num_remove)
        max_dist = ref_point + num_remove * stride
        min_dist = pooled_pos[0] - pos[-1]
        print(max_dist, min_dist)
        return torch.arange(max_dist, min_dist - 1, -stride, dtype=torch.long, device=pos.device)

In [193]:
pos

tensor([0, 1, 2, 3])

In [199]:
relative_pos(pos, 1, [-3, 1], shift=3)

ref_point:  tensor(-3)
num_remove:  6
tensor(3) tensor(-6)


tensor([ 3,  2,  1,  0, -1, -2, -3, -4, -5, -6])

In [288]:
config.pooling_type

'mean'

In [289]:
config.pool_q_only

True

In [290]:
token_type_mat

tensor([[[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]],

        [[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]]])

In [291]:
token_type_mat.shape

torch.Size([2, 4, 4])

### Pool

In [295]:
 def stride_pool( tensor, axis):
        """
        Perform pooling by stride slicing the tensor along the given axis.
        """
        if tensor is None:
            return None

        # Do the stride pool recursively if axis is a list or a tuple of ints.
        if isinstance(axis, (list, tuple)):
            for ax in axis:
                tensor = self.stride_pool(tensor, ax)
            return tensor

        # Do the stride pool recursively if tensor is a list or tuple of tensors.
        if isinstance(tensor, (tuple, list)):
            return type(tensor)(self.stride_pool(x, axis) for x in tensor)

        # Deal with negative axis
        axis %= tensor.ndim

        axis_slice = (
            slice(None, -1, 2) if config.separate_cls and config.truncate_seq else slice(None, None, 2)
        )
        enc_slice = [slice(None)] * axis + [axis_slice]
        if config.separate_cls:
            cls_slice = [slice(None)] * axis + [slice(None, 1)]
            tensor = torch.cat([tensor[cls_slice], tensor], axis=axis)
        return tensor[enc_slice]

In [333]:
stride_pool(token_type_mat, 1)

tensor([[[True, True, True, True],
         [True, True, True, True]],

        [[True, True, True, True],
         [True, True, True, True]]])

In [297]:
_.shape

torch.Size([2, 2, 4])

In [342]:
stride_pool(cls_mask, 0)

tensor([[0., 0., 0., 0.],
        [0., 1., 1., 1.]])

In [343]:
_.shape

torch.Size([2, 4])

In [318]:
stride_pool(torch.LongTensor([[101, 2057, 1012, 102, 101], [101, 1, 2, 102, 101]]), 1).shape

torch.Size([2, 3])

In [316]:
torch.LongTensor([[101, 2057, 1012, 102], [101, 1, 2, 102]]).shape

torch.Size([2, 4])

In [325]:
def pool_tensor(tensor, mode="mean", stride=2):
        """Apply 1D pooling to a tensor of size [B x T (x H)]."""
        if tensor is None:
            return None

        # Do the pool recursively if tensor is a list or tuple of tensors.
        if isinstance(tensor, (tuple, list)):
            return type(tensor)(pool_tensor(tensor, mode=mode, stride=stride) for x in tensor)

        if config.separate_cls:
            suffix = tensor[:, :-1] if config.truncate_seq else tensor
            tensor = torch.cat([tensor[:, :1], suffix], dim=1)

        ndim = tensor.ndim
        if ndim == 2:
            tensor = tensor[:, None, :, None]
        elif ndim == 3:
            tensor = tensor[:, None, :, :]
        # Stride is applied on the second-to-last dimension.
        stride = (stride, 1)

        if mode == "mean":
            tensor = F.avg_pool2d(tensor, stride, stride=stride, ceil_mode=True)
        elif mode == "max":
            tensor = F.max_pool2d(tensor, stride, stride=stride, ceil_mode=True)
        elif mode == "min":
            tensor = -F.max_pool2d(-tensor, stride, stride=stride, ceil_mode=True)
        else:
            raise NotImplementedError("The supported modes are 'mean', 'max' and 'min'.")

        if ndim == 2:
            return tensor[:, 0, :, 0]
        elif ndim == 3:
            return tensor[:, 0]
        return tensor

In [330]:
pool_tensor(torch.LongTensor([[101, 2057, 1012, 102, 101], [101, 1, 2, 102, 101]]))

tensor([[ 101, 1534,  102],
        [ 101,    1,  102]])

In [331]:
_.shape

torch.Size([2, 3])

In [328]:
input_embeds.shape

torch.Size([2, 4, 768])

In [332]:
torch.LongTensor([[101, 2057, 1012, 102, 101], [101, 1, 2, 102, 101]]).shape

torch.Size([2, 5])

In [337]:
stride_pool(token_type_mat, 1).shape

torch.Size([2, 2, 4])

In [335]:
stride_pool(stride_pool(token_type_mat, 1), 2).shape

torch.Size([2, 2, 2])

In [380]:
def pre_attention_pooling( output, attention_inputs):
    """ Pool `output` and the proper parts of `attention_inputs` before the attention layer. """
    position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
    token_type_mat = stride_pool(token_type_mat, 1)
    cls_mask = stride_pool(cls_mask, 0)
    output = pool_tensor(output, mode=config.pooling_type)
    attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
    return output, attention_inputs

### Attention

In [344]:
config.d_model, config.n_head, config.d_head

(768, 12, 64)

In [551]:
class FunnelRelMultiheadAttention(nn.Module):
    def __init__(self, config, block_index):
        super().__init__()
        self.config = config
        self.block_index = block_index
        d_model, n_head, d_head = config.d_model, config.n_head, config.d_head

        self.hidden_dropout = nn.Dropout(config.hidden_dropout)
        self.attention_dropout = nn.Dropout(config.attention_dropout)

        self.q_head = nn.Linear(d_model, n_head * d_head, bias=False)
        self.k_head = nn.Linear(d_model, n_head * d_head)
        self.v_head = nn.Linear(d_model, n_head * d_head)

        self.r_w_bias = nn.Parameter(torch.zeros([n_head, d_head]))
        self.r_r_bias = nn.Parameter(torch.zeros([n_head, d_head]))
        self.r_kernel = nn.Parameter(torch.zeros([d_model, n_head, d_head]))
        self.r_s_bias = nn.Parameter(torch.zeros([n_head, d_head]))
        self.seg_embed = nn.Parameter(torch.zeros([2, n_head, d_head]))

        self.post_proj = nn.Linear(n_head * d_head, d_model)
        self.layer_norm = nn.LayerNorm(d_model, eps=config.layer_norm_eps)
        self.scale = 1.0 / (d_head ** 0.5)
    
    def forward(self, query, key, value, attention_inputs, output_attentions=False, i=1):
        # query has shape batch_size x seq_len x d_model
        # key and value have shapes batch_size x context_len x d_model
        position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs

        batch_size, seq_len, _ = query.shape
        context_len = key.shape[1]
        n_head, d_head = self.config.n_head, self.config.d_head

        # Shape batch_size x seq_len x n_head x d_head
        q_head = self.q_head(query).view(batch_size, seq_len, n_head, d_head)
        print(q_head.shape)
        # Shapes batch_size x context_len x n_head x d_head
        k_head = self.k_head(key).view(batch_size, context_len, n_head, d_head)
        print(k_head.shape)
        v_head = self.v_head(value).view(batch_size, context_len, n_head, d_head)

        q_head = q_head * self.scale
        # Shape n_head x d_head
        r_w_bias = self.r_w_bias * self.scale
        # Shapes batch_size x n_head x seq_len x context_len
        content_score = torch.einsum("bind,bjnd->bnij", q_head + r_w_bias, k_head)
        
        positional_attn = self.relative_positional_attention(position_embeds, q_head, context_len, cls_mask, i)
        token_type_attn = self.relative_token_type_attention(token_type_mat, q_head, cls_mask)
        print("content score: ", content_score.shape)
        print("positional score: ", positional_attn.shape)
        print("token_type score", token_type_attn.shape)

        # merge attention scores
        attn_score = content_score# + positional_attn + token_type_attn

        # precision safe in case of mixed precision training
        dtype = attn_score.dtype
        attn_score = attn_score.float()
        # perform masking
        if attention_mask is not None:
            attn_score = attn_score - 1e6 * (1 - attention_mask[:, None, None].float())
        # attention probability
        attn_prob = torch.softmax(attn_score, dim=-1, dtype=dtype)
        attn_prob = self.attention_dropout(attn_prob)

        # attention output, shape batch_size x seq_len x n_head x d_head
        attn_vec = torch.einsum("bnij,bjnd->bind", attn_prob, v_head)

        # Shape shape batch_size x seq_len x d_model
        attn_out = self.post_proj(attn_vec.reshape(batch_size, seq_len, n_head * d_head))
        attn_out = self.hidden_dropout(attn_out)

        output = self.layer_norm(query + attn_out)
        return (output, attn_prob) if output_attentions else (output,)
    
    def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None, i=1):
        """ Relative attention score for the positional encodings """
        # q_head has shape batch_size x sea_len x n_head x d_head

        shift = 2 if q_head.shape[1] != context_len else 1
        # Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236)
        # Grab the proper positional encoding, shape max_rel_len x d_model
        r = position_embeds[i][shift - 1]
        # Shape n_head x d_head
        v = self.r_r_bias * self.scale
        # Shape d_model x n_head x d_head
        w_r = self.r_kernel

        # Shape max_rel_len x n_head x d_model
        r_head = torch.einsum("td,dnh->tnh", r, w_r)
        # Shape batch_size x n_head x seq_len x max_rel_len
        print("q_head + v shape: ", (q_head + v).shape)
        print("q_head: ", q_head.shape)
        print("V", v.shape)
        positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
        # Shape batch_size x n_head x seq_len x context_len
        print(positional_attn.shape, context_len, shift)
        positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
        print(positional_attn.shape)
        if cls_mask is not None:
            positional_attn *= cls_mask
        return positional_attn
    
    def relative_token_type_attention(self, token_type_mat, q, cls_mask=None):
        # q => batch_size × context_len × n_head × d_head
        # token_type_mat => batch_size × context_len × seq_len
        batch_size, context_len, seq_len = token_type_mat.shape
        r_s_bias = self.r_s_bias * self.scale
        # batch_size × n_head × context_len × 2
        token_type_bias = torch.einsum("bind,snd->bnis", q + r_s_bias, self.seg_embed)
        # batch_size × n_head × context_len × seq_len
        token_type_mat = token_type_mat[:, None].expand(
            [batch_size, q.shape[2], context_len, seq_len])
        # batch_size × n_head × context_len
        diff_token_type, same_token_type = torch.split(token_type_bias, 1, dim=-1)
        # batch_size × n_head × context_len × seq_len
        print("token_type_mat shape", token_type_mat.shape)
#         print(token_type_mat)
        token_type_attn = torch.where(
            token_type_mat, 
            same_token_type.expand(token_type_mat.shape),
            diff_token_type.expand(token_type_mat.shape)
            )
#         print(token_type_attn)
        if cls_mask is not None:
            token_type_attn *= cls_mask
        return token_type_attn

In [545]:
def _relative_shift_gather(positional_attn, context_len, shift):
    batch_size, n_head, seq_len, max_rel_len = positional_attn.shape
    # max_rel_len = 2 * context_len + shift -1 is the numbers of possible relative positions i-j

    # What's next is the same as doing the following gather, which might be clearer code but less efficient.
    # idxs = context_len + torch.arange(0, context_len).unsqueeze(0) - torch.arange(0, seq_len).unsqueeze(1)
    # # matrix of context_len + i-j
    # return positional_attn.gather(3, idxs.expand([batch_size, n_head, context_len, context_len]))

    positional_attn = torch.reshape(positional_attn, [batch_size, n_head, max_rel_len, seq_len])
    positional_attn = positional_attn[:, :, shift:, :]
    positional_attn = torch.reshape(positional_attn, [batch_size, n_head, seq_len, max_rel_len - shift])
    positional_attn = positional_attn[..., :context_len]
    return positional_attn

In [546]:
fra = FunnelRelMultiheadAttention(config, 0)

In [539]:
position_embeds

[[tensor([[-0.8409, -0.7684, -0.6909,  ...,  1.1111,  1.1111,  1.1111],
          [ 0.1568,  0.2345,  0.3093,  ...,  1.1111,  1.1111,  1.1111],
          [ 1.0103,  1.0311,  1.0492,  ...,  1.1111,  1.1111,  1.1111],
          ...,
          [-0.0000, -0.9205, -0.9058,  ...,  1.1111,  1.1111,  1.1111],
          [-1.0103, -0.0000, -1.0492,  ...,  1.1111,  1.1111,  1.1111],
          [-0.1568, -0.2345, -0.3093,  ...,  1.1111,  1.1111,  0.0000]]),
  None],
 [tensor([[-0.8409, -0.7684, -0.6909,  ...,  1.1111,  1.1111,  1.1111],
          [ 1.0103,  1.0311,  1.0492,  ...,  1.1111,  1.1111,  1.1111],
          [ 0.0000,  0.0000,  0.0000,  ...,  1.1111,  1.1111,  1.1111],
          [-1.0103, -0.0000, -1.0492,  ...,  1.1111,  1.1111,  1.1111]]),
  tensor([[ 0.1568,  0.2345,  0.3093,  ...,  1.1111,  1.1111,  1.1111],
          [ 1.0103,  1.0311,  1.0492,  ...,  1.1111,  1.1111,  1.1111],
          [ 0.9350,  0.0000,  0.9058,  ...,  1.1111,  0.0000,  1.1111],
          ...,
          [-1.0103, -

In [472]:
token_type_mat.shape, attention_mask.shape, cls_mask.shape

(torch.Size([2, 1, 4]), torch.Size([2, 4]), torch.Size([1, 4]))

In [366]:
attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)

In [466]:
pooled_hidden, attention_inputs = pre_attention_pooling(input_embeds, attention_inputs)

In [468]:
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs

In [593]:
config = FunnelConfig()
fe = FunnelEmbeddings(config)
input_ids = torch.LongTensor([[101, 2057, 1012, 102], [101, 1, 2, 102]])
input_embeds = fe(input_ids)
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
attention_mask = attention_mask.type_as(input_embeds)
token_type_ids = torch.zeros(input_ids.shape, dtype=torch.long, device=input_ids.device)

seq_len = input_embeds.size(1)
token_type_mat = token_type_ids_to_mat(token_type_ids)
cls_mask = (F.pad(input_embeds.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0)))
position_embeds = get_position_embeds(seq_len, input_embeds.dtype, input_embeds.device)

attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
pooled_hidden, attention_inputs = pre_attention_pooling(input_embeds, attention_inputs)
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs

fra = FunnelRelMultiheadAttention(config, 2)
flh, = fra(pooled_hidden, input_embeds, input_embeds, attention_inputs, i=1)
flh

SecondType rel pos tensor([ 4,  3,  2,  1,  0, -1, -2, -3])

FirstType rel pos tensor([ 3,  2,  1,  0, -1, -2, -3, -4])
SecondType rel pos tensor([ 4,  2,  0, -2])

FirstType rel pos tensor([ 2,  0, -2, -4])
SecondType rel pos tensor([4, 0])

torch.Size([2, 2, 12, 64])
torch.Size([2, 4, 12, 64])
q_head + v shape:  torch.Size([2, 2, 12, 64])
q_head:  torch.Size([2, 2, 12, 64])
V torch.Size([12, 64])
torch.Size([2, 12, 2, 8]) 4 2
torch.Size([2, 12, 2, 4])
token_type_mat shape torch.Size([2, 12, 2, 4])
content score:  torch.Size([2, 12, 2, 4])
positional score:  torch.Size([2, 12, 2, 4])
token_type score torch.Size([2, 12, 2, 4])


tensor([[[ 1.0451, -0.0450, -1.1673,  ...,  0.0464,  1.0833,  1.7608],
         [ 1.1183, -0.5268,  1.4921,  ...,  0.5391, -0.3932, -0.4550]],

        [[ 1.0034,  0.0594, -0.1742,  ...,  0.5220,  0.9122,  1.6947],
         [-0.5062, -1.0373, -0.5044,  ...,  0.6907,  0.4824,  1.2550]]],
       grad_fn=<NativeLayerNormBackward>)

In [594]:
flh.shape

torch.Size([2, 2, 768])

### Decoder

In [513]:
config.num_decoder_layers

2

In [514]:
def upsample(x, stride, target_len, separate_cls=True, truncate_seq=False):
    """Upsample tensor `x` to match `target_len` by repeating the tokens `stride` time on the sequence length
    dimension."""
    if stride == 1:
        return x
    if separate_cls:
        cls = x[:, :1]
        x = x[:, 1:]
    output = torch.repeat_interleave(x, repeats=stride, dim=1)
    if separate_cls:
        if truncate_seq:
            output = nn.functional.pad(output, (0, 0, 0, stride - 1, 0, 0))
        output = output[:, : target_len - 1]
        output = torch.cat([cls, output], dim=1)
    else:
        output = output[:, :target_len]
    return output

In [605]:
class FunnelPositionwiseFFN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear_1 = nn.Linear(config.d_model, config.d_inner)
        self.activation_function = ACT2FN[config.hidden_act]
        self.activation_dropout = nn.Dropout(config.activation_dropout)
        self.linear_2 = nn.Linear(config.d_inner, config.d_model)
        self.dropout = nn.Dropout(config.hidden_dropout)
        self.layer_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps)

    def forward(self, hidden):
        h = self.linear_1(hidden)
        h = self.activation_function(h)
        h = self.activation_dropout(h)
        h = self.linear_2(h)
        h = self.dropout(h)
        return self.layer_norm(hidden + h)


class FunnelLayer(nn.Module):
    def __init__(self, config, block_index):
        super().__init__()
        self.attention = FunnelRelMultiheadAttention(config, block_index)
        self.ffn = FunnelPositionwiseFFN(config)

    def forward(self, query, key, value, attention_inputs, output_attentions=False, i=0):
        attn = self.attention(query, key, value, attention_inputs, output_attentions=output_attentions, i=0)
        output = self.ffn(attn[0])
        return (output, attn[1]) if output_attentions else (output,)

import math
def gelu_new(x):
    """Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
    Also see https://arxiv.org/abs/1606.08415
    """
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

ACT2FN = {
    "gelu_new": gelu_new,
}

In [610]:
class FunnelDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList(
            [FunnelLayer(config, 0) for _ in range(config.num_decoder_layers)]
        )

    def forward(
        self,
        final_hidden,
        first_block_hidden,
        attention_mask=None,
        token_type_ids=None
    ):
        upsampled_hidden = upsample(
            final_hidden,
            stride=2 ** (len(self.config.block_sizes) - 1),
            target_len=first_block_hidden.shape[1],
            separate_cls=self.config.separate_cls,
            truncate_seq=self.config.truncate_seq,
        )

        hidden = upsampled_hidden + first_block_hidden
        
        seq_len = hidden.size(1)
        position_embeds = get_position_embeds(seq_len, hidden.dtype, hidden.device)
        token_type_mat = token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
        cls_mask = (
            F.pad(hidden.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0))
            if self.config.separate_cls
            else None
        )
        attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)

        for layer in self.layers:
            layer_output = layer(hidden, hidden, hidden, attention_inputs, output_attentions=False, i=0)
            hidden = layer_output[0]
            print("Final Print", hidden.shape)

In [611]:
fd = FunnelDecoder(config)

In [612]:
config = FunnelConfig()
fe = FunnelEmbeddings(config)
input_ids = torch.LongTensor([[101, 2057, 1012, 102], [101, 1, 2, 102]])
input_embeds = fe(input_ids)
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
attention_mask = attention_mask.type_as(input_embeds)
token_type_ids = torch.zeros(input_ids.shape, dtype=torch.long, device=input_ids.device)

seq_len = input_embeds.size(1)
token_type_mat = token_type_ids_to_mat(token_type_ids)
cls_mask = (F.pad(input_embeds.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0)))
position_embeds = get_position_embeds(seq_len, input_embeds.dtype, input_embeds.device)

attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
fra = FunnelRelMultiheadAttention(config, 0)
fth, = fra(input_embeds, input_embeds, input_embeds, attention_inputs, output_attentions=False, i=0)
fth.shape

SecondType rel pos tensor([ 4,  3,  2,  1,  0, -1, -2, -3])

FirstType rel pos tensor([ 3,  2,  1,  0, -1, -2, -3, -4])
SecondType rel pos tensor([ 4,  2,  0, -2])

FirstType rel pos tensor([ 2,  0, -2, -4])
SecondType rel pos tensor([4, 0])

torch.Size([2, 4, 12, 64])
torch.Size([2, 4, 12, 64])
q_head + v shape:  torch.Size([2, 4, 12, 64])
q_head:  torch.Size([2, 4, 12, 64])
V torch.Size([12, 64])
torch.Size([2, 12, 4, 8]) 4 1
torch.Size([2, 12, 4, 4])
token_type_mat shape torch.Size([2, 12, 4, 4])
content score:  torch.Size([2, 12, 4, 4])
positional score:  torch.Size([2, 12, 4, 4])
token_type score torch.Size([2, 12, 4, 4])


torch.Size([2, 4, 768])

In [613]:
fd(flh, fth, attention_mask, token_type_ids)

SecondType rel pos tensor([ 4,  3,  2,  1,  0, -1, -2, -3])

FirstType rel pos tensor([ 3,  2,  1,  0, -1, -2, -3, -4])
SecondType rel pos tensor([ 4,  2,  0, -2])

FirstType rel pos tensor([ 2,  0, -2, -4])
SecondType rel pos tensor([4, 0])

torch.Size([2, 4, 12, 64])
torch.Size([2, 4, 12, 64])
q_head + v shape:  torch.Size([2, 4, 12, 64])
q_head:  torch.Size([2, 4, 12, 64])
V torch.Size([12, 64])
torch.Size([2, 12, 4, 8]) 4 1
torch.Size([2, 12, 4, 4])
token_type_mat shape torch.Size([2, 12, 4, 4])
content score:  torch.Size([2, 12, 4, 4])
positional score:  torch.Size([2, 12, 4, 4])
token_type score torch.Size([2, 12, 4, 4])
Final Print torch.Size([2, 4, 768])
torch.Size([2, 4, 12, 64])
torch.Size([2, 4, 12, 64])
q_head + v shape:  torch.Size([2, 4, 12, 64])
q_head:  torch.Size([2, 4, 12, 64])
V torch.Size([12, 64])
torch.Size([2, 12, 4, 8]) 4 1
torch.Size([2, 12, 4, 4])
token_type_mat shape torch.Size([2, 12, 4, 4])
content score:  torch.Size([2, 12, 4, 4])
positional score:  torch.

### Tokenizer

In [735]:
import json
import os
import random

from absl import flags
import absl.logging as _logging

import numpy as np
import tensorflow.compat.v1 as tf

import collections
tf.disable_v2_behavior()


In [622]:
sorted(tf.io.gfile.glob("/Users/HaoShaochun/Yam/All4NLP/*.md"))[0::1]

['/Users/HaoShaochun/Yam/All4NLP/README.md']

In [623]:
x = list("DFSFSDFSD")

In [624]:
x

['D', 'F', 'S', 'F', 'S', 'D', 'F', 'S', 'D']

In [626]:
x[0::1]

['D', 'F', 'S', 'F', 'S', 'D', 'F', 'S', 'D']

In [628]:
bt = BasicTokenizer()

In [644]:
bt.convert_text_to_tokens("范德萨发生i love you cleaning the booking")

['范', '德', '萨', '发', '生', 'i', 'love', 'you', 'cleaning', 'the', 'booking']

In [641]:
def whitespace_tokenize(text):
  """Runs basic whitespace cleaning and splitting on a piece of text."""
  text = text.strip()
  if not text:
    return []
  tokens = text.split()
  return tokens

In [636]:
import unicodedata

In [651]:
unicodedata.category("，")

'Po'

In [762]:
import six
def convert_to_unicode(text):
  """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
  if six.PY3:
    if isinstance(text, str):
      return text
    elif isinstance(text, bytes):
      return text.decode("utf-8", "ignore")
    else:
      raise ValueError("Unsupported string type: %s" % (type(text)))
  elif six.PY2:
    if isinstance(text, str):
      return text.decode("utf-8", "ignore")
    elif isinstance(text, unicode):
      return text
    else:
      raise ValueError("Unsupported string type: %s" % (type(text)))
  else:
    raise ValueError("Not running on Python2 or Python 3?")
def _is_whitespace(char):
  """Checks whether `chars` is a whitespace character."""
  # \t, \n, and \r are technically control characters but we treat them
  # as whitespace since they are generally considered as such.
  if char == " " or char == "\t" or char == "\n" or char == "\r":
    return True
  cat = unicodedata.category(char)
  if cat == "Zs":
    return True
  return False


def _is_control(char):
  """Checks whether `chars` is a control character."""
  # These are technically control characters but we count them as whitespace
  # characters.
  if char == "\t" or char == "\n" or char == "\r":
    return False
  cat = unicodedata.category(char)
  if cat in ("Cc", "Cf"):
    return True
  return False


def _is_punctuation(char):
  """Checks whether `chars` is a punctuation character."""
  cp = ord(char)
  # We treat all non-letter/number ASCII as punctuation.
  # Characters such as "^", "$", and "`" are not in the Unicode
  # Punctuation class but we treat them as punctuation anyways, for
  # consistency.
  if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
      (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
    return True
  cat = unicodedata.category(char)
  if cat.startswith("P"):
    return True
  return False

In [763]:
class BasicTokenizer(object):
  """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""

  def __init__(self, do_lower_case=True):
    """Constructs a BasicTokenizer.
    Args:
      do_lower_case: Whether to lower case the input.
    """
    self.do_lower_case = do_lower_case

  def convert_text_to_tokens(self, text):
    """Tokenizes a piece of text."""
    text = convert_to_unicode(text)
    text = self._text(text)

    # This was added on November 1st, 2018 for the multilingual and Chinese
    # models. This is also applied to the English models now, but it doesn't
    # matter since the English models were not trained on any Chinese data
    # and generally don't have any Chinese data in them (there are Chinese
    # characters in the vocabulary because Wikipedia does have some Chinese
    # words in the English Wikipedia.).
    text = self._tokenize_chinese_chars(text)

    orig_tokens = whitespace_tokenize(text)
    split_tokens = []
    for token in orig_tokens:
      if self.do_lower_case:
        token = token.lower()
        token = self._run_strip_accents(token)
      split_tokens.extend(self._run_split_on_punc(token))

    output_tokens = whitespace_tokenize(" ".join(split_tokens))
    return output_tokens

  def _run_strip_accents(self, text):
    """Strips accents from a piece of text."""
    text = unicodedata.normalize("NFD", text)
    output = []
    for char in text:
      cat = unicodedata.category(char)
      if cat == "Mn":
        continue
      output.append(char)
    return "".join(output)

  def _run_split_on_punc(self, text):
    """Splits punctuation on a piece of text."""
    chars = list(text)
    i = 0
    start_new_word = True
    output = []
    while i < len(chars):
      char = chars[i]
      if _is_punctuation(char):
        output.append([char])
        start_new_word = True
      else:
        if start_new_word:
          output.append([])
        start_new_word = False
        output[-1].append(char)
      i += 1

    return ["".join(x) for x in output]

  def _tokenize_chinese_chars(self, text):
    """Adds whitespace around any CJK character."""
    output = []
    for char in text:
      cp = ord(char)
      if self._is_chinese_char(cp):
        output.append(" ")
        output.append(char)
        output.append(" ")
      else:
        output.append(char)
    return "".join(output)

  def _is_chinese_char(self, cp):
    """Checks whether CP is the codepoint of a CJK character."""
    # This defines a "chinese character" as anything in the CJK Unicode block:
    #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
    #
    # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
    # despite its name. The modern Korean Hangul alphabet is a different block,
    # as is Japanese Hiragana and Katakana. Those alphabets are used to write
    # space-separated words, so they are not treated specially and handled
    # like the all of the other languages.
    if ((cp >= 0x4E00 and cp <= 0x9FFF) or  #
        (cp >= 0x3400 and cp <= 0x4DBF) or  #
        (cp >= 0x20000 and cp <= 0x2A6DF) or  #
        (cp >= 0x2A700 and cp <= 0x2B73F) or  #
        (cp >= 0x2B740 and cp <= 0x2B81F) or  #
        (cp >= 0x2B820 and cp <= 0x2CEAF) or
        (cp >= 0xF900 and cp <= 0xFAFF) or  #
        (cp >= 0x2F800 and cp <= 0x2FA1F)):  #
      return True

    return False

  def _text(self, text):
    """Performs invalid character removal and whitespace cleanup on text."""
    output = []
    for char in text:
      cp = ord(char)
      if cp == 0 or cp == 0xfffd or _is_control(char):
        continue
      if _is_whitespace(char):
        output.append(" ")
      else:
        output.append(char)
    return "".join(output)

In [764]:
class WordpieceTokenizer(object):
  """Runs WordPiece tokenziation."""

  def __init__(self, vocab, unk_token="<unk>", max_input_chars_per_word=200):
    self.vocab = vocab
    self.unk_token = unk_token
    self.max_input_chars_per_word = max_input_chars_per_word

  def convert_text_to_tokens(self, text):
    """Tokenizes a piece of text into its word pieces.
    This uses a greedy longest-match-first algorithm to perform tokenization
    using the given vocabulary.
    For example:
      input = "unaffable"
      output = ["un", "##aff", "##able"]
    Args:
      text: A single token or whitespace separated tokens. This should have
        already been passed through `BasicTokenizer.
    Returns:
      A list of wordpiece tokens.
    """

    text = convert_to_unicode(text)

    output_tokens = []
    for token in whitespace_tokenize(text):
      chars = list(token)
      if len(chars) > self.max_input_chars_per_word:
        output_tokens.append(self.unk_token)
        continue

      is_bad = False
      start = 0
      sub_tokens = []
      while start < len(chars):
        end = len(chars)
        cur_substr = None
#         print(start, end)
        while start < end:
          substr = "".join(chars[start:end])
#           print(substr)
          if start > 0:
            substr = "##" + substr
          if substr in self.vocab:
            cur_substr = substr
            break
          end -= 1
        if cur_substr is None:
          is_bad = True
          break
        sub_tokens.append(cur_substr)
        start = end

      if is_bad:
        output_tokens.append(self.unk_token)
      else:
        output_tokens.extend(sub_tokens)
    return output_tokens

class FullTokenizer(object):
  """Runs end-to-end tokenziation."""

  def __init__(self, vocab_file, do_lower_case=True):
    self.vocab = load_vocab(vocab_file)
    self.inv_vocab = {v: k for k, v in self.vocab.items()}
    self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
    self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

  def convert_text_to_tokens(self, text):
    split_tokens = []
    for token in self.basic_tokenizer.convert_text_to_tokens(text):
      for sub_token in self.wordpiece_tokenizer.convert_text_to_tokens(token):
        split_tokens.append(sub_token)

    return split_tokens

  def get_token_id(self, token):
    return self.vocab[token]

  def convert_tokens_to_ids(self, tokens):
    return convert_by_vocab(self.vocab, tokens)

  def convert_ids_to_tokens(self, ids):
    return convert_by_vocab(self.inv_vocab, ids)

  def convert_text_to_ids(self, text):
    tokens = self.convert_text_to_tokens(text)
    return self.convert_tokens_to_ids(tokens)

  def convert_ids_to_text(self, ids):
    tokens = self.convert_ids_to_tokens(ids)
    return " ".join(tokens)

  def is_start_id(self, token_id):
    token = self.inv_vocab[token_id]
    return not token.startswith("##")

  def is_func_id(self, token_id):
    token = self.inv_vocab[token_id]
    return self.is_func_token(token)

  def is_func_token(self, token):
    return token != "<unk>" and token.startswith("<") and token.endswith(">")

  def get_vocab_size(self):
    return len(self.vocab)

def load_vocab(vocab_file):
  """Loads a vocabulary file into a dictionary."""
  vocab = collections.OrderedDict()
  with tf.io.gfile.GFile(vocab_file, "r") as reader:
    while True:
      token = convert_to_unicode(reader.readline())
      if not token:
        break
      token = token.strip()
      if token not in vocab:
        vocab[token] = len(vocab)
  return vocab

def convert_by_vocab(vocab, items):
  """Converts a sequence of [tokens|ids] using the vocab."""
  output = []
  for item in items:
    output.append(vocab[item])
  return output


In [712]:
wt = WordpieceTokenizer(vocab)
wt.convert_text_to_tokens("lovingx")

lovingx
0 7
lovingx
loving
lovin
lovi
lov
lo
2 7
vingx
ving
6 7
x


['lo', '##ving', '##x']

In [713]:
"##ving" in vocab

True

In [708]:
import pnlp

In [709]:
vocab = pnlp.read_lines("/Volumes/YamHd/Lab/1Models/chinese_rbt3_pytorch/vocab.txt")

In [719]:
hist, bins = np.histogram([np.random.randint(1000) for i in range(100)],
                            bins=[0, 64, 128, 256, 512, 1024, 2048, 102400])

In [723]:
perm_indices = np.random.permutation(1000)

In [725]:
perm_indices.shape

(1000,)

In [728]:
np.logical_not([True, True, False])

array([False, False,  True])

### PrepareData

In [800]:
def create_pretrain_data(input_paths, tokenizer):
  """Load data and call corresponding create_func."""
  input_shards = []

  # working structure used to store each document
  input_data, sent_ids = [], []
  end_of_doc = False

  # monitor doc length and number of tokens
  doc_length = []
  total_num_tok = 0

  for input_path in input_paths:
    sent_id, line_cnt = True, 0

    tf.logging.info("Start processing %s", input_path)
    for line in tf.io.gfile.GFile(input_path):
      if line_cnt % 100000 == 0:
        tf.logging.info("Loading line %d", line_cnt)

      if not line.strip():
        # encounter an empty line (end of a document)
        end_of_doc = True
        cur_sent = []
      else:
        cur_sent = tokenizer.convert_text_to_ids(line.strip())

      if cur_sent:
        input_data.extend(cur_sent)
        sent_ids.extend([sent_id] * len(cur_sent))
        sent_id = not sent_id

      if end_of_doc:
        # monitor over doc lengths
        doc_length.append(len(input_data))

        # only retain docs longer than `min_doc_len`
        if len(input_data) >= max(1, 1):
          input_data = np.array(input_data, dtype=np.int64)
          sent_ids = np.array(sent_ids, dtype=np.bool)
          input_shards.append((input_data, sent_ids))
          total_num_tok += len(input_data)

        # refresh working structs
        input_data, sent_ids = [], []
        end_of_doc = False

      line_cnt += 1

    tf.logging.info("Finish %s with %d lines.", input_path, line_cnt)

  print(input_shards)
  tf.logging.info("[Task %d] Total number tokens: %d", 0,
                  total_num_tok)

  hist, bins = np.histogram(doc_length,
                            bins=[0, 64, 128, 256, 512, 1024, 2048, 102400])
  percent = hist / np.sum(hist)
  tf.logging.info("***** Doc length histogram *****")
  for pct, l, r in zip(percent, bins[:-1], bins[1:]):
    tf.logging.info("  - [%d, %d]: %.4f", l, r, pct)

  # Randomly shuffle input shards (with a fixed but unique random seed)
  np.random.seed(100 * 0 + 0)
  perm_indices = np.random.permutation(len(input_shards))

  input_data_list, sent_ids_list = [], []
  prev_sent_id = None
  for perm_idx in perm_indices:
    input_data, sent_ids = input_shards[perm_idx]
    tf.logging.debug("Idx %d: data %s sent %s", perm_idx,
                     input_data.shape, sent_ids.shape)
    # make sure the `send_ids[0] == not prev_sent_id`
    if prev_sent_id is not None and sent_ids[0] == prev_sent_id:
      sent_ids = np.logical_not(sent_ids)

    # append to temporary list
    input_data_list.append(input_data)
    sent_ids_list.append(sent_ids)

    # update `prev_sent_id`
    prev_sent_id = sent_ids[-1]

  # concat into a flat np.ndarray
  input_data = np.concatenate(input_data_list)
  sent_ids = np.concatenate(sent_ids_list)
  return (input_data, sent_ids)

In [801]:
vocab_file = "/Volumes/YamHd/Lab/1Models/chinese_rbt3_pytorch/vocab.txt"
tokenizer = FullTokenizer(vocab_file)

In [802]:
tokenizer.convert_text_to_tokens("i lvoe you")

['i', 'lv', '##oe', 'you']

In [803]:
tokenizer.convert_text_to_ids("i lvoe you")

[151, 8289, 10115, 8357]

In [804]:
file_paths = sorted(tf.io.gfile.glob("./data/*.txt"))
task_file_paths = file_paths[0::1]

In [805]:
tokenizer.convert_text_to_ids("Doc1 another paragraph.")

[9656, 8148, 9064, 11759, 9519, 8332, 8181, 12220, 119]

In [806]:
data, sent_ids = create_pretrain_data(task_file_paths, tokenizer)

INFO:tensorflow:Start processing ./data/funnel.txt
INFO:tensorflow:Loading line 0
INFO:tensorflow:Finish ./data/funnel.txt with 24 lines.
[(array([ 9656,  8148,  8975,  8174, 13017,  8168,  9575, 12827,   118,
         162, 10477,  8118, 12725,  8755,  8995,   162, 11944,  8303,
        8663,  8174, 10245, 10862,  8332, 11336,  8857,   119,   165,
        8963,  8370, 10253, 11418, 10050,  8179,  8847,  8326,  9273,
        9683,   143,   107, 10234,  8957, 11762,   107,  8134,  9401,
        8354,  8511,   119,  8217,  9470, 12183,  8877,   117, 11136,
        8995, 13050, 11927, 10474,  8118,   131,   133,   147,  9133,
         135,  8330,  8533,  8914,  8829,  8221,   118, 10110, 11227,
        9980, 10936,  8118,   119]), array([ True,  True, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  Tru

In [807]:
def _split_a_and_b(data, sent_ids, begin_idx, tot_len):
  """Split two segments from `data` starting from the index `begin_idx`."""

  data_len = data.shape[0]
  if begin_idx + tot_len >= data_len:
    tf.logging.info("Not enough data: "
                    "begin_idx %d + tot_len %d >= data_len %d",
                    begin_idx, tot_len, data_len)
    return None

  end_idx = begin_idx + 1
  cut_points = []
  while end_idx < data_len:
    if sent_ids[end_idx] != sent_ids[end_idx - 1]:
      if end_idx - begin_idx >= tot_len: break
      cut_points.append(end_idx)
    end_idx += 1

  a_begin = begin_idx
  if not cut_points or random.random() < 0.5:
    # negative pair
    label = 0
    if not cut_points:
      a_end = end_idx
    else:
      a_end = random.choice(cut_points)

    b_len = max(1, tot_len - (a_end - a_begin))
    b_begin = random.randint(0, data_len - b_len)
    b_end = b_begin + b_len

    # locate a complete sentence for `b`
    while b_begin > 0 and sent_ids[b_begin - 1] == sent_ids[b_begin]:
      b_begin -= 1
    while b_end < data_len and sent_ids[b_end - 1] == sent_ids[b_end]:
      b_end += 1

    new_begin = a_end
  else:
    # positive pair
    label = 1
    a_end = random.choice(cut_points)
    b_begin = a_end
    b_end = end_idx

    new_begin = b_end

  # truncate both a & b
  while a_end - a_begin + b_end - b_begin > tot_len:
    # truncate a (only right)
    if a_end - a_begin > b_end - b_begin:
      a_end -= 1
    # truncate b (both left and right)
    else:
      if random.random() < 0.5:
        b_end -= 1
      else:
        b_begin += 1

  ret = [data[a_begin: a_end], data[b_begin: b_end], label, new_begin]

  return ret


In [830]:
_split_a_and_b(
        data,
        sent_ids,
        begin_idx=0,
        tot_len=5)

[array([9656, 8152]), array([ 107,  114, 8310]), 1, 18]

In [827]:
data

array([ 9656,  8152,   113,   107,  9796, 10203,  8641,  9844, 11762,
         107,   114,  8310,  8533,  8914,  8829,  8303,   119,   119,
        8174, 10234,  8957, 11762,  8310, 10564,  8303,  8303,   119,
         133,   147,  9133,   135, 11668,  8168,  8281,  8174, 10380,
        9796,  9049, 12865,  8511, 11927, 10474,  8118,   117,  8997,
        8376,  9197,  8510, 10380,  9178,  8118,  8205,  8282,  8727,
        8372, 11522,  9233, 11112,  8118,   131,   119,  9656,  8159,
        9684,   118, 10110,  9233, 11112,  8118,   131,  8554,  9178,
       10110,  8910, 10079, 11211,  8174,  9519, 13185, 12017,  8205,
        8174,  9684,  9575, 12827,   118,   162, 10477,  8118, 12725,
        8196,  9264,   119,   113,   107,  9447, 10802,  8180,   107,
         116,   107,  9333, 10260,  8180,   107,   114,   119, 10288,
        8513,  8174,  8792, 12579,  8333,  9401, 11874,  8291, 10208,
        9007,  8625, 10631, 11603,  8995,  8847,  8788, 11667,   119,
         133,   147,

In [828]:
sent_ids

array([ True,  True, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False,  True,  True,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,