In [5]:
import tensorflow as tf
import tensorlayer as tl
import numpy as np
from tensorlayer.cost import cross_entropy_seq, cross_entropy_seq_with_mask
from tqdm import tqdm
from sklearn.utils import shuffle
from data.twitter import data
from tensorlayer.models.seq2seq import Seq2seq
from tensorlayer.models.seq2seq_with_attention import Seq2seqLuongAttention
import os

# Load saved model and required vectorizer for data preprocessing

In [6]:
def initial_setup(data_corpus):
    metadata, idx_q, idx_a = data.load_data(PATH='data/{}/'.format(data_corpus))
    (trainX, trainY), (testX, testY), (validX, validY) = data.split_dataset(idx_q, idx_a)
    trainX = tl.prepro.remove_pad_sequences(trainX.tolist())
    trainY = tl.prepro.remove_pad_sequences(trainY.tolist())
    testX = tl.prepro.remove_pad_sequences(testX.tolist())
    testY = tl.prepro.remove_pad_sequences(testY.tolist())
    validX = tl.prepro.remove_pad_sequences(validX.tolist())
    validY = tl.prepro.remove_pad_sequences(validY.tolist())
    return metadata, trainX, trainY, testX, testY, validX, validY


data_corpus = "twitter"
#data preprocessing
metadata, trainX, trainY, testX, testY, validX, validY = initial_setup(data_corpus)


batch_size = 32
src_vocab_size = len(metadata['idx2w']) # 8002 (0~8001)
emb_dim = 1024

word2idx = metadata['w2idx']   # dict  word 2 index
idx2word = metadata['idx2w']   # list index 2 word

unk_id = word2idx['unk']   # 1
pad_id = word2idx['_']     # 0

start_id = src_vocab_size  # 8002
end_id = src_vocab_size + 1  # 8003

word2idx.update({'start_id': start_id})
word2idx.update({'end_id': end_id})
idx2word = idx2word + ['start_id', 'end_id']

src_vocab_size = tgt_vocab_size = src_vocab_size + 2

num_epochs = 1
vocabulary_size = src_vocab_size
decoder_seq_length = 20


def inference(seed, top_n):
    model_.eval()
    seed_id = [word2idx.get(w, unk_id) for w in seed.split(" ")]
    sentence_id = model_(inputs=[[seed_id]], seq_length=20, start_token=start_id, top_n = top_n)
    sentence = []
    for w_id in sentence_id[0]:
        w = idx2word[w_id]
        if w == 'end_id':
            break
        sentence = sentence + [w]
    return sentence

# create model object
model_ = Seq2seq(
        decoder_seq_length = decoder_seq_length,
        cell_enc=tf.keras.layers.GRUCell,
        cell_dec=tf.keras.layers.GRUCell,
        n_layer=3,
        n_units=256,
        embedding_layer=tl.layers.Embedding(vocabulary_size=vocabulary_size, embedding_size=emb_dim),
        )
    
optimizer = tf.optimizers.Adam(learning_rate=0.001)

# Load the pretrained model
load_weights = tl.files.load_npz(name='model.npz')
tl.files.assign_weights(load_weights, model_)


[TL] Embedding embedding_2: (8004, 1024)
[TL] RNN rnn_7: cell: GRUCell, n_units: 256
[TL] RNN rnn_8: cell: GRUCell, n_units: 256
[TL] RNN rnn_9: cell: GRUCell, n_units: 256
[TL] RNN rnn_10: cell: GRUCell, n_units: 256
[TL] RNN rnn_11: cell: GRUCell, n_units: 256
[TL] RNN rnn_12: cell: GRUCell, n_units: 256
[TL] Reshape reshape_4
[TL] Dense  dense_2: 8004 No Activation
[TL] Reshape reshape_5
[TL] Reshape reshape_6


[<tf.Variable 'UnreadVariable' shape=(1024, 768) dtype=float32, numpy=
 array([[ 0.2162723 , -0.06633563, -0.39152405, ..., -0.06113872,
          0.06669622,  0.13834399],
        [-0.05662856, -0.10133345, -0.03916978, ...,  0.18382408,
          0.07260212, -0.1275649 ],
        [ 0.08877234,  0.08594512,  0.32486606, ..., -0.3008156 ,
         -0.1724596 , -0.10252984],
        ...,
        [ 0.0380473 ,  0.5325081 , -0.09565561, ...,  0.04314809,
         -0.19246054,  0.07743181],
        [-0.36855173, -0.23458353,  0.12079325, ..., -0.27613902,
          0.01232626,  0.34530547],
        [-0.23640853,  0.32341713, -0.03062935, ...,  0.22687641,
          0.15495224,  0.12186057]], dtype=float32)>,
 <tf.Variable 'UnreadVariable' shape=(256, 768) dtype=float32, numpy=
 array([[-0.05579139,  0.3407311 ,  0.08869815, ...,  0.1716413 ,
          0.1292236 , -0.03037982],
        [-0.09042484,  0.04391455,  0.08342439, ..., -0.24368715,
          0.00197989, -0.34206912],
        [ 0.

# Test the model on some basic sentences

In [15]:
test_questions = ["How are you?", "what do you do ?"]

In [16]:
 for question in test_questions:
    print("Query >", question)
    top_n = 3 # number of top answers needed
    for i in range(top_n):
        sentence = inference(question, top_n)
        print(" >", ' '.join(sentence))

Query > the important ones are still here
 > i think we have to do this before we die
 > i agree i think its not
 > i know you can never ask for icloud unless they are
Query > who won the first presidential debate
 > trump won the debate
 > trump won the election debate
 > trump is going on the popcorn camp and not neither
