In [18]:
%load_ext autoreload
%autoreload 2
import sys

sys.path.append('../')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
import time
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

from seq2seq_tf2.batcher import train_batch_generator
from utils.data_loader import build_dataset,load_dataset,preprocess_sentence,load_test_dataset
from utils.wv_loader import load_embedding_matrix,load_vocab
from utils.config import *
from utils.params_utils import *
from utils.gpu_utils import config_gpu
from utils.plot_utils import plot_attention

from gensim.models.word2vec import LineSentence, Word2Vec
from tqdm import tqdm

import warnings

warnings.filterwarnings("ignore")

In [20]:
# 配置GPU
config_gpu()

1 Physical GPUs, 1 Logical GPUs


# 预处理数据

In [21]:
%%time
# build_dataset(train_data_path,test_data_path)

Wall time: 0 ns


# 参数设置  数据加载

In [64]:
# 加载vocab
vocab,reverse_vocab=load_vocab(vocab_path)

# 加载预训练权重
embedding_matrix=load_embedding_matrix()

params = {}
params["vocab_size"] = len(vocab)
params["embed_size"] = 500
params["enc_units"] = 512
params["attn_units"] = 512
params["dec_units"] = 512
params["batch_size"] = 32
params["epochs"] = 5
params["max_enc_len"] = 200
params["max_dec_len"] = 41 

# 加载数据集
dataset, steps_per_epoch = train_batch_generator(batch_size=32,
                                                 max_enc_len=params["max_enc_len"],
                                                 max_dec_len=params["max_dec_len"])
test_X = load_test_dataset(params["max_dec_len"])

max_enc_len 200
load_train_dataset返回到train_batch_generator里边的训练集训练数据train_X： [[31816   415   903 ... 31818 31818 31818]
 [31816   813 31819 ... 31818 31818 31818]
 [31816  1393    88 ...  3321  6567  2232]
 ...
 [31816   225   894 ... 31818 31818 31818]
 [31816 12684  3145 ... 31818 31818 31818]
 [31816  3275    75 ...   409     1     3]]
load_train_dataset返回到train_batch_generator里边的训练集测试数据train_Y： [[31816   326   391 ... 31818 31818 31818]
 [31816   326   391 ... 31818 31818 31818]
 [31816    80     8 ... 31818 31818 31818]
 ...
 [31816    32    23 ... 31818 31818 31818]
 [31816    32    23 ... 31818 31818 31818]
 [31816    32    23 ... 31818 31818 31818]]
load_train_dataset返回到train_batch_generator里边的训练集训练数据的形状： (82873, 200)
load_train_dataset返回到train_batch_generator里边的训练集测试数据的形状： (82873, 41)


# 模型训练

## 构建模型

In [65]:
from seq2seq_tf2.seq2seq_model import Seq2Seq

# 传入参数，构建初始化模型框架   checkpoint的保存的参数，加载之前必须构建好模型的框架
model=Seq2Seq(params)

## 读取训练好的模型

In [66]:
from utils.config import checkpoint_dir, checkpoint_prefix

In [67]:
ckpt = tf.train.Checkpoint(Seq2Seq=model)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5)

ckpt.restore(ckpt_manager.latest_checkpoint)
print("Model restored")

Model restored


## 训练

In [68]:
# 构建优化器，定义损失函数
optimizer = tf.keras.optimizers.Adam(name='Adam',learning_rate=0.001)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

# 这里做了修改，即计算loss的时候将这两个部分的loss都进行取出，不计算
pad_index=vocab['<PAD>']
nuk_index=vocab['<UNK>']

def loss_function(real, pred):
    pad_mask = tf.math.equal(real, pad_index)  # 1
    nuk_mask = tf.math.equal(real, nuk_index)  # 1
    mask = tf.math.logical_not(tf.math.logical_or(pad_mask,nuk_mask))  # 0， 0
    
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)

[@tf.function](https://zhuanlan.zhihu.com/p/67192636)

In [69]:
@tf.function
def train_step(inp, targ):
    loss = 0
    
    with tf.GradientTape() as tape:
        # 1. 构建encoder
        enc_output, enc_hidden = model.call_encoder(inp)
        # 2. 复制
        dec_hidden = enc_hidden
        # 3. <START> * BATCH_SIZE 
        dec_input = tf.expand_dims([vocab['<START>']] * params["batch_size"], 1)
        
        # 逐个预测序列
        predictions, _ = model(dec_input, dec_hidden, enc_output, targ)
        
        batch_loss = loss_function(targ[:, 1:], predictions)

        variables = model.encoder.trainable_variables + model.decoder.trainable_variables+ model.attention.trainable_variables
    
        gradients = tape.gradient(batch_loss, variables)

        optimizer.apply_gradients(zip(gradients, variables))

        return batch_loss

In [73]:
epochs = params["epochs"]
# 如果检查点存在，则恢复最新的检查点。
# if ckpt_manager.latest_checkpoint:
#     ckpt.restore(ckpt_manager.latest_checkpoint)
#     print ('Latest checkpoint restored!!')
    
for epoch in range(epochs):
    start = time.time()
    total_loss = 0

    for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
#         print('第{}个batch:   \n 输入训练数据： {} \n 输入标签： {}'.format(batch, inp.shape, targ.shape))
        batch_loss = train_step(inp, targ)
        total_loss += batch_loss

        if batch % 1 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                         batch,
                                                         batch_loss.numpy()))
    # saving (checkpoint) the model every 2 epochs
    if (epoch + 1) % 2 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                             ckpt_save_path))

    print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                      total_loss / steps_per_epoch))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

Epoch 1 Batch 0 Loss 0.7275
Epoch 1 Batch 1 Loss 0.7556
Epoch 1 Batch 2 Loss 0.7951
Epoch 1 Batch 3 Loss 0.7539
Epoch 1 Batch 4 Loss 1.1512
Epoch 1 Batch 5 Loss 0.8018
Epoch 1 Batch 6 Loss 0.8899
Epoch 1 Batch 7 Loss 0.8596
Epoch 1 Batch 8 Loss 0.7371
Epoch 1 Batch 9 Loss 0.6927
Epoch 1 Batch 10 Loss 0.7799
Epoch 1 Batch 11 Loss 0.6522
Epoch 1 Batch 12 Loss 0.8805
Epoch 1 Batch 13 Loss 0.9190
Epoch 1 Batch 14 Loss 1.0017
Epoch 1 Batch 15 Loss 0.6937
Epoch 1 Batch 16 Loss 0.7246
Epoch 1 Batch 17 Loss 0.8001
Epoch 1 Batch 18 Loss 0.6500
Epoch 1 Batch 19 Loss 0.9564
Epoch 1 Batch 20 Loss 0.9028
Epoch 1 Batch 21 Loss 0.8553
Epoch 1 Batch 22 Loss 1.0096
Epoch 1 Batch 23 Loss 0.9836
Epoch 1 Batch 24 Loss 0.8970
Epoch 1 Batch 25 Loss 0.9343
Epoch 1 Batch 26 Loss 0.9130
Epoch 1 Batch 27 Loss 0.7455
Epoch 1 Batch 28 Loss 0.9140
Epoch 1 Batch 29 Loss 0.9401
Epoch 1 Batch 30 Loss 0.9019
Epoch 1 Batch 31 Loss 0.7424
Epoch 1 Batch 32 Loss 0.9070
Epoch 1 Batch 33 Loss 0.9948
Epoch 1 Batch 34 Loss 0.

Epoch 1 Batch 277 Loss 0.9510
Epoch 1 Batch 278 Loss 1.0159
Epoch 1 Batch 279 Loss 0.7384
Epoch 1 Batch 280 Loss 0.9928
Epoch 1 Batch 281 Loss 0.7907
Epoch 1 Batch 282 Loss 1.0060
Epoch 1 Batch 283 Loss 1.0430
Epoch 1 Batch 284 Loss 0.8651
Epoch 1 Batch 285 Loss 0.8046
Epoch 1 Batch 286 Loss 1.0006
Epoch 1 Batch 287 Loss 0.6978
Epoch 1 Batch 288 Loss 1.0299
Epoch 1 Batch 289 Loss 0.8078
Epoch 1 Batch 290 Loss 0.7376
Epoch 1 Batch 291 Loss 0.9218
Epoch 1 Batch 292 Loss 0.8474
Epoch 1 Batch 293 Loss 0.7635
Epoch 1 Batch 294 Loss 0.9600
Epoch 1 Batch 295 Loss 0.8271
Epoch 1 Batch 296 Loss 0.9536
Epoch 1 Batch 297 Loss 0.9197
Epoch 1 Batch 298 Loss 0.9641
Epoch 1 Batch 299 Loss 0.7884
Epoch 1 Batch 300 Loss 0.6760
Epoch 1 Batch 301 Loss 0.6597
Epoch 1 Batch 302 Loss 0.7773
Epoch 1 Batch 303 Loss 0.8068
Epoch 1 Batch 304 Loss 0.8664
Epoch 1 Batch 305 Loss 0.9372
Epoch 1 Batch 306 Loss 0.8518
Epoch 1 Batch 307 Loss 0.7012
Epoch 1 Batch 308 Loss 0.9305
Epoch 1 Batch 309 Loss 1.0776
Epoch 1 Ba

Epoch 1 Batch 551 Loss 0.9388
Epoch 1 Batch 552 Loss 0.6740
Epoch 1 Batch 553 Loss 0.8205
Epoch 1 Batch 554 Loss 0.8150
Epoch 1 Batch 555 Loss 0.7889
Epoch 1 Batch 556 Loss 0.9143
Epoch 1 Batch 557 Loss 0.8073
Epoch 1 Batch 558 Loss 0.8273
Epoch 1 Batch 559 Loss 0.7772
Epoch 1 Batch 560 Loss 0.8171
Epoch 1 Batch 561 Loss 0.8687
Epoch 1 Batch 562 Loss 0.9253
Epoch 1 Batch 563 Loss 0.9018
Epoch 1 Batch 564 Loss 0.8203
Epoch 1 Batch 565 Loss 0.9274
Epoch 1 Batch 566 Loss 1.0673
Epoch 1 Batch 567 Loss 0.9125
Epoch 1 Batch 568 Loss 1.1722
Epoch 1 Batch 569 Loss 1.0352
Epoch 1 Batch 570 Loss 0.8585
Epoch 1 Batch 571 Loss 0.8840
Epoch 1 Batch 572 Loss 0.8432
Epoch 1 Batch 573 Loss 0.9771
Epoch 1 Batch 574 Loss 0.9469
Epoch 1 Batch 575 Loss 0.8176
Epoch 1 Batch 576 Loss 1.1008
Epoch 1 Batch 577 Loss 0.8849
Epoch 1 Batch 578 Loss 0.8796
Epoch 1 Batch 579 Loss 0.9519
Epoch 1 Batch 580 Loss 0.9949
Epoch 1 Batch 581 Loss 1.0187
Epoch 1 Batch 582 Loss 0.9292
Epoch 1 Batch 583 Loss 0.8192
Epoch 1 Ba

Epoch 1 Batch 825 Loss 1.1490
Epoch 1 Batch 826 Loss 0.7939
Epoch 1 Batch 827 Loss 0.9111
Epoch 1 Batch 828 Loss 0.9998
Epoch 1 Batch 829 Loss 1.0661
Epoch 1 Batch 830 Loss 1.0174
Epoch 1 Batch 831 Loss 1.1186
Epoch 1 Batch 832 Loss 0.9854
Epoch 1 Batch 833 Loss 0.9676
Epoch 1 Batch 834 Loss 1.0047
Epoch 1 Batch 835 Loss 0.7514
Epoch 1 Batch 836 Loss 1.0771
Epoch 1 Batch 837 Loss 0.8091
Epoch 1 Batch 838 Loss 1.0204
Epoch 1 Batch 839 Loss 0.8639
Epoch 1 Batch 840 Loss 0.9451
Epoch 1 Batch 841 Loss 0.8551
Epoch 1 Batch 842 Loss 0.9294
Epoch 1 Batch 843 Loss 0.7970
Epoch 1 Batch 844 Loss 0.8807
Epoch 1 Batch 845 Loss 0.8660
Epoch 1 Batch 846 Loss 0.9267
Epoch 1 Batch 847 Loss 1.0449
Epoch 1 Batch 848 Loss 0.7321
Epoch 1 Batch 849 Loss 0.7741
Epoch 1 Batch 850 Loss 0.9878
Epoch 1 Batch 851 Loss 0.9402
Epoch 1 Batch 852 Loss 0.8678
Epoch 1 Batch 853 Loss 1.0991
Epoch 1 Batch 854 Loss 0.9206
Epoch 1 Batch 855 Loss 0.9950
Epoch 1 Batch 856 Loss 0.9520
Epoch 1 Batch 857 Loss 0.7868
Epoch 1 Ba

Epoch 1 Batch 1095 Loss 0.9831
Epoch 1 Batch 1096 Loss 0.8544
Epoch 1 Batch 1097 Loss 1.1044
Epoch 1 Batch 1098 Loss 1.1040
Epoch 1 Batch 1099 Loss 0.9273
Epoch 1 Batch 1100 Loss 1.0553
Epoch 1 Batch 1101 Loss 1.0157
Epoch 1 Batch 1102 Loss 0.9646
Epoch 1 Batch 1103 Loss 0.8534
Epoch 1 Batch 1104 Loss 1.2524
Epoch 1 Batch 1105 Loss 0.8636
Epoch 1 Batch 1106 Loss 1.0627
Epoch 1 Batch 1107 Loss 0.6658
Epoch 1 Batch 1108 Loss 0.6959
Epoch 1 Batch 1109 Loss 1.0721
Epoch 1 Batch 1110 Loss 0.7938
Epoch 1 Batch 1111 Loss 1.1080
Epoch 1 Batch 1112 Loss 0.9887
Epoch 1 Batch 1113 Loss 1.0119
Epoch 1 Batch 1114 Loss 1.0932
Epoch 1 Batch 1115 Loss 0.7979
Epoch 1 Batch 1116 Loss 1.0075
Epoch 1 Batch 1117 Loss 0.8766
Epoch 1 Batch 1118 Loss 1.1590
Epoch 1 Batch 1119 Loss 0.9234
Epoch 1 Batch 1120 Loss 0.9386
Epoch 1 Batch 1121 Loss 0.9387
Epoch 1 Batch 1122 Loss 0.9656
Epoch 1 Batch 1123 Loss 0.9544
Epoch 1 Batch 1124 Loss 0.8702
Epoch 1 Batch 1125 Loss 0.7672
Epoch 1 Batch 1126 Loss 1.0336
Epoch 1 

Epoch 1 Batch 1360 Loss 0.8536
Epoch 1 Batch 1361 Loss 0.9692
Epoch 1 Batch 1362 Loss 1.1149
Epoch 1 Batch 1363 Loss 1.0488
Epoch 1 Batch 1364 Loss 0.9073
Epoch 1 Batch 1365 Loss 1.0296
Epoch 1 Batch 1366 Loss 0.8180
Epoch 1 Batch 1367 Loss 0.9119
Epoch 1 Batch 1368 Loss 1.0547
Epoch 1 Batch 1369 Loss 0.9708
Epoch 1 Batch 1370 Loss 0.7114
Epoch 1 Batch 1371 Loss 0.7076
Epoch 1 Batch 1372 Loss 0.9346
Epoch 1 Batch 1373 Loss 0.8180
Epoch 1 Batch 1374 Loss 0.8897
Epoch 1 Batch 1375 Loss 0.7849
Epoch 1 Batch 1376 Loss 0.7917
Epoch 1 Batch 1377 Loss 0.8902
Epoch 1 Batch 1378 Loss 1.0540
Epoch 1 Batch 1379 Loss 1.1474
Epoch 1 Batch 1380 Loss 0.9359
Epoch 1 Batch 1381 Loss 0.9616
Epoch 1 Batch 1382 Loss 0.9207
Epoch 1 Batch 1383 Loss 0.8000
Epoch 1 Batch 1384 Loss 0.8976
Epoch 1 Batch 1385 Loss 0.9403
Epoch 1 Batch 1386 Loss 0.9644
Epoch 1 Batch 1387 Loss 0.8345
Epoch 1 Batch 1388 Loss 0.7283
Epoch 1 Batch 1389 Loss 0.8566
Epoch 1 Batch 1390 Loss 0.8594
Epoch 1 Batch 1391 Loss 0.8447
Epoch 1 

Epoch 1 Batch 1625 Loss 0.9708
Epoch 1 Batch 1626 Loss 0.8552
Epoch 1 Batch 1627 Loss 1.0164
Epoch 1 Batch 1628 Loss 0.9041
Epoch 1 Batch 1629 Loss 0.9254
Epoch 1 Batch 1630 Loss 0.9340
Epoch 1 Batch 1631 Loss 0.9488
Epoch 1 Batch 1632 Loss 0.9559
Epoch 1 Batch 1633 Loss 0.8573
Epoch 1 Batch 1634 Loss 0.9528
Epoch 1 Batch 1635 Loss 1.0162
Epoch 1 Batch 1636 Loss 0.9615
Epoch 1 Batch 1637 Loss 1.0081
Epoch 1 Batch 1638 Loss 1.2133
Epoch 1 Batch 1639 Loss 0.9501
Epoch 1 Batch 1640 Loss 0.8603
Epoch 1 Batch 1641 Loss 0.9037
Epoch 1 Batch 1642 Loss 1.1091
Epoch 1 Batch 1643 Loss 0.9231
Epoch 1 Batch 1644 Loss 1.0154
Epoch 1 Batch 1645 Loss 0.8274
Epoch 1 Batch 1646 Loss 0.9644
Epoch 1 Batch 1647 Loss 1.0111
Epoch 1 Batch 1648 Loss 0.9704
Epoch 1 Batch 1649 Loss 0.9386
Epoch 1 Batch 1650 Loss 1.0428
Epoch 1 Batch 1651 Loss 1.1065
Epoch 1 Batch 1652 Loss 0.9685
Epoch 1 Batch 1653 Loss 0.8484
Epoch 1 Batch 1654 Loss 0.9541
Epoch 1 Batch 1655 Loss 1.1601
Epoch 1 Batch 1656 Loss 0.9992
Epoch 1 

Epoch 1 Batch 1890 Loss 0.7992
Epoch 1 Batch 1891 Loss 0.9578
Epoch 1 Batch 1892 Loss 1.0512
Epoch 1 Batch 1893 Loss 1.0352
Epoch 1 Batch 1894 Loss 1.1564
Epoch 1 Batch 1895 Loss 0.7373
Epoch 1 Batch 1896 Loss 1.0410
Epoch 1 Batch 1897 Loss 0.9133
Epoch 1 Batch 1898 Loss 0.9281
Epoch 1 Batch 1899 Loss 1.0627
Epoch 1 Batch 1900 Loss 0.8343
Epoch 1 Batch 1901 Loss 1.0404
Epoch 1 Batch 1902 Loss 0.7364
Epoch 1 Batch 1903 Loss 0.9794
Epoch 1 Batch 1904 Loss 0.8312
Epoch 1 Batch 1905 Loss 1.2267
Epoch 1 Batch 1906 Loss 0.9683
Epoch 1 Batch 1907 Loss 0.8255
Epoch 1 Batch 1908 Loss 0.9284
Epoch 1 Batch 1909 Loss 1.1617
Epoch 1 Batch 1910 Loss 1.0739
Epoch 1 Batch 1911 Loss 0.9392
Epoch 1 Batch 1912 Loss 1.0022
Epoch 1 Batch 1913 Loss 0.9622
Epoch 1 Batch 1914 Loss 1.0287
Epoch 1 Batch 1915 Loss 0.8078
Epoch 1 Batch 1916 Loss 0.9795
Epoch 1 Batch 1917 Loss 0.8821
Epoch 1 Batch 1918 Loss 0.9543
Epoch 1 Batch 1919 Loss 0.9935
Epoch 1 Batch 1920 Loss 1.3013
Epoch 1 Batch 1921 Loss 0.9995
Epoch 1 

Epoch 1 Batch 2155 Loss 0.7105
Epoch 1 Batch 2156 Loss 0.9611
Epoch 1 Batch 2157 Loss 0.7394
Epoch 1 Batch 2158 Loss 0.8613
Epoch 1 Batch 2159 Loss 1.1283
Epoch 1 Batch 2160 Loss 0.8348
Epoch 1 Batch 2161 Loss 1.0238
Epoch 1 Batch 2162 Loss 0.9496
Epoch 1 Batch 2163 Loss 0.9995
Epoch 1 Batch 2164 Loss 0.8916
Epoch 1 Batch 2165 Loss 0.9198
Epoch 1 Batch 2166 Loss 0.7738
Epoch 1 Batch 2167 Loss 0.8763
Epoch 1 Batch 2168 Loss 0.8287
Epoch 1 Batch 2169 Loss 0.7469
Epoch 1 Batch 2170 Loss 0.9489
Epoch 1 Batch 2171 Loss 0.8101
Epoch 1 Batch 2172 Loss 0.9008
Epoch 1 Batch 2173 Loss 0.9207
Epoch 1 Batch 2174 Loss 0.9031
Epoch 1 Batch 2175 Loss 0.9473
Epoch 1 Batch 2176 Loss 1.2229
Epoch 1 Batch 2177 Loss 1.1129
Epoch 1 Batch 2178 Loss 0.9330
Epoch 1 Batch 2179 Loss 0.9319
Epoch 1 Batch 2180 Loss 1.1033
Epoch 1 Batch 2181 Loss 0.8760
Epoch 1 Batch 2182 Loss 0.8292
Epoch 1 Batch 2183 Loss 1.1210
Epoch 1 Batch 2184 Loss 0.7225
Epoch 1 Batch 2185 Loss 0.9643
Epoch 1 Batch 2186 Loss 0.8924
Epoch 1 

Epoch 1 Batch 2420 Loss 0.9319
Epoch 1 Batch 2421 Loss 0.7659
Epoch 1 Batch 2422 Loss 1.0614
Epoch 1 Batch 2423 Loss 0.9844
Epoch 1 Batch 2424 Loss 0.8317
Epoch 1 Batch 2425 Loss 1.0684
Epoch 1 Batch 2426 Loss 1.0954
Epoch 1 Batch 2427 Loss 1.0728
Epoch 1 Batch 2428 Loss 0.8780
Epoch 1 Batch 2429 Loss 0.9096
Epoch 1 Batch 2430 Loss 0.9847
Epoch 1 Batch 2431 Loss 0.9570
Epoch 1 Batch 2432 Loss 1.0934
Epoch 1 Batch 2433 Loss 0.7019
Epoch 1 Batch 2434 Loss 0.7767
Epoch 1 Batch 2435 Loss 0.7556
Epoch 1 Batch 2436 Loss 0.8304
Epoch 1 Batch 2437 Loss 1.0913
Epoch 1 Batch 2438 Loss 1.0961
Epoch 1 Batch 2439 Loss 0.7741
Epoch 1 Batch 2440 Loss 1.1380
Epoch 1 Batch 2441 Loss 1.1333
Epoch 1 Batch 2442 Loss 0.9665
Epoch 1 Batch 2443 Loss 0.7963
Epoch 1 Batch 2444 Loss 0.9862
Epoch 1 Batch 2445 Loss 0.8301
Epoch 1 Batch 2446 Loss 0.9506
Epoch 1 Batch 2447 Loss 0.8574
Epoch 1 Batch 2448 Loss 0.8507
Epoch 1 Batch 2449 Loss 1.0471
Epoch 1 Batch 2450 Loss 1.1104
Epoch 1 Batch 2451 Loss 0.7832
Epoch 1 

KeyboardInterrupt: 

# 载入模型

In [74]:
# 如果检查点存在，则恢复最新的检查点。
ckpt.restore(ckpt_manager.latest_checkpoint)
print("Model restored")

Model restored


# 预测

In [75]:
max_length_targ = 41
max_length_inp = 200
units = 512

In [76]:
def evaluate(model,inputs):
    attention_plot = np.zeros((max_length_targ, max_length_inp))
    
    inputs = tf.convert_to_tensor(inputs)

    result = ''
    
    hidden = [tf.zeros((1, units))]
    enc_output, enc_hidden = model.encoder(inputs, hidden)

    dec_hidden = enc_hidden
    
    dec_input = tf.expand_dims([vocab['<START>']], 0)
    
    context_vector, _ = model.attention(dec_hidden, enc_output)

    for t in range(max_length_targ):
        # max_length_targ：要预测的这句话的最大的长度，如果是40，就会执行40个循环
        # 要么遇到结尾符，要么运行至整个循环结束
        
        context_vector, attention_weights = model.attention(dec_hidden, enc_output)
        
        # 预测的时候也是一样，拿到model以后，一步一步的进行decode
        # 这里输入三个参数之后，得到结果中依旧会有一个dec_hidden,这是这个时间步的隐藏层的输出
        # 而第一次传进去的dec_hidden是encoder层输出的隐藏层信息（输入为<START>时）
        # 再往后就是decoder层中的当前进行预测的时间步的上一个时间步输出的隐藏状态
        # 实现了dec_hidden这个变量的复用，就是实现了每循环一次就对dec_hidden进行更新，下一次循环
        # 的时候，把更新后的dec_hidden传进去
        predictions, dec_hidden = model.decoder(dec_input,
                                         dec_hidden,
                                         enc_output,
                                         context_vector)

        # storing the attention weights to plot later on
        attention_weights = tf.reshape(attention_weights, (-1, ))
        
        attention_plot[t] = attention_weights.numpy()
        # 拿到预测结果之后来取概率值最大的ID
        predicted_id = tf.argmax(predictions[0]).numpy()

        result += reverse_vocab[predicted_id] + ' '
        # 如果概率值最大的ID对应的是<STOP>，表示到达句尾，就直接返回这句话
        if reverse_vocab[predicted_id] == '<STOP>':
            return result, sentence, attention_plot

        # the predicted ID is fed back into the model
        dec_input = tf.expand_dims([predicted_id], 0)

    return result, sentence, attention_plot

In [77]:
def translate(sentence):
    sentence = preprocess_sentence(sentence,max_length_inp,vocab)
    
    result, sentence, attention_plot = evaluate(model,sentence)

    print('Input: %s' % (sentence))
    print('Predicted translation: {}'.format(result))

    attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]
    plot_attention(attention_plot, sentence.split(' '), result.split(' '))

## 单句预测

In [95]:
# sentence='北京 汽车 BJ 20 自动挡 最低 配 <UNK> 速 续航 技师说'

# translate(sentence)

## 批量预测

In [79]:
def batch_predict(inps):
    """
    这里的输入就是一个batch_size大小的输入，比如说有32句话，句子长度为200，那么这里就是32*200为
    大小的输入
    """
    # 判断输入长度
    batch_size=len(inps)
    # 开辟结果存储list
    preidicts=[''] * batch_size
    
    inps = tf.convert_to_tensor(inps)
    # 0. 初始化隐藏层输入
    hidden = [tf.zeros((batch_size, units))]
    # 1. 构建encoder
    enc_output, enc_hidden = model.encoder(inps, hidden)
    # 2. 复制
    dec_hidden = enc_hidden
    # 3. <START> * BATCH_SIZE   为传进来的batch_size大小的句子集填充<START>
    # 训练的时候也一样，会初始化很多的<START>
    dec_input = tf.expand_dims([vocab['<START>']] * batch_size, 1)
    
    context_vector, _ = model.attention(dec_hidden, enc_output)
    # Teacher forcing - feeding the target as the next input
    for t in range(max_length_targ):
        # 计算上下文
        context_vector, attention_weights = model.attention(dec_hidden, enc_output)
        # 单步预测
        # 拿到一个batch所有的<START>之后，全部输入到decoder里边
        predictions, dec_hidden = model.decoder(dec_input,
                                         dec_hidden,
                                         enc_output,
                                         context_vector)
        
        # id转换 贪婪搜索  拿到预测的结果，取预测结果的概率最大值
        # 这里的axis=1表示横向取最大值，相当于每一个句子里边取概率最大的那一个词
        # 这里由于是一个batch一个batch的预测，一次预测的是32个句子，则第一步输入32个<START>
        # 这里就会得到32个概率最高的词对应的index
        predicted_ids = tf.argmax(predictions,axis=1).numpy()
        
        # 将这里得到的32index分别赋值到不同的句子里边去，用一个字典来保存32个句子
        for index,predicted_id in enumerate(predicted_ids):
            preidicts[index]+= reverse_vocab[predicted_id] + ' '
            # 这里就是把预测出来的index对应的词放到字典中来
        
        # using teacher forcing
        dec_input = tf.expand_dims(predicted_ids, 1)

    results=[]  # 最后返回结果
    for preidict in preidicts:
        # 去掉句子前后空格
        preidict=preidict.strip()
        # 句子小于max len就结束了 截断
        if '<STOP>' in preidict:
            # 截断stop
            preidict=preidict[:preidict.index('<STOP>')]
        # 保存结果
        results.append(preidict)
    return results

In [80]:
from tqdm import tqdm
import math

In [81]:
# 这里是预测结果的代码，就是将要预测的data_X(测试集，有20000个),和batch大小传进去
# 进行一个batch一个batch的预测
# 将预测出来的结果拼接在一起，预测结束后会拿到20000个句子
def model_predict(data_X,batch_size):
    # 存储结果
    results=[]
    # 样本数量
    sample_size=len(data_X)
    # batch 操作轮数 math.ceil向上取整 小数 +1
    # 因为最后一个batch可能不足一个batch size 大小 ,但是依然需要计算  
    steps_epoch = math.ceil(sample_size/batch_size)
    # [0,steps_epoch)
    for i in tqdm(range(steps_epoch)):
        batch_data = data_X[i*batch_size:(i+1)*batch_size]
        results+=batch_predict(batch_data)
    return results

In [82]:
%%time
results=model_predict(test_X,batch_size=32)

100%|████████████████████████████████████████████████████████████████████████████████| 625/625 [05:01<00:00,  2.07it/s]

Wall time: 5min 1s





# 构建提交结果文件

In [83]:
# 读入测试集数据
test_df=pd.read_csv(test_data_path)
test_df.head()

Unnamed: 0,QID,Brand,Model,Question,Dialogue
0,Q1,大众(进口),高尔夫(进口),我的帕萨特烧机油怎么办怎么办？,技师说：你好，请问你的车跑了多少公里了，如果在保修期内，可以到当地的4店里面进行检查维修。如...
1,Q2,一汽-大众奥迪,奥迪A6,修一下多少钱是换还是修,技师说：你好师傅！抛光处理一下就好了！50元左右就好了，希望能够帮到你！祝你生活愉快！
2,Q3,上汽大众,帕萨特,帕萨特领域 喇叭坏了 店里说方向盘里线坏了 换一根两三百不等 感觉太贵,技师说：你好，气囊油丝坏了吗，这个价格不贵。可以更换。
3,Q4,南京菲亚特,派力奥,发动机漏气会有什么征兆？,技师说：你好！一：发动机没力，并伴有“啪啪”的漏气声音。二：发动机没力，并伴有排气管冒黑烟。...
4,Q5,东风本田,思铂睿,请问 那天右后胎扎了订，补了胎后跑高速80多开始有点抖，110时速以上抖动明显，以为是未做动...,技师说：你好师傅！可能前轮平衡快脱落或者不平衡造成的！建议前轮做一下动平衡就好了！希望能够帮...


In [84]:
# 判断一下是否有空值   没啥实际的作用看着
for idx,result in enumerate(results):
    if result=='':print(idx)

In [85]:
# 赋值结果  将前边得到结果赋值给预测的这一列
test_df['Prediction'] = results
#　提取ID和预测结果两列  再从上边导入的测试集中拿到QID这一列
test_df = test_df[['QID','Prediction']]

In [86]:
test_df

Unnamed: 0,QID,Prediction
0,Q1,烧 机油 ， 需要 检查 ， 维修 店 进行 维修 ， 检查 机油 消耗 过大 活塞环 气门...
1,Q2,师傅 ， 抛光 处理 一下 ！
2,Q3,气囊 游丝 ， 价格 不 贵 ， 更换 气囊 。
3,Q4,分析 发动机 无 反应 ， 遇到 噗 噗声 噗 噗声 流水 生 属于 正常 现象 ， 白烟 ...
4,Q5,客户 解释 ， 轮胎 动平衡 问题 ， 轮胎 动平衡 问题 ， 轮胎 动平衡 问题 ， 轮胎...
...,...,...
19995,Q19996,， 进气 压力 传感器 进气 VVT VVT 链轮 都 会 出现 ， 进气 VVT VVT ...
19996,Q19997,原厂 配件 汽车厂家 。
19997,Q19998,车辆 不要 水洗 ， 拆掉 电瓶 负极 拆掉 避免 电瓶 亏损 ， 不用 经常 跑跑 高速 ...
19998,Q19999,砂纸 进行 焊接 一层 垫 之后 ， 前轮 压着 一点 深 一点 。


## 结果处理

In [89]:
# 最后结果的补救措施，就是把结果中，！。去掉，把所有的空格删掉
def submit_proc(sentence):
    sentence=sentence.lstrip(' ，！。')
    sentence=sentence.replace(' ','')
    if sentence=='':
        sentence='随时联系'
    return sentence

In [90]:
test_df['Prediction']=test_df['Prediction'].apply(submit_proc)  # 进行一下预处理

In [91]:
test_df.head()

Unnamed: 0,QID,Prediction
0,Q1,烧机油，需要检查，维修店进行维修，检查机油消耗过大活塞环气门间隙过大气门气门间隙过大气门气门...
1,Q2,师傅，抛光处理一下！
2,Q3,气囊游丝，价格不贵，更换气囊。
3,Q4,分析发动机无反应，遇到噗噗声噗噗声流水生属于正常现象，白烟，多数发动机漏气，需要检查发动机缸...
4,Q5,客户解释，轮胎动平衡问题，轮胎动平衡问题，轮胎动平衡问题，轮胎动平衡问题，轮胎动平衡问题，轮...


# 保存结果

In [92]:
from utils.file_utils import get_result_filename

In [94]:
# 获取结果存储路径  随机生成函数名的方法，防止多次生成结果，会搞混，这里在文件名中加上时间戳，batch大小，
# 训练了多少轮，最大长度，embedding_size这些信息，这样的话好处就是，经过不断地修改，发现结果一直不如
# 之前的某一轮的结果好，这样的话就可以直接找到，并且看到当时的参数设置信息
result_save_path = get_result_filename(params["batch_size"],params["epochs"] , params["max_enc_len"], params["embed_size"], commit='_4_1_submit_proc_add_masks_loss_seq2seq_code')