In [1]:
from dataset.build_dataset import build_dataset
from readingcomprehension.models.luke import LukeForReadingComprehensionWithLoss
import mindspore.dataset as ds
import os
import numpy as np
from mindspore.mindrecord import FileWriter
import json

# Squad 数据集

In [2]:
FEATURES_FILE = "./data/json_features.npy"
features = np.load(FEATURES_FILE)

In [3]:
list_dict = []
for item in features:
    dict_temp = json.loads(item)
    list_dict.append(dict_temp)

In [4]:
SQUAD_MINDRECORD_FILE = "./data/squad_features.mindrecord"

if os.path.exists(SQUAD_MINDRECORD_FILE):
    os.remove(SQUAD_MINDRECORD_FILE)
    os.remove(SQUAD_MINDRECORD_FILE + ".db")

writer = FileWriter(file_name=SQUAD_MINDRECORD_FILE, shard_num=1)

data_schema = {
    "word_ids": {"type": "int32", "shape": [-1]},
    "word_segment_ids": {"type": "int32", "shape": [-1]},
    "word_attention_mask": {"type": "int32", "shape": [-1]},
    "entity_ids": {"type": "int32", "shape": [-1]},
    "entity_position_ids": {"type": "int32", "shape": [-1]},
    "entity_segment_ids": {"type": "int32", "shape": [-1]},
    "entity_attention_mask": {"type": "int32", "shape": [-1]},
    "start_positions": {"type": "int32", "shape": [-1]},
    "end_positions": {"type": "int32", "shape": [-1]}
}
writer.add_schema(data_schema, "it is a preprocessed squad dataset")

data = []
i = 0
for item in list_dict:
    i += 1
    sample = {
        "word_ids": np.array(item["word_ids"], dtype=np.int32),
        "word_segment_ids": np.array(item["word_segment_ids"], dtype=np.int32),
        "word_attention_mask": np.array(item["word_attention_mask"], dtype=np.int32),
        "entity_ids": np.array(item["entity_ids"], dtype=np.int32),
        "entity_position_ids": np.array(item["entity_position_ids"], dtype=np.int32),
        "entity_segment_ids": np.array(item["entity_segment_ids"], dtype=np.int32),
        "entity_attention_mask": np.array(item["entity_attention_mask"], dtype=np.int32),
        "start_positions": np.array(item["start_positions"], dtype=np.int32),
        "end_positions": np.array(item["end_positions"], dtype=np.int32),
    }

    data.append(sample)
    #print(sample)
    if i % 10 == 0:
        writer.write_raw_data(data)
        data = []

if data:
    writer.write_raw_data(data)

writer.commit()

MSRStatus.SUCCESS

In [5]:
data_set = ds.MindDataset(dataset_file=SQUAD_MINDRECORD_FILE)
count = 0
for item in data_set.create_dict_iterator():
    #print(item)
    count += 1
print("Got {} samples".format(count))

Got 269 samples


# model

In [6]:
from readingcomprehension.models.luke import LukeForReadingComprehension
import mindspore.common.dtype as mstype
from model.bert_model import BertConfig
from mindspore import context
from model.luke import LukeModel, EntityAwareEncoder
import numpy as np
from mindspore import Tensor, context
from mindspore import dtype as mstype
import mindspore.ops as ops
import mindspore.nn as nn
from model.bert_model import BertOutput
from mindspore.common.initializer import TruncatedNormal
import math
from mindspore.ops import composite as C
import mindspore
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

In [10]:
data_sample = next(data_set.create_dict_iterator())
data_sample

{'end_positions': Tensor(shape=[1], dtype=Int32, value= [105]),
 'entity_attention_mask': Tensor(shape=[2], dtype=Int32, value= [0, 0]),
 'entity_ids': Tensor(shape=[2], dtype=Int32, value= [0, 0]),
 'entity_position_ids': Tensor(shape=[60], dtype=Int32, value= [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 
  -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 
  -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]),
 'entity_segment_ids': Tensor(shape=[2], dtype=Int32, value= [0, 0]),
 'start_positions': Tensor(shape=[1], dtype=Int32, value= [104]),
 'word_attention_mask': Tensor(shape=[165], dtype=Int32, value= [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

# RobertaEmbeddings

In [11]:
class RobertaEmbeddings(nn.Cell):
    def __init__(self, config):
        super(RobertaEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size,
                                            padding_idx=config.pad_token_id
                                            )
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
                                                  config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm([config.hidden_size],
                                      epsilon=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        # self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
        # self.register_buffer("position_ids", nn.Range(config.max_position_embeddings).expand((1, -1)))
        # self.register_buffer("token_type_ids",
        #                      ops.Zeros(self.position_ids.size(), dtype=mstype.int64),  # dtype used to torch.long
        #                      persistent=False)
        # End copy
        self.padding_idx = config.pad_token_id
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size,
                                                padding_idx=self.padding_idx)

    def construct(self,
                  input_ids=None,
                  token_type_ids=None,
                  position_ids=None,
                  inputs_embeds=None,
                  past_key_values_length=0):
        if position_ids is None:
            if input_ids is not None:
                position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
            else:
                position_ids = create_position_ids_from_input_ids(inputs_embeds)
        #if input_ids is not None:
        input_shape = input_ids.shape
        seq_length = input_shape[1]
        if token_type_ids is None:
            token_type_ids = ops.Zeros(input_shape, dtype=mstype.int64)
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + token_type_embeddings
        position_embeddings = self.position_embeddings(position_ids)
        embeddings += position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
    """
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.
    Args:
       x: torch.Tensor x:
    Returns: torch.Tensor
    """
    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
    pad_id = np.array(padding_idx)
    mask = Tensor(1 * np.array(input_ids.asnumpy() != pad_id))
    #mask = input_ids.ne(padding_idx).int()  # 可能有问题
    cumsum = ops.CumSum()
    incremental_indices = (cumsum(mask, 1) + past_key_values_length) * mask
    return incremental_indices + padding_idx


In [12]:
op_stack = ops.Stack()
word_ids = op_stack([data_sample["word_ids"], data_sample["word_ids"]])
word_segment_ids = op_stack([data_sample["word_segment_ids"], data_sample["word_segment_ids"]])
embeddings = RobertaEmbeddings(luke_net_cfg)
word_embeddings = embeddings.construct(word_ids, word_segment_ids)
word_embeddings

Tensor(shape=[2, 165, 768], dtype=Float32, value=
[[[-4.41752732e-001, -4.44808930e-001, 1.44897386e-001 ... 1.53909969e+000, 7.29057074e-001, -1.06630422e-001],
  [5.97358465e-001, 7.25465938e-002, 9.16226804e-001 ... -3.36068273e-001, -1.52510810e+000, -3.72017056e-001],
  [-2.41841629e-001, 5.47284521e-002, 1.64154339e+000 ... 2.51613528e-001, -8.57514024e-001, -1.07438004e+000],
  ...
  [-1.61258146e-001, -3.13901573e-001, 3.23314726e-001 ... 1.81788588e+000, -4.56612408e-001, 2.11932003e-001],
  [3.16331476e-001, -4.44033027e-001, -1.08146501e+000 ... 2.19673419e+000, -1.20870686e+000, 1.40268490e-001],
  [1.21113503e+000, -1.07970285e+000, -8.74109268e-001 ... 8.76560450e-001, -3.61373484e-001, 1.10155976e+000]],
 [[-4.41752732e-001, -4.44808930e-001, 1.44897386e-001 ... 1.53909969e+000, 7.29057074e-001, -1.06630422e-001],
  [5.97358465e-001, 7.25465938e-002, 9.16226804e-001 ... -3.36068273e-001, -1.52510810e+000, -3.72017056e-001],
  [-2.41841629e-001, 5.47284521e-002, 1.6415433

# EntityEmbeddings

In [13]:
class EntityEmbeddings(nn.Cell):
    """entity embeddings for luke model"""

    def __init__(self, config):
        super(EntityEmbeddings, self).__init__()
        self.config = config
        #config.entity_vocab_size = 20
        #config.entity_emb_size = config.hidden_size
        #config.layer_norm_eps = 1e-6

        self.entity_embeddings = nn.Embedding(config.entity_vocab_size, config.entity_emb_size, padding_idx=0)
        
        if config.entity_emb_size != config.hidden_size:
            self.entity_embedding_dense = nn.Dense(config.entity_emb_size, config.hidden_size, has_bias=False)
            
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        
        # TODO：[config.hidden_size] 和 torch有区别
        self.layer_norm = nn.LayerNorm([config.hidden_size], epsilon=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.unsqueezee = ops.ExpandDims()

    def construct(self, entity_ids, position_ids, token_type_ids=None):
        """EntityEmbeddings for luke"""
        if token_type_ids is None:
            token_type_ids = ops.zeros_like(entity_ids)

        entity_embeddings = self.entity_embeddings(entity_ids)
        if self.config.entity_emb_size != self.config.hidden_size:
            entity_embeddings = self.entity_embedding_dense(entity_embeddings)
        entity_position_ids_int = clamp(position_ids)
        entity_position_ids_int = Tensor(entity_position_ids_int.asnumpy().astype(np.int32))
        position_embeddings = self.position_embeddings(entity_position_ids_int)
        #position_embeddings = self.position_embeddings(position_ids)
        position_embedding_mask = 1*self.unsqueezee((position_ids != -1), -1)
        position_embeddings = position_embeddings * position_embedding_mask
        position_embeddings = ops.reduce_sum(position_embeddings, -2)
        position_embeddings = position_embeddings / clamp(ops.reduce_sum(position_embedding_mask, -2), minimum=1e-7)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = entity_embeddings + position_embeddings + token_type_embeddings
        #embeddings = self.layer_norm(embeddings)
        #embeddings = self.dropout(embeddings)
        return embeddings


def clamp(x, minimum=0.0):
    mask = x > minimum
    x = x * mask + minimum
    return x

In [14]:
net_EntityEmbeddings = EntityEmbeddings(luke_net_cfg)
entity_ids = op_stack([data_sample["entity_ids"],data_sample["entity_ids"]])
entity_position_ids = op_stack([data_sample["entity_position_ids"],data_sample["entity_position_ids"]])
entity_segment_ids = op_stack([data_sample["entity_segment_ids"],data_sample["entity_segment_ids"]])
eg_EntityEmbeddings = net_EntityEmbeddings.construct(entity_ids, entity_position_ids, entity_segment_ids)
eg_EntityEmbeddings

Tensor(shape=[2, 2, 768], dtype=Float32, value=
[[[-7.71686435e-003, -9.18955076e-003, -2.75919517e-003 ... -1.93111412e-002, -1.82587970e-002, 7.03898724e-003],
  [-7.71686435e-003, -9.18955076e-003, -2.75919517e-003 ... -1.93111412e-002, -1.82587970e-002, 7.03898724e-003]],
 [[-7.71686435e-003, -9.18955076e-003, -2.75919517e-003 ... -1.93111412e-002, -1.82587970e-002, 7.03898724e-003],
  [-7.71686435e-003, -9.18955076e-003, -2.75919517e-003 ... -1.93111412e-002, -1.82587970e-002, 7.03898724e-003]]])

# attention_mask

In [15]:
def _compute_extended_attention_mask(word_attention_mask, entity_attention_mask):
    attention_mask = word_attention_mask
    if entity_attention_mask is not None:
        op_Concat = ops.Concat(axis = 1)
        attention_mask = op_Concat((attention_mask, entity_attention_mask))
    unsqueezee = ops.ExpandDims()
    extended_attention_mask = unsqueezee(unsqueezee(attention_mask, 1), 2)
    extended_attention_mask = extended_attention_mask.astype(mstype.float32)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
    return extended_attention_mask

In [16]:
word_attention_mask = op_stack([data_sample["word_attention_mask"],data_sample["word_attention_mask"]])
entity_attention_mask = op_stack([data_sample["entity_attention_mask"],data_sample["entity_attention_mask"]])
attention_mask = _compute_extended_attention_mask(word_attention_mask, entity_attention_mask)
attention_mask

Tensor(shape=[2, 1, 1, 167], dtype=Float32, value=
[[[[-0.00000000e+000, -0.00000000e+000, -0.00000000e+000 ... -0.00000000e+000, -1.00000000e+004, -1.00000000e+004]]],
 [[[-0.00000000e+000, -0.00000000e+000, -0.00000000e+000 ... -0.00000000e+000, -1.00000000e+004, -1.00000000e+004]]]])

# self-attention

In [17]:
class EntityAwareSelfAttention(nn.Cell):
    """EntityAwareSelfAttention"""

    def __init__(self, config):
        super(EntityAwareSelfAttention, self).__init__()

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Dense(config.hidden_size, self.all_head_size)
        self.key = nn.Dense(config.hidden_size, self.all_head_size)
        self.value = nn.Dense(config.hidden_size, self.all_head_size)

        self.w2e_query = nn.Dense(config.hidden_size, self.all_head_size)
        self.e2w_query = nn.Dense(config.hidden_size, self.all_head_size)
        self.e2e_query = nn.Dense(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.concat = ops.Concat(1)
        self.concat2 = ops.Concat(2)
        self.concat3 = ops.Concat(3)
        self.sotfmax = ops.Softmax()
        self.shape = ops.Shape()
        self.reshape = ops.Reshape()
        self.transpose = ops.Transpose()
        self.softmax = ops.Softmax(axis = -1)
        
    def transpose_for_scores(self, x):
        new_x_shape = ops.shape(x)[:-1] + (self.num_attention_heads, self.attention_head_size)
        out = self.reshape(x, new_x_shape)
        out = self.transpose(out, (0, 2, 1, 3))
        return out

    def construct(self, word_hidden_states, entity_hidden_states, attention_mask):
        """EntityAwareSelfAttention construct"""
        word_size = self.shape(word_hidden_states)[1]
        w2w_query_layer = self.transpose_for_scores(self.query(word_hidden_states))
        w2e_query_layer = self.transpose_for_scores(self.w2e_query(word_hidden_states))
        e2w_query_layer = self.transpose_for_scores(self.e2w_query(entity_hidden_states))
        e2e_query_layer = self.transpose_for_scores(self.e2e_query(entity_hidden_states))

        key_layer = self.transpose_for_scores(self.key(self.concat([word_hidden_states, entity_hidden_states])))

        w2w_key_layer = key_layer[:, :, :word_size, :]
        e2w_key_layer = key_layer[:, :, :word_size, :]
        w2e_key_layer = key_layer[:, :, word_size:, :]
        e2e_key_layer = key_layer[:, :, word_size:, :]

        w2w_attention_scores = ops.matmul(w2w_query_layer, ops.transpose(w2w_key_layer, (0,1, 3, 2)))
        w2e_attention_scores = ops.matmul(w2e_query_layer, ops.transpose(w2e_key_layer, (0,1, 3, 2)))
        e2w_attention_scores = ops.matmul(e2w_query_layer, ops.transpose(e2w_key_layer, (0,1, 3, 2)))
        e2e_attention_scores = ops.matmul(e2e_query_layer, ops.transpose(e2e_key_layer, (0,1, 3, 2)))

        word_attention_scores = self.concat3([w2w_attention_scores, w2e_attention_scores])
        entity_attention_scores = self.concat3([e2w_attention_scores, e2e_attention_scores])
        attention_scores = self.concat2([word_attention_scores, entity_attention_scores])

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_scores = attention_scores + attention_mask

        attention_probs = self.softmax(attention_scores)
        attention_probs = self.dropout(attention_probs)

        value_layer = self.transpose_for_scores(
            self.value(self.concat([word_hidden_states, entity_hidden_states]))
        )
        context_layer = ops.matmul(attention_probs, value_layer)

        context_layer = ops.transpose(context_layer, (0, 2, 1, 3))
        new_context_layer_shape = ops.shape(context_layer)[:-2] + (self.all_head_size,)
        context_layer = self.reshape(context_layer, new_context_layer_shape)

        return context_layer[:, :word_size, :], context_layer[:, word_size:, :]

In [18]:
config = luke_net_cfg
self_attention = EntityAwareSelfAttention(config)
self_attention.construct(word_embeddings, eg_EntityEmbeddings, attention_mask)

(Tensor(shape=[2, 165, 768], dtype=Float32, value=
 [[[2.53133923e-001, -1.60536572e-001, 8.10929239e-002 ... -1.18726261e-001, 9.20061991e-002, 1.52696553e-003],
   [2.52583355e-001, -1.61234885e-001, 7.85283968e-002 ... -1.19530343e-001, 8.81919041e-002, 3.52382869e-003],
   [2.51733571e-001, -1.60819903e-001, 8.06009918e-002 ... -1.18606493e-001, 8.90120715e-002, 3.00622871e-003],
   ...
   [2.52933174e-001, -1.61249921e-001, 7.85813108e-002 ... -1.18925765e-001, 9.12281349e-002, 2.40319758e-003],
   [2.52820879e-001, -1.61966741e-001, 7.76201263e-002 ... -1.16530523e-001, 8.90737846e-002, 5.29047986e-003],
   [2.54186511e-001, -1.60900801e-001, 7.78411701e-002 ... -1.19053677e-001, 9.07880366e-002, 3.46294721e-003]],
  [[2.53133923e-001, -1.60536572e-001, 8.10929239e-002 ... -1.18726261e-001, 9.20061991e-002, 1.52696553e-003],
   [2.52583355e-001, -1.61234885e-001, 7.85283968e-002 ... -1.19530343e-001, 8.81919041e-002, 3.52382869e-003],
   [2.51733571e-001, -1.60819903e-001, 8.0600

# EntityAwareAttention

In [19]:
class BertOutput(nn.Cell):
    """
    Apply a linear computation to hidden status and a residual computation to input.

    Args:
        in_channels (int): Input channels.
        out_channels (int): Output channels.
        initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
        dropout_prob (float): The dropout probability. Default: 0.1.
        compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 initializer_range=0.02,
                 dropout_prob=0.1,
                 compute_type=mstype.float32):
        super(BertOutput, self).__init__()
        self.dense = nn.Dense(in_channels, out_channels,
                              weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
        self.dropout = nn.Dropout(1 - dropout_prob)
        self.dropout_prob = dropout_prob
        self.add = P.Add()
        self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type)
        self.cast = P.Cast()

    def construct(self, hidden_status, input_tensor):
        output = self.dense(hidden_status)
        output = self.dropout(output)
        output = self.add(input_tensor, output)
        output = self.layernorm(output)
        return output

In [20]:

class BertSelfOutput(nn.Cell):
    def __init__(self, config, compute_type=mstype.float32):
        super().__init__()
        self.dense = nn.Dense(config.hidden_size, config.hidden_size,
                             weight_init=TruncatedNormal(config.initializer_range)).to_float(compute_type)
        self.LayerNorm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps).to_float(compute_type)
        self.dropout = nn.Dropout(1 - config.hidden_dropout_prob)
        self.add = P.Add()

    def construct(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.add(input_tensor, hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

In [21]:
print(config.hidden_dropout_prob)

0.1


In [22]:
class EntityAwareAttention(nn.Cell):
    """EntityAwareAttention"""

    def __init__(self, config):
        super(EntityAwareAttention, self).__init__()
        self.self_attention = EntityAwareSelfAttention(config)
        self.output = BertSelfOutput(config)
        self.concat = ops.Concat(1)

    def construct(self, word_hidden_states, entity_hidden_states, attention_mask):
        word_self_output, entity_self_output = self.self_attention.construct(word_hidden_states, entity_hidden_states, attention_mask)
        hidden_states = self.concat([word_hidden_states, entity_hidden_states])
        self_output = self.concat([word_self_output, entity_self_output])
        out = self.output.construct(hidden_states, self_output)
        out1 = out[:, : ops.shape(word_hidden_states)[1], :]
        out2 = out[:, ops.shape(word_hidden_states)[1]:, :]
        #return output[:, : ops.shape(word_hidden_states)[1], :], output[:, ops.shape(word_hidden_states)[1]:, :]
        return out1, out2

In [23]:
AwareAttention = EntityAwareAttention(config)
out1, out2 = AwareAttention.construct(word_embeddings, eg_EntityEmbeddings, attention_mask)
out2.shape

(2, 2, 768)

# EntityAwareLayer

In [24]:
class EntityAwareLayer(nn.Cell):
    """EntityAwareLayer"""

    def __init__(self, config):
        super(EntityAwareLayer, self).__init__()

        self.attention = EntityAwareAttention(config)
        self.intermediate = nn.Dense(config.hidden_size, 
                                     config.intermediate_size,
                                     activation=config.hidden_act,
                                     weight_init=TruncatedNormal(config.initializer_range)).to_float(mstype.float32)
        self.output = BertOutput(config.intermediate_size, config.hidden_size)
        self.concat = ops.Concat(1)

    def construct(self, word_hidden_states, entity_hidden_states, attention_mask):
        word_attention_output, entity_attention_output = self.attention.construct(
            word_hidden_states, entity_hidden_states, attention_mask
        )
        attention_output = self.concat([word_attention_output, entity_attention_output])
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output.construct(intermediate_output, attention_output)

        return layer_output[:, : ops.shape(word_hidden_states)[1], :], \
               layer_output[:, ops.shape(word_hidden_states)[1]:, :]

In [25]:
EntityAwareLayer__ = EntityAwareLayer(config)
out1, out2 = EntityAwareLayer__.construct(word_embeddings, eg_EntityEmbeddings, attention_mask)
out2.shape

(2, 2, 768)

# EntityAwareEncoder

In [26]:
class EntityAwareEncoder(nn.Cell):
    """EntityAwareEncoder"""

    def __init__(self, config):
        super(EntityAwareEncoder, self).__init__()
        #self.layer = EntityAwareLayer(config)
        self.layer = nn.CellList([EntityAwareLayer(config) for _ in range(config.num_hidden_layers)])

    def construct(self, word_hidden_states, entity_hidden_states, attention_mask):
        for layer_module in self.layer:
            word_hidden_states, entity_hidden_states = layer_module.construct(
                word_hidden_states, entity_hidden_states, attention_mask
            )
        return word_hidden_states, entity_hidden_states

In [27]:
EntityAwareEncoder__ = EntityAwareEncoder(config)
EntityAwareEncoder__.construct(word_embeddings, eg_EntityEmbeddings, attention_mask)

(Tensor(shape=[2, 165, 768], dtype=Float32, value=
 [[[4.15360510e-001, -3.95976394e-001, -6.53767526e-001 ... -9.21365440e-001, 6.43256068e-001, -4.25003141e-001],
   [-9.43257008e-003, -3.21244985e-001, -5.20340323e-001 ... -1.01637685e+000, 1.72726680e-002, -1.83341965e-001],
   [-1.15032345e-001, -5.66492736e-001, -8.94666910e-001 ... -9.95096624e-001, 3.00080210e-001, -3.62655103e-001],
   ...
   [2.13470399e-001, -6.99665666e-001, -1.01645136e+000 ... -1.11550820e+000, 4.22653645e-001, -1.28931506e-002],
   [-2.62470096e-001, -1.01830363e+000, -6.96891069e-001 ... -1.49254012e+000, 3.85883212e-001, -8.99146870e-002],
   [1.96873173e-001, -8.99896383e-001, -7.28093684e-001 ... -1.01188862e+000, 6.34365797e-001, -9.88355130e-002]],
  [[4.15360510e-001, -3.95976394e-001, -6.53767526e-001 ... -9.21365440e-001, 6.43256068e-001, -4.25003141e-001],
   [-9.43257008e-003, -3.21244985e-001, -5.20340323e-001 ... -1.01637685e+000, 1.72726680e-002, -1.83341965e-001],
   [-1.15032345e-001, -5.

# LukeEntityAwareAttentionModel

In [28]:
class LukeEntityAwareAttentionModel(LukeModel):
    """LukeEntityAwareAttentionModel"""

    def __init__(self, config):
        super(LukeEntityAwareAttentionModel, self).__init__(config)
        self.config = config
        self.encoder = EntityAwareEncoder(config)

    def construct(self, word_ids, word_segment_ids, word_attention_mask, entity_ids,
                  entity_position_ids, entity_segment_ids, entity_attention_mask):
        word_embeddings = self.embeddings.construct(word_ids, word_segment_ids)
        entity_embeddings = self.entity_embeddings.construct(entity_ids, entity_position_ids, entity_segment_ids)
        attention_mask = self._compute_extended_attention_mask(word_attention_mask, entity_attention_mask)

        return self.encoder.construct(word_embeddings, entity_embeddings, attention_mask)

In [29]:
data_sample

{'end_positions': Tensor(shape=[1], dtype=Int32, value= [105]),
 'entity_attention_mask': Tensor(shape=[2], dtype=Int32, value= [0, 0]),
 'entity_ids': Tensor(shape=[2], dtype=Int32, value= [0, 0]),
 'entity_position_ids': Tensor(shape=[60], dtype=Int32, value= [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 
  -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 
  -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]),
 'entity_segment_ids': Tensor(shape=[2], dtype=Int32, value= [0, 0]),
 'start_positions': Tensor(shape=[1], dtype=Int32, value= [104]),
 'word_attention_mask': Tensor(shape=[165], dtype=Int32, value= [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [30]:
op_stack = ops.Stack()
word_ids = op_stack([data_sample["word_ids"], data_sample["word_ids"]])
word_segment_ids = op_stack([data_sample["word_segment_ids"], data_sample["word_segment_ids"]])
word_attention_mask = op_stack([data_sample["word_attention_mask"], data_sample["word_attention_mask"]])
entity_ids = op_stack([data_sample["entity_ids"], data_sample["entity_ids"]])
entity_position_ids = op_stack([data_sample["entity_position_ids"], data_sample["entity_position_ids"]])
entity_segment_ids = op_stack([data_sample["entity_segment_ids"], data_sample["entity_segment_ids"]])
entity_attention_mask = op_stack([data_sample["entity_attention_mask"], data_sample["entity_attention_mask"]])
start_positions = op_stack([data_sample["start_positions"], data_sample["start_positions"]])
end_positions = op_stack([data_sample["end_positions"], data_sample["end_positions"]])

In [31]:
LukeEntityAwareAttentionModel__ = LukeEntityAwareAttentionModel(config)
LukeEntityAwareAttentionModel__.construct(word_ids,
                                          word_segment_ids,
                                          word_attention_mask,
                                          entity_ids,
                                          entity_position_ids,
                                          entity_segment_ids,
                                          entity_attention_mask
                                          )

(Tensor(shape=[2, 165, 768], dtype=Float32, value=
 [[[2.49526367e-001, -7.46389508e-001, 3.81608248e-001 ... -5.47386035e-002, -1.27832508e+000, 6.00578904e-001],
   [9.78615761e-001, -5.02685308e-001, 3.67616683e-001 ... -1.78594932e-001, -9.16401267e-001, 5.65950215e-001],
   [6.86164558e-001, -5.35537004e-001, 5.53446352e-001 ... -7.36731440e-002, -1.02447784e+000, 4.10418957e-001],
   ...
   [5.61239362e-001, -6.25160217e-001, 6.52561709e-002 ... 3.27793539e-001, -1.26642430e+000, 7.57280111e-001],
   [7.46241808e-001, -1.15103817e+000, 1.91354737e-001 ... 2.21213371e-001, -8.41529369e-001, 6.64219141e-001],
   [8.04365695e-001, -7.37675309e-001, -1.86596975e-001 ... -3.32862735e-002, -1.07152200e+000, 2.90190727e-001]],
  [[2.49526367e-001, -7.46389508e-001, 3.81608248e-001 ... -5.47386035e-002, -1.27832508e+000, 6.00578904e-001],
   [9.78615761e-001, -5.02685308e-001, 3.67616683e-001 ... -1.78594932e-001, -9.16401267e-001, 5.65950215e-001],
   [6.86164558e-001, -5.35537004e-001,

# LukeForReadingComprehension

In [32]:
word_ids.shape

(2, 165)

In [33]:
ops.shape(word_ids)[1]

165

In [34]:
class LukeForReadingComprehension(LukeEntityAwareAttentionModel):
    """Luke for reading comprehension task"""

    def __init__(self, config):
        super(LukeForReadingComprehension, self).__init__(config)
        self.LukeEntityAwareAttentionModel = super(LukeForReadingComprehension, self)
        self.qa_outputs = nn.Dense(self.config.hidden_size, 2)
        self.split = ops.Split(-1, 2)
        self.squeeze = ops.Squeeze(-1)
        self.shape = ops.Shape()

    def construct(
            self,
            word_ids,
            word_segment_ids,
            word_attention_mask,
            entity_ids,
            entity_position_ids,
            entity_segment_ids,
            entity_attention_mask,
            start_positions=None,
            end_positions=None,
    ):
        """LukeForReadingComprehension construct"""
        encoder_outputs = self.LukeEntityAwareAttentionModel.construct(
            word_ids,
            word_segment_ids,
            word_attention_mask,
            entity_ids,
            entity_position_ids,
            entity_segment_ids,
            entity_attention_mask,
        )

        word_hidden_states = encoder_outputs[0][:, : ops.shape(word_ids)[1], :]
        logits = self.qa_outputs(word_hidden_states)
        start_logits, end_logits = self.split(logits)
        start_logits = self.squeeze(start_logits)
        end_logits = self.squeeze(end_logits)
        
        if start_positions is not None and end_positions is not None:
            if len(self.shape(start_positions)) > 1:
                start_positions = self.squeeze(start_positions)
            if len(self.shape(end_positions)) > 1:
                end_positions = self.squeeze(end_positions)

            ignored_index = ops.shape(start_logits)[1]
            start_positions = C.clip_by_value(start_positions, 0, ignored_index)
            end_positions = C.clip_by_value(end_positions, 0, ignored_index)

            loss_fct = nn.SoftmaxCrossEntropyWithLogits(sparse = True)
            #loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,)
        else:
            outputs = tuple()
        return outputs + (start_logits, end_logits,)

In [35]:
LukeForReadingComprehension__ = LukeForReadingComprehension(config)
LukeForReadingComprehension__.construct(word_ids,
                                        word_segment_ids,
                                        word_attention_mask,
                                        entity_ids,
                                        entity_position_ids,
                                        entity_segment_ids,
                                        entity_attention_mask,
                                        start_positions,
                                        end_positions
                                        )

(Tensor(shape=[2], dtype=Float32, value= [5.13476753e+000, 5.13476753e+000]),
 Tensor(shape=[2, 165], dtype=Float32, value=
 [[2.52272874e-001, 3.20962906e-001, 1.44353583e-001 ... 1.16958171e-001, 1.70200512e-001, 2.72789359e-001],
  [2.52272874e-001, 3.20962906e-001, 1.44353583e-001 ... 1.16958171e-001, 1.70200512e-001, 2.72789359e-001]]),
 Tensor(shape=[2, 165], dtype=Float32, value=
 [[2.52272874e-001, 3.20962906e-001, 1.44353583e-001 ... 1.16958171e-001, 1.70200512e-001, 2.72789359e-001],
  [2.52272874e-001, 3.20962906e-001, 1.44353583e-001 ... 1.16958171e-001, 1.70200512e-001, 2.72789359e-001]]))

In [36]:
LukeForReadingComprehension__

LukeForReadingComprehension<
  (encoder): EntityAwareEncoder<
    (layer): CellList<
      (0): EntityAwareLayer<
        (attention): EntityAwareAttention<
          (self_attention): EntityAwareSelfAttention<
            (query): Dense<input_channels=768, output_channels=768, has_bias=True>
            (key): Dense<input_channels=768, output_channels=768, has_bias=True>
            (value): Dense<input_channels=768, output_channels=768, has_bias=True>
            (w2e_query): Dense<input_channels=768, output_channels=768, has_bias=True>
            (e2w_query): Dense<input_channels=768, output_channels=768, has_bias=True>
            (e2e_query): Dense<input_channels=768, output_channels=768, has_bias=True>
            (dropout): Dropout<keep_prob=0.1>
            >
          (output): BertSelfOutput<
            (dense): Dense<input_channels=768, output_channels=768, has_bias=True>
            (LayerNorm): LayerNorm<normalized_shape=(768,), begin_norm_axis=-1, begin_params_axis=-1, 