Author: [Zhedong Zheng](https://github.com/zhedongzheng)

<img src="img/self_attn.png" width="300">

In [1]:
import tensorflow as tf
import numpy as np
import sklearn

In [2]:
VOCAB_SIZE = 5000
EMBED_DIM = 50
BATCH_SIZE = 32
LR = {'start': 5e-3, 'end': 5e-4, 'steps': 1500}
N_EPOCH = 2
N_CLASS = 2

In [3]:
def sort_by_len(x, y):
    idx = sorted(range(len(x)), key=lambda i: len(x[i]))
    return x[idx], y[idx]

def pad_sentence_batch(sent_batch):
    max_seq_len = max([len(sent) for sent in sent_batch])
    padded_seqs = [(sent + [0]*(max_seq_len - len(sent))) for sent in sent_batch]
    return padded_seqs

def next_train_batch(X_train, y_train):
    for i in range(0, len(X_train), BATCH_SIZE):
        padded_seqs = pad_sentence_batch(X_train[i : i+BATCH_SIZE])
        yield padded_seqs, y_train[i : i+BATCH_SIZE]
        
def next_test_batch(X_test):
    for i in range(0, len(X_test), BATCH_SIZE):
        padded_seqs = pad_sentence_batch(X_test[i : i+BATCH_SIZE])
        yield padded_seqs
        
def train_input_fn(X_train, y_train):
    dataset = tf.data.Dataset.from_generator(
        lambda: next_train_batch(X_train, y_train),
        (tf.int32, tf.int64),
        (tf.TensorShape([None,None]), tf.TensorShape([None])))
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()

def predict_input_fn(X_test):
    dataset = tf.data.Dataset.from_generator(
        lambda: next_test_batch(X_test),
        tf.int32,
        tf.TensorShape([None,None]))
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()

In [4]:
def forward(inputs, mode):
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    
    values = tf.contrib.layers.embed_sequence(inputs, VOCAB_SIZE, EMBED_DIM)
    
    x = tf.layers.dropout(values, 0.2, training=is_training)
    
    # alignment
    align = tf.squeeze(tf.layers.dense(x, 1, tf.tanh), -1)
    
    # masking
    masks = tf.sign(inputs)
    paddings = tf.fill(tf.shape(align), float('-inf'))
    align = tf.where(tf.equal(masks, 0), paddings, align)
    
    # probability
    align = tf.expand_dims(tf.nn.softmax(align), -1)
    
    # weighted sum
    x = tf.squeeze(tf.matmul(values, align, transpose_a=True), -1)
    
    logits = tf.layers.dense(x, N_CLASS)
    return logits


def model_fn(features, labels, mode):
    logits = forward(features, mode)
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        preds = tf.argmax(logits, -1)
        return tf.estimator.EstimatorSpec(mode, predictions=preds)
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_global_step()

        lr_op = tf.train.exponential_decay(
            LR['start'], global_step, LR['steps'], LR['end']/LR['start'])

        loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=labels))

        train_op = tf.train.AdamOptimizer(lr_op).minimize(
            loss_op, global_step=global_step)

        lth = tf.train.LoggingTensorHook({'lr': lr_op}, every_n_iter=100)
        
        return tf.estimator.EstimatorSpec(
            mode=mode, loss=loss_op, train_op=train_op, training_hooks=[lth])

In [5]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)
X_train, y_train = sort_by_len(X_train, y_train)
X_test, y_test = sort_by_len(X_test, y_test)

estimator = tf.estimator.Estimator(model_fn)

for _ in range(N_EPOCH):
    estimator.train(lambda: train_input_fn(X_train, y_train))
    y_pred = np.fromiter(estimator.predict(lambda: predict_input_fn(X_test)), np.int32)
    print("\nValidation Accuracy: %.4f\n" % (y_pred==y_test).mean())

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp3pcf0w9u', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11f8fc630>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1 into /var/folders/sx/