In [None]:
# Install the h5py package if it isn't installed yet with running in your command line interface
# sudo pip install h5py

In [64]:
import numpy as np
from ch10 import instantiate_seq2seq_model

In [65]:
try:
    import cPickle as pickle
except ImportError:
    import pickle

from io import open

with open("../data/characters_stats.pkl", "rb") as filehandler:
    input_characters, target_characters, input_token_index, target_token_index = pickle.load(filehandler)

with open("../data/encoder_decoder_stats.pkl", "rb") as filehandler:
    num_encoder_tokens, num_decoder_tokens, max_encoder_seq_length, max_decoder_seq_length = pickle.load(filehandler)

In [66]:
# create a dict to look up predicted tokens
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())

In [67]:
_, encoder_model, decoder_model = instantiate_seq2seq_model(num_encoder_tokens, num_decoder_tokens, latent_dim=256)

In [68]:
encoder_model.load_weights('../data/encoder_seq2seq.hd5')

In [69]:
decoder_model.load_weights('../data/decoder_seq2seq.hd5')

In [70]:
def decode_sequence(input_seq):
    # Encode the input as state vectors.
    states_value = encoder_model.predict(input_seq)

    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index['\t']] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict(
            [target_seq] + states_value)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or
           len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        # Update states
        states_value = [h, c]

    return decoded_sentence

In [71]:
def response(input_text):
    input_text = input_text.lower()
    input_seq = np.zeros((1, max_encoder_seq_length, num_encoder_tokens), dtype='float32')
    for t, char in enumerate(input_text):
        input_seq[0, t, input_token_index[char]] = 1.
    decoded_sentence = decode_sequence(input_seq)
    print('Decoded sentence:', decoded_sentence)

In [72]:
input_text = 'hello. how are you?'
response(input_text)

Decoded sentence: hello, mr. president.



In [73]:
input_text = 'Do you cheer for football?'
response(input_text)

Decoded sentence: i don't know what that means a lot of enemenee.



In [None]:
input_text = 'What about basketball?'
response(input_text)