In [1]:
# original code: https://github.com/fchollet/keras/blob/master/examples/lstm_seq2seq.py
from keras.models import Model
from keras.layers import Input, LSTM, Dense
import numpy as np

Using TensorFlow backend.


In [2]:
batch_size = 128
n_epoch = 50
latent_dim = 256
n_samples = 10000
data_path = 'fra-eng/fra.txt' # you have to download dataset from http://www.manythings.org/anki/

In [3]:
input_texts, target_texts = [], []
input_vocab, target_vocab = set(), set()
lines = open(data_path).read().split('\n')
for line in lines[:min(n_samples, len(lines) -1)]:
    in_txt, tg_txt = line.split('\t')
    tg_txt = '\t' + tg_txt + '\n' # \t for <start> word and \n for <end> word for Decoder
    input_texts.append(in_txt)
    target_texts.append(tg_txt)
    for c in in_txt:
        if c not in input_vocab:
            input_vocab.add(c)
    for c in tg_txt:
        if c not in target_vocab:
            target_vocab.add(c)

input_vocab = sorted(list(input_vocab))   
target_vocab = sorted(list(target_vocab)) 
n_input_vocab = len(input_vocab)
n_target_vocab = len(target_vocab)
max_encoder_seq_len = max([len(txt) for txt in input_texts])
max_decoder_seq_len = max([len(txt) for txt in target_texts])

In [4]:
print('n_input_vocab', n_input_vocab)
print('n_target_vocab', n_target_vocab)
print('max enc len', max_encoder_seq_len)
print('max dec len', max_decoder_seq_len)

n_input_vocab 71
n_target_vocab 93
max enc len 16
max dec len 59


In [5]:
input_w2i = {w:i for i,w in enumerate(input_vocab)}
input_i2w = {i:w for i,w in enumerate(input_vocab)}
target_w2i = {w:i for i,w in enumerate(target_vocab)}
target_i2w = {i:w for i,w in enumerate(target_vocab)}

In [6]:
enc_input_data = np.zeros( (len(input_texts), max_encoder_seq_len, n_input_vocab) )
dec_input_data = np.zeros( (len(input_texts), max_decoder_seq_len, n_target_vocab) )
dec_target_data = np.zeros( (len(input_texts), max_decoder_seq_len, n_target_vocab) )

for i, (in_text, tg_text) in enumerate(zip(input_texts, target_texts)):
    for t, char in enumerate(in_text):
        enc_input_data[i, t, input_w2i[char]] = 1.
    for t, char in enumerate(tg_text):
        dec_input_data[i, t, target_w2i[char]] = 1.
        if t > 0:
            dec_target_data[i, t-1, target_w2i[char]] = 1.

In [7]:
# Encoder
encoder_inputs = Input(shape=(None, n_input_vocab))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
encoder_states = [state_h, state_c] # discard encoder_outputs

# Decoder
decoder_inputs = Input(shape=(None, n_target_vocab))
decoder = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder(decoder_inputs, initial_state=encoder_states)
decoder_dense = Dense(n_target_vocab, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
print(model.summary())
model.fit([enc_input_data, dec_input_data], dec_target_data,
          batch_size=batch_size,
          epochs=n_epoch,
          validation_split=.2)

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_1 (InputLayer)             (None, None, 71)      0                                            
____________________________________________________________________________________________________
input_2 (InputLayer)             (None, None, 93)      0                                            
____________________________________________________________________________________________________
lstm_1 (LSTM)                    [(None, 256), (None,  335872      input_1[0][0]                    
____________________________________________________________________________________________________
lstm_2 (LSTM)                    [(None, None, 256), ( 358400      input_2[0][0]                    
                                                                   lstm_1[0][1]            

<keras.callbacks.History at 0x7f8391e66fd0>

In [8]:
model.save('seq2seq.h5')

# Inference

In [9]:
encoder_model = Model(encoder_inputs, encoder_states)
decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder(decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs] + decoder_states
)
print('Encoder Model', encoder_model.summary())
print('Decoder Model', decoder_model.summary())

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, None, 71)          0         
_________________________________________________________________
lstm_1 (LSTM)                [(None, 256), (None, 256) 335872    
Total params: 335,872
Trainable params: 335,872
Non-trainable params: 0
_________________________________________________________________
Encoder Model None
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_2 (InputLayer)             (None, None, 93)      0                                            
____________________________________________________________________________________________________
input_3 (InputLayer)             (None, 256)           0                                            
________

In [12]:
def decode_sequence(input_seq):
    states_value = encoder_model.predict(input_seq)
    
    target_seq = np.zeros((1, 1, n_target_vocab)) # (n_samples, seq_len, n_vocab)
    target_seq[0, 0, target_w2i['\t']] = 1. # start character
    
    stop_cond = False
    decoded_sentence = ''
    while not stop_cond:
        output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = target_i2w[sampled_token_index]
        decoded_sentence += sampled_char
        
        if sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_len:
            stop_cond = True
        
        target_seq = np.zeros((1, 1, n_target_vocab))
        target_seq[0, 0, sampled_token_index] = 1.
        
        states_value = [h, c]
    
    return decoded_sentence

for seq_idx in range(100):
    input_seq = enc_input_data[seq_idx: seq_idx+1]
    decoded_sentence = decode_sequence(input_seq)
    print('----')
    print('input sentence:', input_texts[seq_idx])
    print('decoded sentence:', decoded_sentence)    

----
input sentence: Go.
decoded sentence: Va !

----
input sentence: Run!
decoded sentence: Dépêche-toi.

----
input sentence: Run!
decoded sentence: Dépêche-toi.

----
input sentence: Wow!
decoded sentence: Bien vuelle chonde !

----
input sentence: Fire!
decoded sentence: Oubliez !

----
input sentence: Help!
decoded sentence: Va donce tour !

----
input sentence: Jump.
decoded sentence: Tout la monde !

----
input sentence: Stop!
decoded sentence: Arrêtez !

----
input sentence: Stop!
decoded sentence: Arrêtez !

----
input sentence: Stop!
decoded sentence: Arrêtez !

----
input sentence: Wait!
decoded sentence: Attends !

----
input sentence: Wait!
decoded sentence: Attends !

----
input sentence: I see.
decoded sentence: Je vous ai sauvées.

----
input sentence: I try.
decoded sentence: Je vous ai sauvées.

----
input sentence: I won!
decoded sentence: Je l'ai vu.

----
input sentence: I won!
decoded sentence: Je l'ai vu.

----
input sentence: Oh no!
decoded sentence: Ou le tour 