In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.keras import Model
import math
import json

In [2]:
def gelu(x):
    cdf = 0.5 * (1.0 + tf.tanh(
    (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
    return x * cdf

In [3]:
def reshape_matrix(input_tensor):
    ndims = input_tensor.shape.ndims
    if ndims < 2:
        raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
                     (input_tensor.shape))
    if ndims == 2:
        return input_tensor

    width = input_tensor.shape[-1]
    output_tensor = tf.reshape(input_tensor, [-1, width])
    return output_tensor



In [4]:
def attention_layer(seq_length=128,
                    attention_mask=None,
                    num_attention_heads=1,
                    size_per_head=512,
                    query_act=None,
                    key_act=None,
                    value_act=None,
                    attention_probs_dropout_prob=0.0,
                    initializer_range=0.02,
                    do_return_2d_tensor=False,
                    from_seq_length=None,
                    to_seq_length=None):
    
    def transpose_for_scores(input_tensor, num_attention_heads, seq_length, size_per_head):
        output_tensor = tf.reshape(input_tensor, [-1, seq_length, num_attention_heads, size_per_head])
        output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
        return output_tensor
        
    
    
    # input_shape = [None, seq_length, hidden_size]
    
    inputs = tf.keras.Input([num_attention_heads*size_per_head])
    reshape_inputs = reshape_matrix(inputs)
    query_layer = tf.keras.layers.Dense(num_attention_heads * size_per_head, activation=query_act)(reshape_inputs)
    key_layer = tf.keras.layers.Dense(num_attention_heads * size_per_head, activation=key_act)(reshape_inputs)
    value_layer = tf.keras.layers.Dense(num_attention_heads * size_per_head, activation=value_act)(reshape_inputs)
    
    query_layer = transpose_for_scores(query_layer, num_attention_heads, seq_length, size_per_head)
    key_layer = transpose_for_scores(key_layer, num_attention_heads, seq_length, size_per_head)

    attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
    attention_scores = tf.multiply(attention_scores, 1.0/math.sqrt(float(size_per_head)))
    
    if attention_mask is not None:
        attention_mask = tf.expand_dims(attention_mask, axis=[1])
        adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
        attention_scores += adder
    
    attention_probs = tf.keras.layers.Softmax()(attention_scores)
    attention_probs = tf.keras.layers.Dropout(attention_probs_dropout_prob)(attention_probs)
        
    value_layer = transpose_for_scores(value_layer, num_attention_heads, seq_length, size_per_head)

    context_layer = tf.matmul(attention_probs, value_layer)
    context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
    
    if do_return_2d_tensor:
        context_layer = tf.reshape(context_layer, [-1, num_attention_heads*size_per_head])
    else:
        context_layer = tf.reshape(context_layer, [-1, seq_length, num_attention_heads, size_per_head])
    
    return Model(inputs, context_layer)

In [5]:
def transformer_model(seq_length=128,
                      attention_mask=None,
                      hidden_size=768,
                      num_hidden_layers=12,
                      num_attention_heads=12,
                      intermediate_size=3072,
                      intermediate_act_fn=gelu,
                      hidden_dropout_prob=0.1,
                      attention_probs_dropout_prob=0.1,
                      initializer_range=0.02,
                      do_return_all_layers=False):
    inputs = tf.keras.Input([seq_length, hidden_size])
    
    if hidden_size % num_attention_heads != 0:
        raise ValueError("The hidden size (%d) is not a multiple of the number of attention heads (%d)" % (hidden_size, num_attention_heads))
    
    attention_head_size = int(hidden_size / num_attention_heads)
    prev_output = reshape_matrix(inputs)
    
    all_layer_outputs = []
    for layer_idx in range(num_hidden_layers):
        layer_input = prev_output
        
        attention_heads = []
        attention_head = attention_layer(
                                        attention_mask=attention_mask,
                                        num_attention_heads=num_attention_heads,
                                        size_per_head=attention_head_size,
                                        attention_probs_dropout_prob=attention_probs_dropout_prob,
                                        initializer_range=initializer_range,
                                        do_return_2d_tensor=True,
                                        from_seq_length=seq_length,
                                        to_seq_length=seq_length)(layer_input)
        attention_heads.append(attention_head)
        
        attention_output = None
        
        # attention layer를 한번에 연산할지 나눠서 연산할지에 따라 다르게 구성 => 논문에서는 두가지 방법은 사실상 같다고 언급
        if len(attention_heads) == 1:
            attention_output = attention_heads[0]
        else:
            attention_output = tf.concat(attention_heads, axis=-1)
        
        attention_output = tf.keras.layers.Dense(hidden_size)(attention_output)
        attention_output = tf.keras.layers.Dropout(hidden_dropout_prob)(attention_output)
        attention_output = tf.keras.layers.LayerNormalization()(attention_output + layer_input)
        
        intermediate_output = tf.keras.layers.Dense(intermediate_size, activation=intermediate_act_fn)(attention_output)
        
        layer_output = tf.keras.layers.Dense(hidden_size)(intermediate_output)
        layer_output = tf.keras.layers.Dropout(hidden_dropout_prob)(layer_output)
        layer_output = tf.keras.layers.LayerNormalization()(layer_output + attention_output)
        
        prev_output = layer_output
        all_layer_outputs.append(layer_output)
    
    if do_return_all_layers:
        final_outputs = []
        for layer_output in all_layer_outputs:
            final_output = tf.reshape(layer_output, [-1, seq_length, hidden_size])
            final_outputs.append(final_output)
        return Model(inputs, final_outputs)
    else:
        final_output = tf.reshape(prev_output, [-1, seq_length, hidden_size])
        return Model(inputs, final_output)

In [11]:
class embedding_postprocessor(tf.keras.layers.Layer):

    def __init__(self,
                seq_length,
                hidden_size,
                use_token_type=False,
                token_type_ids=None,
                token_type_vocab_size=16,
                token_type_embedding_name="token_type_embeddings",
                use_position_embeddings=True,
                position_embedding_name="position_embeddings",
                initializer_range=0.02,
                max_position_embeddings=512,
                dropout_prob=0.1):
    
        super(embedding_postprocessor, self).__init__()
        self.seq_length = seq_length
        self.hidden_size = hidden_size
        self.use_token_type = use_token_type
        self.token_type_ids = token_type_ids
        self.token_type_vocab_size = token_type_vocab_size
        self.token_type_embedding_name = token_type_embedding_name
        self.use_position_embeddings = use_position_embeddings
        self.position_embedding_name = position_embedding_name
        self.initializer_range = initializer_range
        self.max_position_embeddings = max_position_embeddings
        self.dropout_prob = dropout_prob
        self.layer_norm = tf.keras.layers.LayerNormalization()
        self.dropout = tf.keras.layers.Dropout(self.dropout_prob)
    
    def build(self, input_shape):
        self.token_embedding = self.add_weight(shape=(self.token_type_vocab_size, input_shape[-1]))
        self.position_embedding = self.add_weight(shape=(self.max_position_embeddings, input_shape[-1]))
    
    def call(self, inputs):
        
        output = inputs
        
        if self.use_token_type:
            if self.token_type_ids is None:
                raise ValueError("'token_type_ids' must be specified if 'use_token_type' is True.")
            
            self.flat_token_type_ids = tf.reshape(self.token_type_ids, [-1])
            one_hot_ids = tf.one_hot(flat_token_type_ids, depth=self.token_type_vocab_size)
            token_type_embeddings = tf.matmul(one_hot_ids, self.token_embedding)
            token_type_embeddings = tf.reshape(token_type_embeddings, inputs.shape)
            
            output += token_type_embeddings
        
        if self.use_position_embeddings:
            position_embeddings = tf.slice(self.position_embedding, [0,0], [self.seq_length, -1])
            
            num_dims = len(output.shape)
            
            position_broadcast_shape = []
            for _ in range(num_dims-2):
                position_broadcast_shape.append(1)
            position_broadcast_shape.extend([self.seq_length, inputs.shape[-1]])
            position_embeddings = tf.reshape(position_embeddings, position_broadcast_shape)
            
            output += position_embeddings
            
        output = self.layer_norm(output)
        output = self.dropout(output)
        
        return output

In [13]:
class BertModel(tf.keras.layers.Layer):
    
    def __init__(self,
                 config,
                 is_training,
                input_masks=None,
                token_type_ids=None,
                use_one_hot_embedding=False):
        
        super(BertModel, self).__init__()
        self.config = config
        self.is_training = is_training
        self.input_masks = input_masks
        self.token_type_ids = token_type_ids
        self.use_one_hot_embedding = use_one_hot_embedding
        
        self.embedding = tf.keras.layers.Embedding(self.config["vocab_size"], self.config["hidden_size"])
        self.embedding_postprocessor = embedding_postprocessor(config["seq_length"],
                                                        config["hidden_size"],
                                                       use_token_type=False,
                                                       token_type_ids=token_type_ids,
                                                       token_type_vocab_size=config["type_vocab_size"],
                                                       token_type_embedding_name="token_type_embeddings",
                                                       use_position_embeddings=True,
                                                       position_embedding_name="position_embeddings",
                                                       initializer_range=config["initializer_range"],
                                                       max_position_embeddings=config["max_position_embeddings"],
                                                       dropout_prob=config["hidden_dropout_prob"])
        self.transformer_model = transformer_model(seq_length = config["seq_length"],
                                                attention_mask=None,
                                                hidden_size=config["hidden_size"],
                                                num_hidden_layers=config["num_hidden_layers"],
                                                num_attention_heads=config["num_attention_heads"],
                                                intermediate_size=config["intermediate_size"],
                                                intermediate_act_fn='gelu',
                                                hidden_dropout_prob=config["hidden_dropout_prob"],
                                                attention_probs_dropout_prob=config["attention_probs_dropout_prob"],
                                                initializer_range=config["initializer_range"],
                                                do_return_all_layers=True)
        
        
    def call(self, inputs):
        
        config = self.config.copy()
        if not self.is_training:
            config["hidden_dropout_prob"] = 0.0
            config["attention_probs_dropout_prob"] = 0.0
        
        if self.input_masks is None:
            self.input_masks = tf.ones(shape=[inputs.shape[0], inputs.shape[1]], dtype=tf.int32)
        
        if self.token_type_ids is None:
            self.token_type_ids = tf.zeros(shape=[inputs.shape[0], inputs.shape[1]], dtype=tf.int32)
        
        to_mask = tf.cast(tf.reshape(self.input_masks, [-1, 1, config["seq_length"]]), tf.float32)
        broadcast_ones = tf.ones(shape=[inputs.shape[0], inputs.shape[1], 1], dtype=tf.float32)
        attention_mask = broadcast_ones * to_mask
        
        self.embedding_output = self.embedding(inputs)
        self.embedding_output = self.embedding_postprocessor(self.embedding_output)
        
        self.transformer_model.attention_mask = attention_mask
        self.all_encoder_layer = self.transformer_model(self.embedding_output)
        
        self.sequence_output = self.all_encoder_layer[-1]
        
        return self.sequence_output    
        

In [14]:
bert_config = {
    "vocab_size" : 32000,
    "hidden_size" : 768,
    "num_hidden_layers" : 12,
    "num_attention_heads" : 12,
    "intermediate_size" : 3072,
    "hidden_act" : 'gelu',
    "hidden_dropout_prob" : 0.1,
    "attention_probs_dropout_prob" : 0.1,
    "max_position_embeddings" : 512,
    "type_vocab_size" : 16,
    "initializer_range" : 0.02,
    "seq_length" : 128
}

In [15]:
bert = BertModel(config=bert_config, is_training=False)

In [16]:
model = tf.keras.Sequential([
    tf.keras.layers.InputLayer([128], batch_size=1),
    bert
])


In [17]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bert_model_1 (BertModel)     (1, 128, 768)             110037504 
Total params: 110,037,504
Trainable params: 110,037,504
Non-trainable params: 0
_________________________________________________________________
