In [15]:
import tensorflow as tf
import math

In [16]:
class BertConfig():

    def __init__(
        self,
        vocab_size=30522,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        pad_token_id=0,
        gradient_checkpointing=False,
        position_embedding_type="absolute",
        use_cache=True
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.gradient_checkpointing = gradient_checkpointing
        self.position_embedding_type = position_embedding_type
        self.use_cache = use_cache
        

In [21]:
class TFBertEmbedding(tf.keras.layers.Layer):
    def __init__(self, config):
        super().__init__()
        self.vocab_size = config.vocab_size # 30522
        self.type_vocab_size = config.type_vocab_size # 2
        self.hidden_size = config.hidden_size # 768
        self.max_position_embeddings = config.max_position_embeddings # 512
        self.embeddings_sum = tf.keras.layers.Add()
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps)
        self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
        
    def build(self, input_shape):
        super().build(input_shape)
        self.weight = self.add_weight(shape=[self.vocab_size, self.hidden_size])
        self.token_type_embedding = self.add_weight(shape=[self.type_vocab_size, self.hidden_size])
        self.position_embedding = self.add_weight(shape=[self.max_position_embeddings, self.hidden_size])
        
    def call(
        self,
        input_ids=None,
        position_ids=None,
        token_type_ids=None,
        training=False
    ):
        inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
        position_embeds = tf.gather(params=self.position_embedding, indices=position_ids)
        token_type_embeds = tf.gather(params=self.token_type_embedding, indices=token_type_ids)
        final_embedding = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
        final_embedding = self.LayerNorm(inputs=final_embedding)
        final_embedding = self.dropout(inputs=final_embedding, training=training)
        # shape = [batch_size, seq_len, hidden_size]
        return final_embedding 

In [29]:
class TFBertSelfAttention(tf.keras.layers.Layer):
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads # 12
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 768/12
        self.all_head_size = self.num_attention_heads * self.attention_head_size # 768
        self.sqrt_att_head_size = math.sqrt(self.attention_head_size) # sqrt(768)
        self.query = tf.keras.layers.Dense(units=self.all_head_size)
        self.key = tf.keras.layers.Dense(units=self.all_head_size)
        self.value = tf.keras.layers.Dense(units=self.all_head_size)
        self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
        
    def transpose_for_scores(self, tensor, batch_size):
        tensor = tf.reshape(tensor=tensor, shape=[batch_size, -1, self.num_attention_heads, self.attention_head_size])
        return tf.transpose(tensor, perm=[0, 2, 1, 3])
    
    def call(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        output_attentions,
        traning=False
    ):
        batch_size = hidden_states.shape.as_list()[0]
        # mixed layer includes all the heads
        # shape = [batch_size, seq_len, all_head_size]
        mixed_query_layer = self.query(inputs=hidden_states)
        mixed_key_layer = self.key(inputs=hidden_states)
        mixed_value_layer = self.value(inputs=hidden_states)        
        # shape = [batch_size, num_attention_heads=12, seq_len=512, attention_head_size=768/12]
        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
        # shape = [batch_size, num_attention_heads, seq_len, seq_len]
        attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
        dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
        attention_scores = tf.divide(attention_scores, dk)
        attention_scores = tf.add(attention_scores, attention_mask)
        attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
        attention_probs = tf.multiply(attention_probs, head_mask)
        # shape = [batch_size, num_attention_heads, seq_len, attention_head_size]
        attention_output = tf.matmul(attention_probs, value_layer)
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
        # shape = [batch_size, seq_len, all_head_size]
        outputs = tf.reshape(attention_output, shape=[batch_size, -1, self.all_head_size])
        return outputs
        

In [30]:
config = BertConfig()
h = TFBertEmbedding(config)([1,2,3],[0,0,0],[1,1,1])

<tf.Tensor: shape=(3, 1, 768), dtype=float32, numpy=
array([[[-1.0565512 , -1.1522021 , -1.804117  , ..., -0.21669677,
          1.4941266 , -1.4558128 ]],

       [[-1.0245496 , -1.4149065 , -1.8073409 , ..., -0.08891329,
          1.6286561 , -1.2214249 ]],

       [[-0.70252836, -1.2980103 , -2.0195045 , ..., -0.19623536,
          1.535229  , -1.154598  ]]], dtype=float32)>

In [16]:
a = tf.constant([12,3], dtype=tf.int32)
a.shape.as_list()

[2]

In [7]:
a = tf.fill(dims=(1,2,3), value=1)
b = tf.fill(dims=[1,2,3], value=1)
tf.matmul(a, b, transpose_b=True)

<tf.Tensor: shape=(1, 2, 2), dtype=int32, numpy=
array([[[3, 3],
        [3, 3]]])>