In [1]:
from __future__ import division, print_function, absolute_import

import tflearn
from tflearn.data_utils import to_categorical, pad_sequences
from tflearn.datasets import imdb

# IMDB Dataset loading
train, test, _ = imdb.load_data(path='imdb.pkl', n_words=10000,
                                valid_portion=0.1)
trainX, trainY = train
testX, testY = test

# Data preprocessing
# Sequence padding
trainX = pad_sequences(trainX, maxlen=100, value=0.)
testX = pad_sequences(testX, maxlen=100, value=0.)
# Converting labels to binary vectors
trainY = to_categorical(trainY, nb_classes=2)
testY = to_categorical(testY, nb_classes=2)

# Network building
net = tflearn.input_data([None, 100])
net = tflearn.embedding(net, input_dim=10000, output_dim=128)
net = tflearn.lstm(net, 128, dropout=0.8)
net = tflearn.fully_connected(net, 2, activation='softmax')
net = tflearn.regression(net, optimizer='adam', learning_rate=0.001,
                         loss='categorical_crossentropy')

# Training
model = tflearn.DNN(net, tensorboard_verbose=0)
model.fit(trainX, trainY, validation_set=(testX, testY), show_metric=True,
          batch_size=32)

Training Step: 7040  | total loss: [1m[32m0.07188[0m[0m
| Adam | epoch: 010 | loss: 0.07188 - acc: 0.9835 | val_loss: 0.66640 - val_acc: 0.8084 -- iter: 22500/22500
Training Step: 7040  | total loss: [1m[32m0.07188[0m[0m
| Adam | epoch: 010 | loss: 0.07188 - acc: 0.9835 | val_loss: 0.66640 - val_acc: 0.8084 -- iter: 22500/22500
--


In [8]:
model.predict(testX)

[[0.6326836347579956, 0.367316335439682],
 [0.3748258948326111, 0.6251740455627441],
 [0.01972774788737297, 0.9802722334861755],
 [0.002958473516628146, 0.9970415234565735],
 [0.9891246557235718, 0.010875380598008633],
 [0.017686232924461365, 0.9823137521743774],
 [0.024001287296414375, 0.9759986400604248],
 [0.006551928352564573, 0.9934480786323547],
 [0.4850422143936157, 0.5149577856063843],
 [0.002300182357430458, 0.9976997971534729],
 [0.015558102168142796, 0.9844419360160828],
 [0.9972615242004395, 0.002738514682278037],
 [0.007078435271978378, 0.9929215312004089],
 [0.0022463174536824226, 0.9977536797523499],
 [0.7164676785469055, 0.2835322916507721],
 [0.9891040325164795, 0.01089593768119812],
 [0.9959967136383057, 0.0040033115074038506],
 [0.02902616560459137, 0.9709738492965698],
 [0.2994113266468048, 0.7005887031555176],
 [0.025919489562511444, 0.974080502986908],
 [0.964852511882782, 0.035147525370121],
 [0.005252168048173189, 0.9947478175163269],
 [0.9979396462440491, 0.002