In [1]:
import tensorflow as tf
import numpy as np

In [2]:
tf.enable_eager_execution()

In [2]:
def layer_norm(inputs, eps=1e-6):
    # LayerNorm(x + Sublayer(x))
    feature_shape = inputs.get_shape()[-1:]
    #  평균과 표준편차을 넘겨 준다.
    mean = tf.keras.backend.mean(inputs, [-1], keepdims=True)
    std = tf.keras.backend.std(inputs, [-1], keepdims=True)
    beta = tf.get_variable("beta", initializer=tf.zeros(feature_shape))
    gamma = tf.get_variable("gamma", initializer=tf.ones(feature_shape))

    return gamma * (inputs - mean) / (std + eps) + beta

In [3]:
def sublayer_connection(inputs, sublayer, dropout=0.2):
    outputs = layer_norm(inputs + tf.keras.layers.Dropout(dropout)(sublayer))
    return outputs

In [4]:
def positional_encoding(dim, sentence_length):
    encoded_vec = np.array([pos/np.power(10000, 2*i/dim)
                            for pos in range(sentence_length) for i in range(dim)])

    encoded_vec[::2] = np.sin(encoded_vec[::2])
    encoded_vec[1::2] = np.cos(encoded_vec[1::2])

    return tf.constant(encoded_vec.reshape([sentence_length, dim]), dtype=tf.float32)

In [None]:
class MultiHeadAttention(tf.keras.Model): # In 2.0 tf.keras.Model => tf.layers.Layer
    def __init__(self, num_units, heads, sub_masked=False):
        super(MultiHeadAttention, self).__init__()

        self.heads = heads
        self.sub_masked = sub_masked

        self.query_dense = tf.keras.layers.Dense(num_units, use_bias=False)
        self.key_dense = tf.keras.layers.Dense(num_units, use_bias=False)
        self.value_dense = tf.keras.layers.Dense(num_units, use_bias=False)
        self.out_dense = tf.keras.layers.Dense(num_units, use_bias=False)

    def scaled_dot_product_attention(self, query, key, value, key_mask=None):
        key_seq_length = float(key.get_shape().as_list()[-1])
        key = tf.transpose(key, perm=[0, 2, 1])
        outputs = tf.matmul(query, key) / tf.sqrt(key_seq_length)
        
        masks = tf.ones_like(outputs)
        masks = tf.cast(tf.logical_and(tf.cast(masks, tf.bool),
                                      tf.cast(tf.expand_dims(key_mask, 1), tf.bool)),
                       tf.float32)
        if self.sub_masked:
            diag_vals = tf.ones_like(outputs[0, :, :])
            tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense()
            subsequent_masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1])
            masks = tf.cast(tf.logical_and(tf.cast(masks, tf.bool),
                                          tf.cast(subsequent_masks, tf.bool)),
                           tf.float32)
        inf = tf.ones_like(masks) * (-2 ** 32 + 1)
        outputs = tf.where(tf.equal(masks, 0), inf, outputs)

        attention_map = tf.nn.softmax(outputs)

        return tf.matmul(attention_map, value)

    def call(self, query, key, value, key_mask):
        query = self.query_dense(query)
        key = self.key_dense(key)
        value = self.value_dense(value)

        query = tf.concat(tf.split(query, self.heads, axis=-1), axis=0)
        key = tf.concat(tf.split(key, self.heads, axis=-1), axis=0)
        value = tf.concat(tf.split(value, self.heads, axis=-1), axis=0)

        attention_map = self.scaled_dot_product_attention(query, key, value, key_mask)

        attn_outputs = tf.concat(tf.split(attention_map, self.heads, axis=0), axis=-1)

        return self.out_dense(attn_outputs)

In [10]:
class PositionWiseFeedForward(tf.keras.Model):
    def __init__(self, num_units, feature_shape):
        super(PositionWiseFeedForward, self).__init__()

        self.inner_dense = tf.keras.layers.Dense(num_units, activation=tf.nn.relu)
        self.output_dense = tf.keras.layers.Dense(feature_shape)

    def call(self, inputs):
        inner_layer = self.inner_dense(inputs)
        outputs = self.output_dense(inner_layer)

        return outputs

In [7]:
class Encoder(tf.keras.Model):
    def __init__(self, model_dims, ffn_dims, attn_heads, num_layers=1):
        super(Encoder, self).__init__()

        self.self_attention = [MultiHeadAttention(model_dims, attn_heads) for _ in range(num_layers)]
        self.position_feedforward = [PositionWiseFeedForward(ffn_dims, model_dims) for _ in range(num_layers)]

    def call(self, inputs, src_mask):
        output_layer = None

        for i, (s_a, p_f) in enumerate(zip(self.self_attention, self.position_feedforward)):
            with tf.variable_scope('encoder_layer_' + str(i + 1)):
                attention_layer = sublayer_connection(inputs, s_a(inputs, inputs, inputs, src_mask))
                output_layer = sublayer_connection(attention_layer, p_f(attention_layer))

                inputs = output_layer

        return output_layer

In [9]:
class Decoder(tf.keras.Model):
    def __init__(self, model_dims, ffn_dims, attn_heads, num_layers=1):
        super(Decoder, self).__init__()

        self.self_attention = [MultiHeadAttention(model_dims, attn_heads, sub_masked=True) for _ in range(num_layers)]
        self.encoder_decoder_attention = [MultiHeadAttention(model_dims, attn_heads) for _ in range(num_layers)]
        self.position_feedforward = [PositionWiseFeedForward(ffn_dims, model_dims) for _ in range(num_layers)]

    def call(self, inputs, encoder_outputs, src_mask, tgt_mask):
        output_layer = None

        for i, (s_a, ed_a, p_f) in enumerate(zip(self.self_attention, self.encoder_decoder_attention, self.position_feedforward)):
            with tf.variable_scope('decoder_layer_' + str(i + 1)):
                masked_attention_layer = sublayer_connection(inputs, s_a(inputs, inputs, inputs, tgt_mask))
                attention_layer = sublayer_connection(masked_attention_layer, ed_a(masked_attention_layer,
                                                                                           encoder_outputs,
                                                                                           encoder_outputs, src_mask))
                output_layer = sublayer_connection(attention_layer, p_f(attention_layer))
                inputs = output_layer

        return output_layer