In [1]:
import tensorflow as tf
import numpy as np
from sklearn.metrics import classification_report

tf.logging.set_verbosity(tf.logging.INFO)

In [2]:
VOCAB_SIZE = 20000
EMBED_DIM = 128
RNN_SIZE = 128
CLIP_NORM = 5.0
BATCH_SIZE = 32
DISPLAY_STEP = 50
N_EPOCH = 2
N_CLASS = 2

In [3]:
def sort_by_len(x, y):
    idx = sorted(range(len(x)), key=lambda i: len(x[i]))
    return x[idx], y[idx]

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)
X_train, y_train = sort_by_len(X_train, y_train)
X_test, y_test = sort_by_len(X_test, y_test)

In [4]:
def pad_sentence_batch(sent_batch):
    max_seq_len = max([len(sent) for sent in sent_batch])
    padded_seqs = [(sent + [0]*(max_seq_len - len(sent))) for sent in sent_batch]
    return padded_seqs

def next_train_batch():
    for i in range(0, len(X_train), BATCH_SIZE):
        padded_seqs = pad_sentence_batch(X_train[i : i+BATCH_SIZE])
        yield padded_seqs, y_train[i : i+BATCH_SIZE]
        
def next_test_batch():
    for i in range(0, len(X_test), BATCH_SIZE):
        padded_seqs = pad_sentence_batch(X_test[i : i+BATCH_SIZE])
        yield padded_seqs
        
def train_input_fn():
    dataset = tf.data.Dataset.from_generator(next_train_batch, (tf.int32,tf.int64),
        (tf.TensorShape([None,None]),tf.TensorShape([None])))
    dataset = dataset.repeat(N_EPOCH)
    iterator = dataset.make_one_shot_iterator()
    X_train_batch, y_train_batch = iterator.get_next()
    return {'_': X_train_batch}, y_train_batch

def predict_input_fn():
    dataset = tf.data.Dataset.from_generator(next_test_batch, tf.int32,
        tf.TensorShape([None,None]))
    iterator = dataset.make_one_shot_iterator()
    return {'_': iterator.get_next()}

In [5]:
def rnn_cell():
    return tf.nn.rnn_cell.GRUCell(RNN_SIZE//2, kernel_initializer=tf.orthogonal_initializer())

def forward(inputs, reuse, is_training):
    with tf.variable_scope('model', reuse=reuse):
        x = tf.contrib.layers.embed_sequence(inputs, VOCAB_SIZE, EMBED_DIM)
        x = tf.layers.dropout(x, 0.2, training=is_training)
        _, bi_states = tf.nn.bidirectional_dynamic_rnn(
            rnn_cell(), rnn_cell(), x, tf.count_nonzero(inputs, 1), dtype=tf.float32)
        x = tf.concat(bi_states, -1)
        logits = tf.layers.dense(x, N_CLASS)
    return logits

def clip_grads(loss):
    params = tf.trainable_variables()
    grads = tf.gradients(loss, params)
    clipped_grads, _ = tf.clip_by_global_norm(grads, CLIP_NORM)
    return zip(clipped_grads, params)

def model_fn(features, labels, mode, params):
    logits = forward(features['_'], reuse=False, is_training=True)
    preds = forward(features['_'], reuse=True, is_training=False)
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        preds = tf.argmax(preds, -1)
        return tf.estimator.EstimatorSpec(mode, predictions=preds)
    
    global_step = tf.train.get_global_step()
    
    lr_op = tf.train.exponential_decay(5e-3, global_step, 1400, 0.2)
    
    loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=labels))
    
    train_op = tf.train.AdamOptimizer(lr_op).apply_gradients(
        clip_grads(loss_op), global_step=global_step)
    
    return tf.estimator.EstimatorSpec(
        mode=mode, loss=loss_op, train_op=train_op)

In [6]:
estimator = tf.estimator.Estimator(model_fn)
estimator.train(train_input_fn)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp1ey5ftat', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x1218d0cc0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1 into /var/folders/sx/fv0r97j96fz8njp14dt5g794000

<tensorflow.python.estimator.estimator.Estimator at 0x121364cc0>

In [7]:
y_pred = np.array(list(estimator.predict(predict_input_fn)))
print("Accuracy: %.4f" % (y_pred==y_test).mean())
print(classification_report(y_test, y_pred))

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp1ey5ftat/model.ckpt-1564
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
Accuracy: 0.8950
             precision    recall  f1-score   support

          0       0.88      0.92      0.90     12500
          1       0.92      0.87      0.89     12500

avg / total       0.90      0.90      0.89     25000

