# 神经机器翻译系统

用TensorFlow搭建seq2seq模型实现了一个简单的神经机器翻译系统，实现英语翻译为法语。Encoder使用双向LSTM。Decoder采用了attention机制。

使用多张计算图分别处理train，eval和infer，并分别在不同的session中进行训练和推断。参数共享用Saver。

增加了tensorboard可视化。

Encoder和Decoder使用了多层LSTM。decoder的initial state采用前向encoder state。

Decoder使用Beam search。

Train model中增加了dropout。

详见TensorFlow教程：https://tensorflow.google.cn/tutorials/seq2seq 和谷歌论文：https://arxiv.org/abs/1609.08144

### 导入包
检查TensorFlow版本和GPU情况

In [1]:
from distutils.version import LooseVersion
import warnings, os
import numpy as np
import tensorflow as tf
from tqdm import tqdm

# Check TensorFlow Version
print('TensorFlow Version: {}'.format(tf.__version__))

# Check for a GPU
if not tf.test.gpu_device_name():
    warnings.warn('No GPU found. Please use a GPU to train your neural network.')
else:
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))

  from ._conv import register_converters as _register_converters


TensorFlow Version: 1.8.0


  if sys.path[0] == '':


### 超参数设置

In [2]:
source_path = 'en-fr/small_vocab_en'
target_path = 'en-fr/small_vocab_fr'
checkpoint_path = './tmp-model.ckpt'
batch_size = 256
num_units = 64
num_layers = 3
max_gradient_norm = 5.0
learning_rate = 0.001
epoch = 6
beam_width = 3
keep_prob = 0.7

### 建立lookup table文件
sos句子开始。eos句子结束。

In [3]:
l=[]
with open(source_path, 'r', encoding='utf-8') as f:
    for line in f.readlines():
        l += line.split()
f.close()
unique_words_src = ['eos'] + list(set(l))

with open('en-fr/words_en', 'w', encoding='utf-8') as f:
    for word in unique_words_src:
        f.write(word + '\n')
f.close()

In [4]:
l=[]
with open(target_path, 'r', encoding='utf-8') as f:
    for line in f.readlines():
        l += line.split()
f.close()
unique_words_tar = ['sos'] + ['eos'] + list(set(l))
print(len(l), len(unique_words_tar))
#print ('rusty' in unique_words_en)

with open('en-fr/words_fr', 'w', encoding='utf-8') as f:
    for word in unique_words_tar:
        f.write(word + '\n')
f.close()

1961295 357


### 使用预训练词向量 FastText

In [5]:
embed_file_src = os.path.join('.', 'fasttext', 'wiki-news-300d-1M.vec')
embed_file_tar = os.path.join('.', 'fasttext', 'cc.fr.300.vec')
embed_size = 300

In [6]:
def get_coefs(word,*arr): return word, np.asarray(arr, dtype='float32')
embeddings_index_src = dict(get_coefs(*o.rstrip().rsplit(' ')) for o in open(embed_file_src, encoding = 'utf-8'))
embeddings_index_tar = dict(get_coefs(*o.rstrip().rsplit(' ')) for o in open(embed_file_tar, encoding = 'utf-8'))

#### 建立两种语言的embedding matrix

In [7]:
embedding_matrix_src = np.random.normal(size=(len(unique_words_src), embed_size), scale=0.01)
embedding_matrix_tar = np.random.normal(size=(len(unique_words_tar), embed_size), scale=0.01)

for i, word in enumerate(unique_words_src):
    embedding_vector = embeddings_index_src.get(word)
    if embedding_vector is not None: embedding_matrix_src[i] = embedding_vector
        
for i, word in enumerate(unique_words_tar):
    embedding_vector = embeddings_index_tar.get(word)
    if embedding_vector is not None: embedding_matrix_tar[i] = embedding_vector

### 生成lookup table函数

In [8]:
def BuildLookupTable(source_words_path, target_words_path):
    lookup_src = tf.contrib.lookup.index_table_from_file(source_words_path)
    lookup_tar = tf.contrib.lookup.index_table_from_file(target_words_path)
    lookup_translate = tf.contrib.lookup.index_to_string_table_from_file(target_words_path)
    return lookup_src, lookup_tar, lookup_translate

### 输入训练文本预处理函数
预处理source和target dataset。文本转成单词id。target开头加一个sos。分batch并pad。末尾用eos补足到最大长度。
这里不需要drop remainder，iterator会自动计算最后一批的样本量。但是后面不能再使用batch size

In [9]:
def BuildTrainDataset(source_path, target_path, src_eos_id, tar_eos_id):
    
    # source
    source_dataset = tf.data.TextLineDataset(source_path)
    source_dataset = source_dataset.map(lambda string: tf.string_split([string]).values)
    source_dataset = source_dataset.map(lambda words: (words, tf.size(words)))
    source_dataset = source_dataset.map(lambda words, size: (lookup_src.lookup(words), size))
    
    # target
    target_dataset = tf.data.TextLineDataset(target_path)
    target_dataset = target_dataset.map(lambda string: tf.string_split([tf.string_join([tf.constant('sos'), string], separator=' ')]).values)
    target_dataset = target_dataset.map(lambda words: (words, tf.size(words)))
    target_dataset = target_dataset.map(lambda words, size: (lookup_tar.lookup(words), size))
    
    # zip source and target
    source_target_dataset = tf.data.Dataset.zip((source_dataset, target_dataset))

    # batch and pad
    batched_dataset = source_target_dataset.padded_batch(
        batch_size,
        padded_shapes=((tf.TensorShape([None]),  # source vectors of unknown size
                        tf.TensorShape([])),     # size(source)
                       (tf.TensorShape([None]),  # target vectors of unknown size
                        tf.TensorShape([]))),    # size(target)
        padding_values=((src_eos_id,  # source vectors padded on the right with src_eos_id
                         0),          # size(source) -- unused
                        (tar_eos_id,  # target vectors padded on the right with tar_eos_id
                         0)))         # size(target) -- unused
    
    return batched_dataset

### Build the train model function
Input: batched and padded dataset iterator ((source, source_lengths), (target, target_lengths))

Output: command to run in train session

这里注意不能用list乘法构建list，会是同一个cell重复。list乘法对于不可变对象是复制值，而可变对象是复制引用。

In [10]:
def BuildTrainModel(train_iterator):
    ((source, source_lengths), (target, target_lengths)) = train_iterator.get_next()
    encoder_inputs = tf.transpose(source, [1,0]) # to time major
    decoder_inputs = tf.transpose(target, [1,0])
    decoder_outputs = tf.pad(decoder_inputs[1:], tf.constant([[0,1],[0,0]]), constant_values=tar_eos_id)

    shape = tf.shape(decoder_outputs)
    target_weights = tf.to_double(tf.where(tf.equal(decoder_outputs, tf.fill(shape, tar_eos_id)), tf.zeros(shape), tf.ones(shape)))
            
    embedding_encoder = tf.Variable(embedding_matrix_src, name='embedding_encoder')
    embedding_decoder = tf.Variable(embedding_matrix_tar, name='embedding_decoder')
    
    # Embedding layer
    encoder_emb_inp = tf.nn.embedding_lookup(embedding_encoder, encoder_inputs)
    decoder_emb_inp = tf.nn.embedding_lookup(embedding_decoder, decoder_inputs)
    
    # Encoder
    # Construct forward and backward cells
    forward_cells = [tf.contrib.rnn.DropoutWrapper(
        tf.nn.rnn_cell.BasicLSTMCell(num_units), input_keep_prob=keep_prob, output_keep_prob=keep_prob) for _ in range(num_layers)]
    backward_cells = [tf.contrib.rnn.DropoutWrapper(
        tf.nn.rnn_cell.BasicLSTMCell(num_units), input_keep_prob=keep_prob, output_keep_prob=keep_prob) for _ in range(num_layers)]
    
    encoder_outputs, encoder_states_fw, encoder_states_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
        forward_cells, backward_cells, encoder_emb_inp, dtype=tf.float64, 
        sequence_length=source_lengths, time_major=True)
    #encoder_states: the final states, one tensor per layer, of the forward/backward rnn

    # Attention
    attention_states = tf.transpose(encoder_outputs, [1, 0, 2])
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
        num_units, attention_states, memory_sequence_length=source_lengths, dtype=tf.float64)
    decoder_cells = [tf.contrib.rnn.DropoutWrapper(
        tf.nn.rnn_cell.BasicLSTMCell(num_units), input_keep_prob=keep_prob, output_keep_prob=keep_prob) for _ in range(num_layers)]
    decoder_cell = tf.contrib.rnn.MultiRNNCell(decoder_cells)
    
    decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism, attention_layer_size=num_units)
    initial_state = decoder_cell.zero_state(dtype=tf.float64, batch_size=tf.shape(encoder_inputs)[1])
    initial_state = initial_state.clone(
        cell_state = encoder_states_fw)
    
    # Projection layer on the top
    projection_layer = tf.layers.Dense(len(unique_words_tar), use_bias=False, name='projection')
    
    # Decoder for training
    helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, target_lengths, time_major=True)
    decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, initial_state, output_layer=projection_layer)
    outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder, output_time_major=True, impute_finished=True)
    logits = outputs.rnn_output
    
    # Loss
    crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=decoder_outputs, logits=logits)
    train_loss = tf.reduce_sum(crossent * target_weights)/ tf.to_double(tf.shape(encoder_inputs)[1])
    tf.summary.scalar('train_loss', train_loss)
    
    # Gradient
    params = tf.trainable_variables()
    gradients = tf.gradients(train_loss, params)
    clipped_gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm)
    
    # Optimization
    optimizer = tf.train.AdamOptimizer(learning_rate)
    update_step = optimizer.apply_gradients(zip(clipped_gradients, params))
    
    return train_loss, update_step

### 设置train graph

In [11]:
train_graph = tf.Graph()
with train_graph.as_default():
    
    # Build the lookup table
    lookup_src, lookup_tar, lookup_translate = BuildLookupTable('en-fr/words_en', 'en-fr/words_fr')
    
    # set the sos and eos
    src_eos_id=lookup_src.lookup(tf.constant('eos')) #0 in source vocab
    tar_sos_id=lookup_tar.lookup(tf.constant('sos')) #0 in target vocab
    tar_eos_id=lookup_tar.lookup(tf.constant('eos')) #1 in target vocab
    
    # Preprocess the text dataset
    batched_dataset = BuildTrainDataset(source_path, target_path, src_eos_id, tar_eos_id)
    
    # Build the train model
    train_iterator = batched_dataset.make_initializable_iterator()
    train_model = BuildTrainModel(train_iterator)
    initializer = tf.global_variables_initializer()
    table_initializer = tf.tables_initializer()
    train_saver = tf.train.Saver(max_to_keep=2)
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter('./vis/train', train_graph)

### Run the train session to train the model and save the variables

In [12]:
train_sess = tf.Session(graph=train_graph)
train_sess.run(initializer)
train_sess.run(table_initializer)
train_sess.run(train_iterator.initializer)

#saver.restore(sess, './tmp-model.ckpt-11')
for i in tqdm(range(epoch)):
    train_sess.run(train_iterator.initializer)
    n_batch=0
    while True:
        try:
            summary, (cost, _) = train_sess.run([merged, train_model])
            train_writer.add_summary(summary)
            n_batch+=1
            print (n_batch)
            if n_batch % 10 == 0:
                print (cost)
        except tf.errors.OutOfRangeError:
            print (cost)
            break
    model_path = train_saver.save(train_sess, checkpoint_path, global_step=i+1)

  0%|                                                    | 0/6 [00:00<?, ?it/s]

1
2
3
4
5
6
7
8
9
10
71.16295964861769
11
12
13
14
15
16
17
18
19
20
61.72331034784902
21
22
23
24
25
26
27
28
29
30
58.19069562833698
31
32
33
34
35
36
37
38
39
40
56.979478704062174
41
42
43
44
45
46
47
48
49
50
52.38928705160461
51
52
53
54
55
56
57
58
59
60
50.57255799183794
61
62
63
64
65
66
67
68
69
70
45.60350089214758
71
72
73
74
75
76
77
78
79
80
43.35324462623079
81
82
83
84
85
86
87
88
89
90
40.344163592039436
91
92
93
94
95
96
97
98
99
100
37.76521949437898
101
102
103
104
105
106
107
108
109
110
36.97331796171849
111
112
113
114
115
116
117
118
119
120
36.373063294875614
121
122
123
124
125
126
127
128
129
130
34.967439948269316
131
132
133
134
135
136
137
138
139
140
34.8652553600875
141
142
143
144
145
146
147
148
149
150
33.9796331662217
151
152
153
154
155
156
157
158
159
160
32.75297392830049
161
162
163
164
165
166
167
168
169
170
31.502517118976996
171
172
173
174
175
176
177
178
179
180
30.025585961497104
181
182
183
184
185
186
187
188
189
190
30.166312162354615
1

 17%|██████▋                                 | 1/6 [30:28<2:32:22, 1828.46s/it]

1
2
3
4
5
6
7
8
9
10
16.088901281942046
11
12
13
14
15
16
17
18
19
20
16.17709101667645
21
22
23
24
25
26
27
28
29
30
15.490097231045866
31
32
33
34
35
36
37
38
39
40
15.625679672430618
41
42
43
44
45
46
47
48
49
50
15.086064536776464
51
52
53
54
55
56
57
58
59
60
15.934091400179373
61
62
63
64
65
66
67
68
69
70
14.568984653231569
71
72
73
74
75
76
77
78
79
80
15.185642203025992
81
82
83
84
85
86
87
88
89
90
15.039941326926836
91
92
93
94
95
96
97
98
99
100
14.53970222453043
101
102
103
104
105
106
107
108
109
110
14.616660389682643
111
112
113
114
115
116
117
118
119
120
14.488761314564897
121
122
123
124
125
126
127
128
129
130
14.937028744304232
131
132
133
134
135
136
137
138
139
140
14.884943223614263
141
142
143
144
145
146
147
148
149
150
15.325488909679882
151
152
153
154
155
156
157
158
159
160
14.417665190773718
161
162
163
164
165
166
167
168
169
170
14.871363178463161
171
172
173
174
175
176
177
178
179
180
13.622161354331649
181
182
183
184
185
186
187
188
189
190
13.56295

 33%|████████████▋                         | 2/6 [1:02:42<2:05:24, 1881.18s/it]

1
2
3
4
5
6
7
8
9
10
10.520289209446872
11
12
13
14
15
16
17
18
19
20
10.78357818287829
21
22
23
24
25
26
27
28
29
30
10.351307716517383
31
32
33
34
35
36
37
38
39
40
10.303158931052515
41
42
43
44
45
46
47
48
49
50
9.997888148431462
51
52
53
54
55
56
57
58
59
60
10.500125341634892
61
62
63
64
65
66
67
68
69
70
9.709786047549645
71
72
73
74
75
76
77
78
79
80
10.01612181501276
81
82
83
84
85
86
87
88
89
90
9.791030685895779
91
92
93
94
95
96
97
98
99
100
9.943471112864387
101
102
103
104
105
106
107
108
109
110
9.652438877670416
111
112
113
114
115
116
117
118
119
120
9.247871787651192
121
122
123
124
125
126
127
128
129
130
9.598787569640159
131
132
133
134
135
136
137
138
139
140
9.530764378288566
141
142
143
144
145
146
147
148
149
150
9.871210131510608
151
152
153
154
155
156
157
158
159
160
9.302098792058134
161
162
163
164
165
166
167
168
169
170
9.857060839842493
171
172
173
174
175
176
177
178
179
180
9.265565532406772
181
182
183
184
185
186
187
188
189
190
8.637747526069523
19

 50%|███████████████████                   | 3/6 [1:35:18<1:35:18, 1906.25s/it]

1
2
3
4
5
6
7
8
9
10
6.020228653425491
11
12
13
14
15
16
17
18
19
20
5.6087511098026415
21
22
23
24
25
26
27
28
29
30
5.661652070438629
31
32
33
34
35
36
37
38
39
40
5.481503597689633
41
42
43
44
45
46
47
48
49
50
5.85930319632252
51
52
53
54
55
56
57
58
59
60
5.331358933720232
61
62
63
64
65
66
67
68
69
70
5.00263880135082
71
72
73
74
75
76
77
78
79
80
5.534681519096724
81
82
83
84
85
86
87
88
89
90
5.259370476512393
91
92
93
94
95
96
97
98
99
100
5.0763383262641995
101
102
103
104
105
106
107
108
109
110
4.85708490305797
111
112
113
114
115
116
117
118
119
120
4.779967916421295
121
122
123
124
125
126
127
128
129
130
4.943532385455992
131
132
133
134
135
136
137
138
139
140
4.933598722979468
141
142
143
144
145
146
147
148
149
150
5.160010184376738
151
152
153
154
155
156
157
158
159
160
4.59609796811966
161
162
163
164
165
166
167
168
169
170
4.940333992787195
171
172
173
174
175
176
177
178
179
180
4.672982006578678
181
182
183
184
185
186
187
188
189
190
3.976719503850151
191
192


 67%|█████████████████████████▎            | 4/6 [2:06:45<1:03:22, 1901.37s/it]

1
2
3
4
5
6
7
8
9
10
2.727049236105876
11
12
13
14
15
16
17
18
19
20
2.89326743727791
21
22
23
24
25
26
27
28
29
30
2.6882483817681297
31
32
33
34
35
36
37
38
39
40
2.7099757866068073
41
42
43
44
45
46
47
48
49
50
2.9769458482946143
51
52
53
54
55
56
57
58
59
60
2.7620121947797722
61
62
63
64
65
66
67
68
69
70
2.3845014944688927
71
72
73
74
75
76
77
78
79
80
2.861998473701473
81
82
83
84
85
86
87
88
89
90
2.8206754675371055
91
92
93
94
95
96
97
98
99
100
2.7222350099234536
101
102
103
104
105
106
107
108
109
110
2.6148599133965105
111
112
113
114
115
116
117
118
119
120
2.441656007156396
121
122
123
124
125
126
127
128
129
130
2.896682936850471
131
132
133
134
135
136
137
138
139
140
2.7925443594268518
141
142
143
144
145
146
147
148
149
150
3.06307510903727
151
152
153
154
155
156
157
158
159
160
2.7016900658218432
161
162
163
164
165
166
167
168
169
170
2.9953851421221946
171
172
173
174
175
176
177
178
179
180
2.810399438510383
181
182
183
184
185
186
187
188
189
190
2.3426007488649

 83%|█████████████████████████████████▎      | 5/6 [2:38:50<31:46, 1906.11s/it]

1
2
3
4
5
6
7
8
9
10
2.107270368356595
11
12
13
14
15
16
17
18
19
20
2.158997189276194
21
22
23
24
25
26
27
28
29
30
1.9418684850828916
31
32
33
34
35
36
37
38
39
40
2.0920370466727976
41
42
43
44
45
46
47
48
49
50
2.0308499499165027
51
52
53
54
55
56
57
58
59
60
1.9982767652087365
61
62
63
64
65
66
67
68
69
70
1.8176824068989443
71
72
73
74
75
76
77
78
79
80
2.304824922824392
81
82
83
84
85
86
87
88
89
90
2.0447900617284365
91
92
93
94
95
96
97
98
99
100
2.290624638379529
101
102
103
104
105
106
107
108
109
110
2.0311662973942193
111
112
113
114
115
116
117
118
119
120
1.8861005369258028
121
122
123
124
125
126
127
128
129
130
2.2275693393525122
131
132
133
134
135
136
137
138
139
140
2.1641036804760345
141
142
143
144
145
146
147
148
149
150
2.3209252048892854
151
152
153
154
155
156
157
158
159
160
2.0368893295666908
161
162
163
164
165
166
167
168
169
170
2.2513293437724595
171
172
173
174
175
176
177
178
179
180
2.160973282232327
181
182
183
184
185
186
187
188
189
190
1.760193044

100%|████████████████████████████████████████| 6/6 [3:10:32<00:00, 1905.46s/it]


### 输入测试文本预处理函数
预处理用于infer的source dataset。文本转成单词id。分batch并pad。末尾用eos补足到最大长度。
这里不需要drop remainder，iterator会自动计算最后一批的样本量。但是后面不能再使用batch size

In [13]:
def BuildTestDataset(source_path, src_eos_id):
    
    # source
    source_dataset = tf.data.TextLineDataset(source_path)
    source_dataset = source_dataset.map(lambda string: tf.string_split([string]).values)
    source_dataset = source_dataset.map(lambda words: (words, tf.size(words)))
    source_dataset = source_dataset.map(lambda words, size: (lookup_src.lookup(words), size))

    # batch and pad
    batched_dataset = source_dataset.padded_batch(
        batch_size,
        padded_shapes=(tf.TensorShape([None]),  # source vectors of unknown size
                        tf.TensorShape([])),     # size(source)
        padding_values=(src_eos_id,  # source vectors padded on the right with src_eos_id
                         0))          # size(source) -- unused
    
    return batched_dataset

### Build the infer model function
Input: batched and padded dataset iterator ((source, source_lengths), (target, target_lengths))

Output: command to run in infer session

In [14]:
def BuildInferModel(test_iterator, tar_sos_id, tar_eos_id):
    (source, source_lengths) = test_iterator.get_next()
    encoder_inputs = tf.transpose(source, [1,0])
    
    embedding_encoder = tf.Variable(embedding_matrix_src, name='embedding_encoder')
    embedding_decoder = tf.Variable(embedding_matrix_tar, name='embedding_decoder')
    
    # Embedding layer
    encoder_emb_inp = tf.nn.embedding_lookup(embedding_encoder, encoder_inputs)
    
    # Encoder
    # Construct forward and backward cells
    forward_cells = [tf.nn.rnn_cell.BasicLSTMCell(num_units) for _ in range(num_layers)]
    backward_cells = [tf.nn.rnn_cell.BasicLSTMCell(num_units) for _ in range(num_layers)]
    
    encoder_outputs, encoder_states_fw, encoder_states_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
        forward_cells, backward_cells, encoder_emb_inp, dtype=tf.float64, 
        sequence_length=source_lengths, time_major=True)
    #encoder_states: the final states, one tensor per layer, of the forward/backward rnn
    
    # Attention
    attention_states = tf.transpose(encoder_outputs, [1, 0, 2])
    tiled_attention_states = tf.contrib.seq2seq.tile_batch(attention_states, multiplier=beam_width)
    decoder_initial_state = tf.contrib.seq2seq.tile_batch(encoder_states_fw, multiplier=beam_width)
    tiled_source_lengths = tf.contrib.seq2seq.tile_batch(source_lengths, multiplier=beam_width)
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
        num_units, 
        memory=tiled_attention_states, 
        memory_sequence_length=tiled_source_lengths, 
        dtype=tf.float64)
    decoder_cells = [tf.nn.rnn_cell.BasicLSTMCell(num_units) for _ in range(num_layers)]
    decoder_cell = tf.contrib.rnn.MultiRNNCell(decoder_cells)
    
    decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism, attention_layer_size=num_units)
    initial_state = decoder_cell.zero_state(dtype=tf.float64, batch_size=tf.shape(encoder_inputs)[1] * beam_width)
    initial_state = initial_state.clone(cell_state = decoder_initial_state)
    
    # Projection layer on the top
    projection_layer = tf.layers.Dense(len(unique_words_tar), use_bias=False, name='projection')
    
    # Decoder to infer using beam search
    start_tokens = tf.fill([tf.shape(encoder_inputs)[1]], tf.to_int32(tar_sos_id))
    end_token = tf.to_int32(tar_eos_id)
    
    decoder = tf.contrib.seq2seq.BeamSearchDecoder(
        cell = decoder_cell, 
        embedding = embedding_decoder,
        start_tokens = start_tokens,
        end_token = end_token,
        initial_state = initial_state, 
        beam_width = beam_width,
        output_layer=projection_layer,
        length_penalty_weight=0.0) # Data type bug if length_penalty_weight=0.6?
    maximum_iterations = tf.round(tf.reduce_max(tiled_source_lengths) * 2)
    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
        decoder, 
        maximum_iterations=maximum_iterations, 
        output_time_major=True, 
        impute_finished=False) # when using beam search, bug if impute_finished=True?
    
    beam_id = tf.to_int64(outputs.predicted_ids) # [time_step, batch_size, beam_width] Beams are ordered from best to worst
    translation_id = tf.unstack(beam_id, axis=2)[0] #[time_step, batch_size]
    #print (translation_id, beam_id)
    translation = lookup_translate.lookup(translation_id)
    
    return translation

### 设置infer graph

In [15]:
infer_graph = tf.Graph()
with infer_graph.as_default():
    
    # Build the lookup table
    lookup_src, lookup_tar, lookup_translate = BuildLookupTable('en-fr/words_en', 'en-fr/words_fr')
    
    # set the sos and eos
    src_eos_id=lookup_src.lookup(tf.constant('eos')) #0 in source vocab
    tar_sos_id=lookup_tar.lookup(tf.constant('sos')) #0 in target vocab
    tar_eos_id=lookup_tar.lookup(tf.constant('eos')) #1 in target vocab
    # Preprocess the text dataset
    batched_dataset = BuildTestDataset(source_path = source_path, src_eos_id = src_eos_id)
    
    # Build the train model
    test_iterator = batched_dataset.make_initializable_iterator()
    infer_model = BuildInferModel(test_iterator, tar_sos_id, tar_eos_id)
    infer_saver = tf.train.Saver()
    table_initializer = tf.tables_initializer()
    infer_writer = tf.summary.FileWriter('./vis/infer', infer_graph)

### Run the infer session to translate new sentences

In [16]:
infer_sess = tf.Session(graph=infer_graph)
infer_sess.run(table_initializer)
infer_saver.restore(infer_sess, model_path)
infer_sess.run(test_iterator.initializer)
n_batch=0
f = open('en-fr/trans_fr', 'w', encoding='utf-8')
while True:
    try:
        tar_sentences = infer_sess.run(infer_model)
        n_batch+=1
        tar_sentences = np.transpose(tar_sentences)
        for sentence in tar_sentences:
            for word in sentence:
                if word == 'eos': break
                f.write(word.decode('utf-8') + ' ')
            f.write('\n')
        print(n_batch)
    except tf.errors.OutOfRangeError:
        break
f.close()

# Close the session manually and release resources
train_sess.close()
infer_sess.close()

INFO:tensorflow:Restoring parameters from ./tmp-model.ckpt-6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

saver = tf.train.Saver()
with tf.Session() as sess: #debug
    sess.run(tf.global_variables_initializer())
    tf.tables_initializer().run()
    sess.run(batched_iterator.initializer)
    n_batch=0
    ei, di, do = sess.run([encoder_inputs, decoder_inputs, decoder_outputs])
            #print (np.shape(ei), np.shape(di), np.shape(do))
    n_batch+=1
    print(n_batch)
    model_path = saver.save(sess, './tmp-model.ckpt')