![title](img/self_attn.png)

In [1]:
import tensorflow as tf
import numpy as np
from sklearn.metrics import classification_report

In [2]:
VOCAB_SIZE = 10000
MAX_LEN = 400
HIDDEN_DIM = 50
BATCH_SIZE = 32
LR = {'start': 5e-3, 'end': 5e-4, 'steps': 1500}
N_EPOCH = 2
N_CLASS = 2

In [3]:
def embed_seq(x, vocab_sz, embed_dim, name):
    embedding = tf.get_variable(name, [vocab_sz, embed_dim])
    return tf.nn.embedding_lookup(embedding, x)


def position_embedding(inputs, repr_dim=HIDDEN_DIM):
    T = inputs.get_shape().as_list()[1]
    x = tf.range(T)                            # (T)
    x = tf.expand_dims(x, 0)                   # (1, T)
    x = tf.tile(x, [tf.shape(inputs)[0], 1])   # (N, T)
    return embed_seq(x, T, repr_dim, 'position_embedding')


def google_position_encoding(inputs, repr_dim=HIDDEN_DIM):
    T = inputs.get_shape().as_list()[1]
    pos_idx = tf.tile(tf.expand_dims(tf.range(T), 0), [tf.shape(inputs)[0], 1])

    pos_enc = np.array(
        [[pos / np.power(10000, 2*i/repr_dim) for i in range(repr_dim)] for pos in range(T)])
    pos_enc[:, 0::2] = np.sin(pos_enc[:, 0::2])  
    pos_enc[:, 1::2] = np.cos(pos_enc[:, 1::2]) 

    lookup_table = tf.convert_to_tensor(pos_enc, tf.float32)
    return tf.nn.embedding_lookup(lookup_table, pos_idx)


def importance_weighting(inputs):
    fn = lambda x, k_sz: tf.layers.conv1d(x,
                                          filters=1,
                                          kernel_size=k_sz,
                                          activation=tf.nn.relu,
                                          padding='same')
    return tf.squeeze(fn(inputs, 5), -1)


def forward(inputs, reuse, is_training):
    with tf.variable_scope('model', reuse=reuse):
        V = embed_seq(inputs, VOCAB_SIZE, HIDDEN_DIM, 'word_embedding')
        x = V + google_position_encoding(inputs)
        x = tf.layers.dropout(x, 0.1, training=is_training)
        
        # alignment
        align = importance_weighting(x)
        # masking
        masks = tf.sign(inputs)
        paddings = tf.fill(tf.shape(align), float('-inf'))
        align = tf.where(tf.equal(masks, 0), paddings, align)
        # probability
        align = tf.expand_dims(tf.nn.softmax(align), -1)
        # weighted sum
        x = tf.squeeze(tf.matmul(tf.transpose(V, [0,2,1]), align), -1)
        
        logits = tf.layers.dense(x, N_CLASS)
    return logits


def model_fn(features, labels, mode, params):
    logits = forward(features, reuse=False, is_training=True)
    preds = forward(features, reuse=True, is_training=False)
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        preds = tf.argmax(preds, -1)
        return tf.estimator.EstimatorSpec(mode, predictions=preds)
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_global_step()

        lr_op = tf.train.exponential_decay(
            LR['start'], global_step, LR['steps'], LR['end']/LR['start'])

        loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=labels))

        train_op = tf.train.AdamOptimizer(lr_op).minimize(
            loss_op, global_step=global_step)

        lth = tf.train.LoggingTensorHook({'lr': lr_op}, every_n_iter=100)
        
        return tf.estimator.EstimatorSpec(
            mode=mode, loss=loss_op, train_op=train_op, training_hooks=[lth])

In [4]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)
X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, MAX_LEN, padding='post')
X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, MAX_LEN, padding='post')

estimator = tf.estimator.Estimator(model_fn)

for _ in range(N_EPOCH):
    estimator.train(tf.estimator.inputs.numpy_input_fn(
        x = X_train, y = y_train,
        batch_size = BATCH_SIZE,
        shuffle = True))
    y_pred = np.fromiter(estimator.predict(tf.estimator.inputs.numpy_input_fn(
        x = X_test,
        batch_size = BATCH_SIZE,
        shuffle = False)), np.int32)
    print("\nValidation Accuracy: %.4f\n" % (y_pred==y_test).mean())

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpxrmf9r0c', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x122563d68>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1 into /var/folders/sx/