In [1]:
import torch
import torch.nn as nn
from transformers import (
    BertModel,
    BertPreTrainedModel,
    MegatronBertPreTrainedModel,
    MegatronBertModel,
    BertConfig
)

In [5]:
emb = nn.Embedding(40, 128)
emb.weight

Parameter containing:
tensor([[-0.7881, -0.2588,  1.5303,  ..., -1.2407,  0.5551,  1.4715],
        [-0.0267, -1.1805,  1.2889,  ...,  1.9129,  0.3833,  0.9492],
        [ 0.4281,  0.2788, -0.3647,  ...,  0.1450,  0.9270, -0.5054],
        ...,
        [-0.0603, -0.3262, -0.5943,  ..., -0.3791, -1.2080, -0.2433],
        [-1.2797,  0.3160,  1.3186,  ..., -0.0179,  0.0253, -0.1372],
        [-2.2144,  0.6131, -0.3709,  ...,  1.1943, -0.6113, -0.1236]],
       requires_grad=True)

In [7]:
emb(torch.arange(0, 10)), emb(torch.arange(30, 40))

(tensor([[-0.7881, -0.2588,  1.5303,  ..., -1.2407,  0.5551,  1.4715],
         [-0.0267, -1.1805,  1.2889,  ...,  1.9129,  0.3833,  0.9492],
         [ 0.4281,  0.2788, -0.3647,  ...,  0.1450,  0.9270, -0.5054],
         ...,
         [-0.3826, -0.3967, -0.1523,  ...,  1.9263, -0.7281, -0.4230],
         [-0.1222,  0.5273,  0.0361,  ...,  0.6581, -0.7048,  2.0117],
         [ 0.1944,  0.5211, -0.5999,  ..., -0.4132, -1.1221, -0.3146]],
        grad_fn=<EmbeddingBackward0>),
 tensor([[ 0.1055,  2.8484, -0.5538,  ..., -0.5746, -0.0499, -0.2440],
         [ 0.3305, -0.1416,  0.9070,  ...,  1.1257, -1.8558,  0.4499],
         [-0.9627, -0.4911,  0.7415,  ...,  1.0141,  0.4602, -0.8878],
         ...,
         [-0.0603, -0.3262, -0.5943,  ..., -0.3791, -1.2080, -0.2433],
         [-1.2797,  0.3160,  1.3186,  ..., -0.0179,  0.0253, -0.1372],
         [-2.2144,  0.6131, -0.3709,  ...,  1.1943, -0.6113, -0.1236]],
        grad_fn=<EmbeddingBackward0>))

In [2]:
class MultiNonLinearClassifier(nn.Module):
    def __init__(self, hidden_size, num_label, dropout_rate, act_func="gelu", intermediate_hidden_size=None):
        super(MultiNonLinearClassifier, self).__init__()
        self.num_label = num_label
        self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size
        self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size)
        self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label)
        self.dropout = nn.Dropout(dropout_rate)
        self.act_func = act_func

    def forward(self, input_features):
        features_output1 = self.classifier1(input_features)
        if self.act_func == "gelu":
            features_output1 = F.gelu(features_output1)
        elif self.act_func == "relu":
            features_output1 = F.relu(features_output1)
        elif self.act_func == "tanh":
            features_output1 = F.tanh(features_output1)
        else:
            raise ValueError
        features_output1 = self.dropout(features_output1)
        features_output2 = self.classifier2(features_output1)
        return features_output2

In [3]:
class PrefixEncoder(torch.nn.Module):
    r'''
    The torch.nn model to encode the prefix
    Input shape: (batch-size, prefix-length)
    Output shape: (batch-size, prefix-length, 2*layers*hidden)
    '''
    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        if self.prefix_projection:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
            )
        else:
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.trans(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values

In [12]:
class BertPrefixQueryNER(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)

        self.start_outputs = nn.Linear(config.hidden_size, 1)
        self.end_outputs = nn.Linear(config.hidden_size, 1)
        self.span_embedding = MultiNonLinearClassifier(
            config.hidden_size * 2,
            1,
            config.mrc_dropout,
            intermediate_hidden_size=config.classifier_intermediate_hidden_size,
        )

        self.hidden_size = config.hidden_size

        self.pre_seq_len = config.pre_seq_len
        self.n_layer = config.num_hidden_layers
        self.n_head = config.num_attention_heads
        self.n_embd = config.hidden_size // config.num_attention_heads
        self.prefix_tokens = torch.arange(self.pre_seq_len)
        self.prefix_encoder = PrefixEncoder(config)

        self.init_weights()

    def get_prompt(self, batch_size):
        prefix_tokens = (
            self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
        )
        past_key_values = self.prefix_encoder(prefix_tokens)
        # bsz, seqlen, _ = past_key_values.shape
        past_key_values = past_key_values.view(
            batch_size, self.pre_seq_len, self.n_layer * 2, self.n_head, self.n_embd
        )
        past_key_values = self.dropout(past_key_values)
        past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
        return past_key_values

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        """
        Args:
            input_ids: bert input tokens, tensor of shape [seq_len]
            token_type_ids: 0 for query, 1 for context, tensor of shape [seq_len]
            attention_mask: attention mask, tensor of shape [seq_len]
        Returns:
            start_logits: start/non-start probs of shape [seq_len]
            end_logits: end/non-end probs of shape [seq_len]
            match_logits: start-end-match probs of shape [seq_len, 1]
        """
        batch_size = input_ids.shape[0]
        past_key_values = self.get_prompt(batch_size=batch_size)
        prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(
            self.bert.device
        )
        attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

        bert_outputs = self.bert(
            input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
        )

        sequence_heatmap = bert_outputs[0]  # [batch, seq_len, hidden]
        batch_size, seq_len, _ = sequence_heatmap.size()
        attention_mask = attention_mask[:, self.pre_seq_len :].contiguous()

        start_logits = self.start_outputs(sequence_heatmap).squeeze(
            -1
        )  # [batch, seq_len, 1]
        end_logits = self.end_outputs(sequence_heatmap).squeeze(
            -1
        )  # [batch, seq_len, 1]

        # for every position $i$ in sequence, should concate $j$ to
        # predict if $i$ and $j$ are start_pos and end_pos for an entity.
        # [batch, seq_len, seq_len, hidden]
        start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1)
        # [batch, seq_len, seq_len, hidden]
        end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1)
        # [batch, seq_len, seq_len, hidden*2]
        span_matrix = torch.cat([start_extend, end_extend], 3)
        # [batch, seq_len, seq_len]
        span_logits = self.span_embedding(span_matrix).squeeze(-1)

        return start_logits, end_logits, span_logits


In [13]:
class BertQueryNerConfig(BertConfig):
    def __init__(self, **kwargs):
        super(BertQueryNerConfig, self).__init__(**kwargs)
        self.mrc_dropout = kwargs.get("mrc_dropout", 0.1)
        self.classifier_intermediate_hidden_size = kwargs.get("classifier_intermediate_hidden_size", 1024)
        self.classifier_act_func = kwargs.get("classifier_act_func", "gelu")

In [20]:
bert_config = BertQueryNerConfig.from_pretrained("bert-base-uncased",
                                     hidden_dropout_prob=0.1,
                                     attention_probs_dropout_prob=0.0,
                                     mrc_dropout=0.1,
                                     classifier_act_func = "gelu",
                                     classifier_intermediate_hidden_size=128)
bert_config.pre_seq_len=32
bert_config.prefix_projection = True
bert_config.prefix_hidden_size = 128

In [21]:
model = BertPrefixQueryNER.from_pretrained("bert-base-uncased", config=bert_config)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertPrefixQueryNER: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertPrefixQueryNER from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertPrefixQueryNER from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertPrefixQueryNER were not initialized from the model checkpoint at bert-base-uncased and are newly initialized

In [22]:
model

BertPrefixQueryNER(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine

In [24]:
bert_param = 0
for name, param in model.bert.named_parameters():
    bert_param += param.numel()
print('bert param is {}'.format(bert_param))

all_param = 0
for name, param in model.named_parameters():
    all_param += param.numel()
total_param = all_param - bert_param
print('total param is {}'.format(total_param))

bert param is 109482240
total param is 2699139
