In [1]:
import tensorflow as tf
import numpy as np
from sklearn.metrics import classification_report

In [2]:
VOCAB_SIZE = 5000
MAX_LEN = 400
BATCH_SIZE = 32
EMBED_DIM = 50
FILTERS = 250
N_CLASS = 2
N_EPOCH = 2
DISPLAY_STEP = 50

In [3]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)
X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, maxlen=MAX_LEN)
X_test = tf.keras.preprocessing.sequence.pad_sequences(X_test, maxlen=MAX_LEN)

In [4]:
def pipeline(mode):
    if mode == tf.estimator.ModeKeys.TRAIN:
        dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
        dataset = dataset.batch(BATCH_SIZE).repeat(N_EPOCH)
    if mode == tf.estimator.ModeKeys.PREDICT:
        dataset = tf.data.Dataset.from_tensor_slices(X_test)
        dataset = dataset.batch(BATCH_SIZE)
    iterator = dataset.make_one_shot_iterator()
    return iterator

In [5]:
def forward(x, reuse, is_training):
    with tf.variable_scope('model', reuse=reuse):
        x = tf.contrib.layers.embed_sequence(x, VOCAB_SIZE, EMBED_DIM)
        x = tf.layers.dropout(x, 0.2, training=is_training)
        feat_map = []
        for k_size in [3, 4, 5]:
            _x = tf.layers.conv1d(x, FILTERS, k_size, activation=tf.nn.relu)
            _x = tf.layers.max_pooling1d(_x, _x.get_shape().as_list()[1], 1)
            _x = tf.reshape(_x, (tf.shape(x)[0], FILTERS))
            feat_map.append(_x)
        x = tf.concat(feat_map, -1)
        x = tf.layers.dense(x, FILTERS, tf.nn.relu)
        logits = tf.layers.dense(x, N_CLASS)
    return logits

In [6]:
ops = {}

X_train_batch, y_train_batch = pipeline('train').get_next()

logits_train_batch = forward(X_train_batch, reuse=False, is_training=True)

ops['global_step'] = tf.Variable(0, trainable=False)

ops['lr'] = tf.train.exponential_decay(5e-3, ops['global_step'], 1400, 0.2)

ops['loss'] = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits=logits_train_batch, labels=y_train_batch))

ops['train'] = tf.train.AdamOptimizer(ops['lr']).minimize(
    ops['loss'], global_step=ops['global_step'])

Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead


In [7]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
while True:
    try:
        sess.run(ops['train'])
    except tf.errors.OutOfRangeError:
        break
    else:
        step = sess.run(ops['global_step'])
        if step % DISPLAY_STEP == 0 or step == 1:
            loss, lr = sess.run([ops['loss'], ops['lr']])
            print("Step %d | Loss %.3f | LR: %.4f" % (step, loss, lr))

Step 1 | Loss 0.767 | LR: 0.0050
Step 50 | Loss 0.413 | LR: 0.0047
Step 100 | Loss 0.567 | LR: 0.0045
Step 150 | Loss 0.567 | LR: 0.0042
Step 200 | Loss 0.344 | LR: 0.0040
Step 250 | Loss 0.410 | LR: 0.0038
Step 300 | Loss 0.378 | LR: 0.0035
Step 350 | Loss 0.244 | LR: 0.0033
Step 400 | Loss 0.210 | LR: 0.0032
Step 450 | Loss 0.130 | LR: 0.0030
Step 500 | Loss 0.403 | LR: 0.0028
Step 550 | Loss 0.387 | LR: 0.0027
Step 600 | Loss 0.242 | LR: 0.0025
Step 650 | Loss 0.306 | LR: 0.0024
Step 700 | Loss 0.200 | LR: 0.0022
Step 750 | Loss 0.248 | LR: 0.0021
Step 800 | Loss 0.154 | LR: 0.0020
Step 850 | Loss 0.232 | LR: 0.0019
Step 900 | Loss 0.118 | LR: 0.0018
Step 950 | Loss 0.206 | LR: 0.0017
Step 1000 | Loss 0.148 | LR: 0.0016
Step 1050 | Loss 0.121 | LR: 0.0015
Step 1100 | Loss 0.186 | LR: 0.0014
Step 1150 | Loss 0.219 | LR: 0.0013
Step 1200 | Loss 0.327 | LR: 0.0013
Step 1250 | Loss 0.105 | LR: 0.0012
Step 1300 | Loss 0.226 | LR: 0.0011
Step 1350 | Loss 0.208 | LR: 0.0011
Step 1400 | Los

In [8]:
ops['predict'] = forward(pipeline('infer').get_next(), reuse=True, is_training=False)

In [9]:
y_pred_li = []
while True:
    try:
        y_pred_li.append(sess.run(ops['predict']))
    except tf.errors.OutOfRangeError:
        break
y_pred = np.argmax(np.vstack(y_pred_li), 1)
print("Accuracy: %.4f" % (y_pred==y_test).mean())
print(classification_report(y_test, y_pred))

Accuracy: 0.8952
             precision    recall  f1-score   support

          0       0.92      0.87      0.89     12500
          1       0.88      0.92      0.90     12500

avg / total       0.90      0.90      0.90     25000

