In [108]:
import torch
import torch.nn as nn
import math
from torch import _softmax_backward_data as _softmax_backward_data
import numpy as np

import copy
from collections.abc import Sequence

## Basic

In [146]:
def build_relative_position(query_size, key_size):
    """ Build relative position according to the query and key

    We assume the absolute position of query :math:`P_q` is range from 
    (0, query_size) and the absolute position of key :math:`P_k` is range from (0, key_size),
    The relative positions from query to key is
    
    :math:`R_{q \\rightarrow k} = P_q - P_k`

    Args:
        query_size (int): the length of query
        key_size (int): the length of key

    Return:
        :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size]

    """

    q_ids = np.arange(0, query_size)
    k_ids = np.arange(0, key_size)
    rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
    rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
    rel_pos_ids = rel_pos_ids[:query_size, :]
    rel_pos_ids = rel_pos_ids.unsqueeze(0)
    return rel_pos_ids

In [103]:
class BertLayerNorm(nn.Module):
  """LayerNorm module in the TF style (epsilon inside the square root).
  """

  def __init__(self, size, eps=1e-12):
    super().__init__()
    self.weight = nn.Parameter(torch.ones(size))
    self.bias = nn.Parameter(torch.zeros(size))
    self.variance_epsilon = eps

  def forward(self, x):
    input_type = x.dtype
    x = x.float()
    u = x.mean(-1, keepdim=True)
    s = (x - u).pow(2).mean(-1, keepdim=True)
    x = (x - u) / torch.sqrt(s + self.variance_epsilon)
    x = x.to(input_type)
    y = self.weight * x + self.bias
    return y
class BertSelfOutput(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    self.LayerNorm = BertLayerNorm(config.hidden_size, config.layer_norm_eps)
    self.dropout = StableDropout(config.hidden_dropout_prob)
    self.config = config

  def forward(self, hidden_states, input_states, mask=None):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states += input_states
    hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
    return hidden_states

class BertAttention(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.self = DisentangledSelfAttention(config)
    self.output = BertSelfOutput(config)
    self.config = config

  def forward(self, hidden_states, attention_mask, 
              return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
    self_output = self.self(hidden_states, attention_mask, return_att, 
                            query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
    if return_att:
      self_output, att_matrix = self_output
    if query_states is None:
      query_states = hidden_states
    attention_output = self.output(self_output, query_states, attention_mask)

    if return_att:
      return (attention_output, att_matrix)
    else:
      return attention_output

class BertIntermediate(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
    self.intermediate_act_fn = ACT2FN[config.hidden_act] \
      if isinstance(config.hidden_act, str) else config.hidden_act

  def forward(self, hidden_states):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.intermediate_act_fn(hidden_states)
    return hidden_states

class BertOutput(nn.Module):
  def __init__(self, config):
    super(BertOutput, self).__init__()
    self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
    self.LayerNorm = BertLayerNorm(config.hidden_size, config.layer_norm_eps)
    self.dropout = StableDropout(config.hidden_dropout_prob)
    self.config = config

  def forward(self, hidden_states, input_states, mask=None):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states += input_states
    hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
    return hidden_states

class BertLayer(nn.Module):
  def __init__(self, config):
    super(BertLayer, self).__init__()
    self.attention = BertAttention(config)
    self.intermediate = BertIntermediate(config)
    self.output = BertOutput(config)

  def forward(self, hidden_states, attention_mask, return_att=False, 
              query_states=None, relative_pos=None, rel_embeddings=None):
    attention_output = self.attention(hidden_states, attention_mask, return_att=return_att, \
      query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
    if return_att:
      attention_output, att_matrix = attention_output
    intermediate_output = self.intermediate(attention_output)
    layer_output = self.output(intermediate_output, attention_output, attention_mask)
    if return_att:
      return (layer_output, att_matrix)
    else:
      return layer_output

In [109]:
def gelu(x):
  """Implementation of the gelu activation function.
    For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
  """
  return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def swish(x):
  return x * torch.sigmoid(x)

def linear_act(x):
  return x

ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, 
          "swish": swish, "tanh": torch.nn.functional.tanh, 
          "linear": linear_act, 'sigmoid': torch.sigmoid}

## Config

In [62]:
class ModelConfig:
    def __init__(self):        
        self.hidden_size = 768
        self.num_hidden_layers = 12
        self.num_attention_heads = 12
        self.hidden_act = "gelu"
        self.intermediate_size = 3072
        self.hidden_dropout_prob = 0.1
        self.attention_probs_dropout_prob = 0.1
        self.max_position_embeddings = 512
        self.type_vocab_size = 0
        self.initializer_range = 0.02
        self.layer_norm_eps = 1e-7
        self.padding_idx = 0
        self.vocab_size = 21128
        
        self.relative_attention = True
        self.max_relative_positions = 512
        self.position_biased_input = True
        self.pos_att_type = "p2c|c2p"

In [63]:
config = ModelConfig()

## Embedding

In [80]:
# 与 Bert 一致
class BertEmbeddings(nn.Module):
  """Construct the embeddings from word, position and token_type embeddings.
  """
  def __init__(self, config):
    super(BertEmbeddings, self).__init__()
    padding_idx = getattr(config, 'padding_idx', 0)
    self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
    self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx = padding_idx)

    self.position_biased_input = getattr(config, 'position_biased_input', True)
    if not self.position_biased_input:
      self.position_embeddings = None
    else:
      self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)

    if config.type_vocab_size>0:
      self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
    
    if self.embedding_size != config.hidden_size:
      self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
    self.LayerNorm = BertLayerNorm(config.hidden_size, config.layer_norm_eps)
    self.dropout = StableDropout(config.hidden_dropout_prob)
    self.output_to_half = False
    self.config = config

  def forward(self, input_ids, token_type_ids=None, position_ids=None, mask = None):
    seq_length = input_ids.size(1)
    if position_ids is None:
      position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
      position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    if token_type_ids is None:
      token_type_ids = torch.zeros_like(input_ids)

    words_embeddings = self.word_embeddings(input_ids)
    print("words_embeddings shape: ", words_embeddings.shape)
    if self.position_embeddings is not None:
      position_embeddings = self.position_embeddings(position_ids.long())
      print("position_embeddings shape: ", position_embeddings.shape)
    else:
      position_embeddings = torch.zeros_like(words_embeddings)

    embeddings = words_embeddings
    if self.position_biased_input:
      embeddings += position_embeddings
    if self.config.type_vocab_size>0:
      token_type_embeddings = self.token_type_embeddings(token_type_ids)
      embeddings += token_type_embeddings

    if self.embedding_size != self.config.hidden_size:
      embeddings = self.embed_proj(embeddings)

    embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, mask)
    embeddings = self.dropout(embeddings)
    return embeddings

In [337]:
input_ids = torch.tensor([[31, 51, 99, 10, 9, 8, 7, 6], [15, 5, 4, 3, 2, 1, 0, 0]])
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0]])
token_type_ids = torch.zeros_like(input_ids)

In [338]:
embeddings = BertEmbeddings(config)
embeddings

BertEmbeddings(
  (word_embeddings): Embedding(21128, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (LayerNorm): BertLayerNorm()
  (dropout): StableDropout()
)

In [339]:
embedding_output = embeddings(input_ids, token_type_ids, mask=attention_mask)
embedding_output.shape

words_embeddings shape:  torch.Size([2, 8, 768])
position_embeddings shape:  torch.Size([2, 8, 768])


torch.Size([2, 8, 768])

## Distengled Attentin

In [340]:
def get_attention_mask(attention_mask):
    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
    att_mask = extended_attention_mask.byte()
    attention_mask = att_mask*att_mask.squeeze(-2).unsqueeze(-1)
    return attention_mask

In [341]:
attention_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 0, 0]])

In [342]:
get_attention_mask(attention_mask)

tensor([[[[1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 1, 0, 0],
          [1, 1, 1, 1, 1, 1, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0]]]], dtype=torch.uint8)

In [343]:
def get_rel_pos(hidden_states, relative_pos=None):
    q = hidden_states.size(-2)
    relative_pos = build_relative_position(q, hidden_states.size(-2))
    return relative_pos

In [344]:
get_rel_pos(embedding_output)

tensor([[[ 0, -1, -2, -3, -4, -5, -6, -7],
         [ 1,  0, -1, -2, -3, -4, -5, -6],
         [ 2,  1,  0, -1, -2, -3, -4, -5],
         [ 3,  2,  1,  0, -1, -2, -3, -4],
         [ 4,  3,  2,  1,  0, -1, -2, -3],
         [ 5,  4,  3,  2,  1,  0, -1, -2],
         [ 6,  5,  4,  3,  2,  1,  0, -1],
         [ 7,  6,  5,  4,  3,  2,  1,  0]]])

In [346]:
build_relative_position(8, 8)

tensor([[[ 0, -1, -2, -3, -4, -5, -6, -7],
         [ 1,  0, -1, -2, -3, -4, -5, -6],
         [ 2,  1,  0, -1, -2, -3, -4, -5],
         [ 3,  2,  1,  0, -1, -2, -3, -4],
         [ 4,  3,  2,  1,  0, -1, -2, -3],
         [ 5,  4,  3,  2,  1,  0, -1, -2],
         [ 6,  5,  4,  3,  2,  1,  0, -1],
         [ 7,  6,  5,  4,  3,  2,  1,  0]]])

In [347]:
_.shape

torch.Size([1, 8, 8])

In [348]:
input_ids.shape

torch.Size([2, 8])

In [349]:
class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        layer = BertLayer(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
        self.relative_attention = getattr(config, 'relative_attention', False)
        self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
        self.rel_embeddings = nn.Embedding(self.max_relative_positions*2, config.hidden_size)

    def get_attention_mask(self, attention_mask):
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        att_mask = extended_attention_mask.byte()
        attention_mask = att_mask*att_mask.squeeze(-2).unsqueeze(-1)
        return attention_mask

    def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
        q = hidden_states.size(-2)
        relative_pos = build_relative_position(q, hidden_states.size(-2))
        return relative_pos

    def forward(self, hidden_states, attention_mask, 
                output_all_encoded_layers=True, return_att=True,
                query_states = None, relative_pos=None):
        
        attention_mask = self.get_attention_mask(attention_mask)
        relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)

        all_encoder_layers = []
        att_matrixs = []
        next_kv = hidden_states
        rel_embeddings = self.rel_embeddings.weight
        
        for i, layer_module in enumerate(self.layer):
            output_states = layer_module(next_kv, 
                                         attention_mask, 
                                         return_att, 
                                         query_states = query_states, 
                                         relative_pos=relative_pos, 
                                         rel_embeddings=rel_embeddings)
            output_states, att_m = output_states
            next_kv = output_states
            all_encoder_layers.append(output_states)
            att_matrixs.append(att_m)
        return (all_encoder_layers, att_matrixs)

In [350]:
encoder = BertEncoder(config)

In [351]:
encoded_layers = encoder(embedding_output, attention_mask, output_all_encoded_layers=True, return_att=True)

In [352]:
all_encoder_layers, att_matrixs = encoded_layers

In [353]:
len(all_encoder_layers)

12

In [354]:
all_encoder_layers[-1].shape

torch.Size([2, 8, 768])

In [355]:
len(att_matrixs)

12

In [356]:
att_matrixs[-1].shape

torch.Size([2, 12, 8, 8])

In [357]:
########

In [358]:
attn = DisentangledSelfAttention(config)

In [359]:
attn

DisentangledSelfAttention(
  (in_proj): Linear(in_features=768, out_features=2304, bias=False)
  (pos_dropout): StableDropout()
  (pos_proj): Linear(in_features=768, out_features=768, bias=False)
  (pos_q_proj): Linear(in_features=768, out_features=768, bias=True)
  (dropout): StableDropout()
)

In [360]:
embedding_output.shape

torch.Size([2, 8, 768])

In [361]:
get_attention_mask(attention_mask).shape

torch.Size([2, 1, 8, 8])

In [362]:
rel_embeddings = nn.Embedding(config.max_relative_positions*2, config.hidden_size)

In [363]:
(context_layer, attention_probs) = attn(embedding_output, 
     get_attention_mask(attention_mask), 
     return_att=True, query_states=None, 
     relative_pos=build_relative_position(embedding_output.size(-2), embedding_output.size(-2)), # seq_len
     rel_embeddings=rel_embeddings.weight)

In [364]:
context_layer.shape

torch.Size([2, 8, 768])

In [365]:
attention_probs.shape

torch.Size([2, 12, 8, 8])

In [366]:
attention_mask.shape

torch.Size([2, 8])

In [367]:
get_attention_mask(attention_mask).shape

torch.Size([2, 1, 8, 8])

In [368]:
build_relative_position(embedding_output.size(-2), embedding_output.size(-2)).dim()

3

In [369]:
build_relative_position(embedding_output.size(-2), embedding_output.size(-2)).unsqueeze(1).shape

torch.Size([1, 1, 8, 8])

In [370]:
pos_dropout = StableDropout(config.hidden_dropout_prob)
pos_dropout

StableDropout()

In [371]:
pos_dropout(rel_embeddings.weight).shape

torch.Size([1024, 768])

In [372]:
def transpose_for_scores(x):
    new_x_shape = x.size()[:-1] + (config.num_attention_heads, -1)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)

In [377]:
x = torch.rand(torch.Size([2, 8, 2304]))

In [378]:
transpose_for_scores(x).shape

torch.Size([2, 12, 8, 192])

In [379]:
q,k,v = transpose_for_scores(x).chunk(3, dim=-1)

In [380]:
q.shape, k.shape, v.shape

(torch.Size([2, 12, 8, 64]),
 torch.Size([2, 12, 8, 64]),
 torch.Size([2, 12, 8, 64]))

In [241]:
q_bias = torch.nn.Parameter(torch.zeros((768), dtype=torch.float))
transpose_for_scores(q_bias.unsqueeze(0).unsqueeze(0)).shape

torch.Size([1, 12, 1, 64])

In [283]:
config.max_relative_positions

512

In [333]:
embedding_output.shape

torch.Size([2, 3, 768])

In [387]:
class DisentangledSelfAttention(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 768/12 = 64
        self.all_head_size = self.num_attention_heads * self.attention_head_size # 12 * 64 = 768
        self.in_proj = torch.nn.Linear(config.hidden_size, self.all_head_size*3, bias=False) # 768, 768*3
        self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
        self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
        self.pos_att_type = ["p2c", "c2p"]

        self.max_relative_positions = config.max_relative_positions
        self.pos_dropout = StableDropout(config.hidden_dropout_prob)

        self.pos_proj = torch.nn.Linear(config.hidden_size, self.all_head_size, bias=False)
        self.pos_q_proj = torch.nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = StableDropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask, 
                return_att=False, query_states=None, 
                relative_pos=None, rel_embeddings=None):
        """  Call the module
        Args:
            hidden_states (:obj:`torch.FloatTensor`):
                Input states to the module usally the output from previous layer, 
                it will be the Q,K and V in `Attention(Q,K,V)`

            attention_mask (:obj:`torch.ByteTensor`):
                An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, 
                `N` is the maxium sequence length in which element [i,j] = `1` means 
                the `i` th token in the input can attend to the `j` th token.

            return_att (:obj:`bool`, optional):
                Whether return the attention maxitrix.

            query_states (:obj:`torch.FloatTensor`, optional):
                The `Q` state in `Attention(Q,K,V)`.

            relative_pos (:obj:`torch.LongTensor`):
                The relative position encoding between the tokens in the sequence. 
                It's of shape [`B`, `N`, `N`] with values ranging in [`-max_relative_positions`, 
                `max_relative_positions`].

            rel_embeddings (:obj:`torch.FloatTensor`):
                The embedding of relative distances. 
                It's a tensor of shape [:math:`2 \\times \\text{max_relative_positions}`, `hidden_size`].
        """
        # (batch_size, seq_len, hidden_size * 3)
        qp = self.in_proj(hidden_states)
        # (batch_size, num_attention_heads, seq_len, seq_len * attention_head_size).chunk(3, dim=-1) =>
        # (batch_size, num_attention_heads, seq_len, attention_head_size)
        query_layer,key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)
        
        query_layer += self.transpose_for_scores(self.q_bias.unsqueeze(0).unsqueeze(0))
        value_layer += self.transpose_for_scores(self.v_bias.unsqueeze(0).unsqueeze(0))

        rel_att = None
        # Take the dot product between "query" and "key" to get the raw attention scores.
        scale_factor = 1
        if 'c2p' in self.pos_att_type:
            scale_factor += 1
        if 'p2c' in self.pos_att_type:
            scale_factor += 1
        if 'p2p' in self.pos_att_type:
            scale_factor += 1
        scale = math.sqrt(query_layer.size(-1)*scale_factor)
        query_layer = query_layer/scale
        # (batch_size, num_attention_heads, query_size, key_size)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        
        # 本文定义的额外计算 Attention 分数
        rel_embeddings = self.pos_dropout(rel_embeddings)
        # (batch_size, num_attention_heads, query_size, key_size)
        rel_att = self.disentangled_att_bias(query_layer, key_layer, 
                                             relative_pos, rel_embeddings, 
                                             scale_factor)

        attention_scores = attention_scores + rel_att
        
        attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (-1,)
        context_layer = context_layer.view(*new_context_layer_shape)
        
        return (context_layer, attention_probs)

    def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
        # query_layer: (batch_size, num_attention_heads, query_seq_len, attention_head_size)
        # key_layer: like query_layer
        # relative_pos: (1, query_size, key_size)
        # rel_embeddings: (max_relative_positions*2, hidden_size)
        # scale_factor: 3
        
        # relative_pos.dim()==3:
        # (1, query_size, key_size) => (1, 1, query_size, key_size)
        relative_pos = relative_pos.unsqueeze(1)

        # int number
        att_span = min(max(query_layer.size(-2), key_layer.size(-2)), self.max_relative_positions)
        relative_pos = relative_pos.long().to(query_layer.device)
        # (1, att_span*2, hidden_size)
        # 层间共享的 P
        rel_embeddings = rel_embeddings[self.max_relative_positions - att_span:
                                        self.max_relative_positions + att_span, :].unsqueeze(0)
        
        # 位置 Kr
        if 'c2p' in self.pos_att_type:
            # without bias
            # (1, att_span*2, hidden_size)
            pos_key_layer = self.pos_proj(rel_embeddings)
            # (1, num_attention_heads, att_span*2, attention_head_size)
            pos_key_layer = self.transpose_for_scores(pos_key_layer)
        # 位置 Qr
        if 'p2c' in self.pos_att_type:
            # with bias
            # (1, att_span*2, hidden_size)
            pos_query_layer = self.pos_q_proj(rel_embeddings)
            # (1, num_attention_heads, att_span*2, attention_head_size)
            pos_query_layer = self.transpose_for_scores(pos_query_layer)

        score = 0
        # content->position
        if 'c2p' in self.pos_att_type:
            # (batch_size, num_attention_heads, query_size, att_span * 2)
            c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
            # (1, 1, query_size, key_size)  # i-j+k, [0, 2*k)
            c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span*2-1)
            # (batch_size, num_attention_heads, query_size, key_size)
            # expand(batch_size, num_attention_heads, query_size, key_size)
            c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos.expand(
                [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]))
            score += c2p_att

        # position->content
        if 'p2c' in self.pos_att_type:
            pos_query_layer /= math.sqrt(pos_query_layer.size(-1)*scale_factor)
            # j-i+k, [0, 2*k), δ(j,i)
            p2c_pos = torch.clamp(-relative_pos + att_span, 0, att_span*2-1)
            # (batch_size, num_attention_heads, key_size, att_span * 2)
            p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
            # expand(batch_size, num_attention_heads, key_size, query_size)
            # transpose to => (batch_size, num_attention_heads, query_size, key_size)
            p2c_att = torch.gather(p2c_att, dim=-1, index=p2c_pos.expand(
                [key_layer.size(0), key_layer.size(1), key_layer.size(2), relative_pos.size(-2)])
                                  ).transpose(-1, -2)
            # expand 里面稍微改了一下，以前是这样的：
            # [query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]
            score += p2c_att
        return score

In [389]:
attn = DisentangledSelfAttention(config)
(context_layer, attention_probs) = attn(embedding_output, 
     get_attention_mask(attention_mask), 
     return_att=True, query_states=None, 
     relative_pos=build_relative_position(embedding_output.size(-2), embedding_output.size(-2)), # seq_len
     rel_embeddings=rel_embeddings.weight)

In [394]:
context_layer.shape, attention_probs.shape

(torch.Size([2, 8, 768]), torch.Size([2, 12, 8, 8]))

In [336]:
build_relative_position(embedding_output.size(-2), embedding_output.size(-2)).unsqueeze(1).shape

torch.Size([1, 1, 3, 3])

In [315]:
x = build_relative_position(embedding_output.size(-2), embedding_output.size(-2)).unsqueeze(1)
x

tensor([[[[ 0, -1, -2],
          [ 1,  0, -1],
          [ 2,  1,  0]]]])

In [316]:
x.expand([2, 12, 3, 3]).shape

torch.Size([2, 12, 3, 3])

In [319]:
torch.gather?

In [322]:
t = torch.tensor([[1,2],[3,4]])
t

tensor([[1, 2],
        [3, 4]])

In [328]:
torch.gather(t, -1, torch.tensor([[0,0],[1,0]]))

tensor([[1, 1],
        [4, 3]])