In [1]:
import os
import sys
import time
from datetime import timedelta

import numpy as np
import tensorflow as tf
from sklearn import metrics

from cnn_model import TCNNConfig, TextCNNwithAttention
from data.cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab

base_dir = './data'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')

save_dir = './checkpoints/textcnn'
save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径

def get_time_dif(start_time):
    """获取已使用时间"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))


def feed_data(x_batch, y_batch, keep_prob):
    feed_dict = {
        model.input_x: x_batch,
        model.input_y: y_batch,
        model.keep_prob: keep_prob
    }
    return feed_dict


def evaluate(sess, x_, y_):
    """评估在某一数据上的准确率和损失"""
    data_len = len(x_)
    batch_eval = batch_iter(x_, y_, 128)
    total_loss = 0.0
    total_acc = 0.0
    for x_batch, y_batch in batch_eval:
        batch_len = len(x_batch)
        feed_dict = feed_data(x_batch, y_batch, 1.0)
        loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
        total_loss += loss * batch_len
        total_acc += acc * batch_len

    return total_loss / data_len, total_acc / data_len

In [2]:
print('Configuring CNN model...')
config = TCNNConfig()
if not os.path.exists(vocab_dir):  # 如果不存在词汇表，重建
    build_vocab(train_dir, vocab_dir, config.vocab_size)
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_dir)
config.vocab_size = len(words)
model = TextCNNwithAttention(config)

Configuring CNN model...


In [3]:
print("Configuring TensorBoard and Saver...")
# 配置 Tensorboard，重新训练时，请将tensorboard文件夹删除，不然图会覆盖
tensorboard_dir = './tensorboard/textcnn'
if not os.path.exists(tensorboard_dir):
    os.makedirs(tensorboard_dir)

tf.summary.scalar("loss", model.loss)
tf.summary.scalar("accuracy", model.acc)
merged_summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(tensorboard_dir)

# 配置 Saver
saver = tf.train.Saver()
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

print("Loading training and validation data...")
# 载入训练集与验证集
start_time = time.time()
x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)

# 创建session
session = tf.Session()
session.run(tf.global_variables_initializer())
writer.add_graph(session.graph)

print('Training and evaluating...')
start_time = time.time()
total_batch = 0  # 总批次
best_acc_val = 0.0  # 最佳验证集准确率
last_improved = 0  # 记录上一次提升批次
require_improvement = 1000  # 如果超过1000轮未提升，提前结束训练

flag = False
for epoch in range(config.num_epochs):
    print('Epoch:', epoch + 1)
    batch_train = batch_iter(x_train, y_train, config.batch_size)
    for x_batch, y_batch in batch_train:
        feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)

        if total_batch % config.save_per_batch == 0:
            # 每多少轮次将训练结果写入tensorboard scalar
            s = session.run(merged_summary, feed_dict=feed_dict)
            writer.add_summary(s, total_batch)

        if total_batch % config.print_per_batch == 0:
            # 每多少轮次输出在训练集和验证集上的性能
            feed_dict[model.keep_prob] = 1.0
            loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
            loss_val, acc_val = evaluate(session, x_val, y_val)  # todo

            if acc_val > best_acc_val:
                # 保存最好结果
                best_acc_val = acc_val
                last_improved = total_batch
                saver.save(sess=session, save_path=save_path)
                improved_str = '*'
            else:
                improved_str = ''

            time_dif = get_time_dif(start_time)
            msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
                  + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
            print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))

        session.run(model.optim, feed_dict=feed_dict)  # 运行优化
        total_batch += 1

        if total_batch - last_improved > require_improvement:
            # 验证集正确率长期不提升，提前结束训练
            print("No optimization for a long time, auto-stopping...")
            flag = True
            break  # 跳出循环
    if flag:  # 同上
        break

Configuring TensorBoard and Saver...
Loading training and validation data...
Time usage: 0:00:12
Training and evaluating...
Epoch: 1
Iter:      0, Train Loss:    2.3, Train Acc:  17.19%, Val Loss:    2.3, Val Acc:  10.72%, Time: 0:00:01 *
Iter:    100, Train Loss:    1.1, Train Acc:  56.25%, Val Loss:    1.4, Val Acc:  53.30%, Time: 0:00:04 *
Iter:    200, Train Loss:   0.34, Train Acc:  92.19%, Val Loss:    0.6, Val Acc:  81.66%, Time: 0:00:07 *
Iter:    300, Train Loss:   0.28, Train Acc:  92.19%, Val Loss:   0.41, Val Acc:  88.82%, Time: 0:00:10 *
Iter:    400, Train Loss:   0.13, Train Acc:  96.88%, Val Loss:   0.39, Val Acc:  88.52%, Time: 0:00:12 
Iter:    500, Train Loss:   0.39, Train Acc:  89.06%, Val Loss:   0.35, Val Acc:  89.54%, Time: 0:00:15 *
Iter:    600, Train Loss:   0.37, Train Acc:  90.62%, Val Loss:   0.33, Val Acc:  90.04%, Time: 0:00:18 *
Iter:    700, Train Loss:   0.14, Train Acc:  98.44%, Val Loss:   0.33, Val Acc:  90.08%, Time: 0:00:21 *
Epoch: 2
Iter:    80

In [4]:
print("Loading test data...")
start_time = time.time()
x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)

session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess=session, save_path=save_path)  # 读取保存的模型

print('Testing...')
loss_test, acc_test = evaluate(session, x_test, y_test)
msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
print(msg.format(loss_test, acc_test))

batch_size = 128
data_len = len(x_test)
num_batch = int((data_len - 1) / batch_size) + 1

y_test_cls = np.argmax(y_test, 1)
y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
for i in range(num_batch):  # 逐批次处理
    start_id = i * batch_size
    end_id = min((i + 1) * batch_size, data_len)
    feed_dict = {
        model.input_x: x_test[start_id:end_id],
        model.keep_prob: 1.0
    }
    y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)

# 评估
print("Precision, Recall and F1-Score...")
print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))

# 混淆矩阵
print("Confusion Matrix...")
cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
print(cm)

time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)

Loading test data...
INFO:tensorflow:Restoring parameters from ./checkpoints/textcnn/best_validation
Testing...
Test Loss:   0.18, Test Acc:  94.73%
Precision, Recall and F1-Score...
             precision    recall  f1-score   support

         体育       0.99      0.99      0.99      1000
         财经       0.95      0.98      0.96      1000
         房产       0.96      0.97      0.97      1000
         家居       0.97      0.78      0.86      1000
         教育       0.90      0.92      0.91      1000
         科技       0.95      0.97      0.96      1000
         时尚       0.94      0.96      0.95      1000
         时政       0.92      0.95      0.93      1000
         游戏       0.94      0.97      0.96      1000
         娱乐       0.96      0.97      0.96      1000

avg / total       0.95      0.95      0.95     10000

Confusion Matrix...
[[992   0   0   0   4   0   0   0   2   2]
 [  0 982   8   0   3   1   0   6   0   0]
 [  0   3 975   2   3   2   3  10   1   1]
 [  0  23  20 777  54  19  32