## 0.导入库

### 0.1 常用的库

In [1]:
import os
import numpy as np
import time

### 0.2 需要使用的库

In [2]:
import pandas as pd
import pickle
import re

### 0.3 基本方法

In [3]:
import sys
# 实时更新进度条
def print_flush(print_string):
    print(print_string, end='\r')
    sys.stdout.flush()

# 导入深度学习库tensorflow    
import tensorflow as tf    
#  获取显存动态增长的会话 
def get_session():
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)
    return session

  from ._conv import register_converters as _register_converters


## 1. 文件内容的列表

### 1.1 文本文件路径的列表

In [4]:
def get_txtFilePath_list(root_dirPath):
    txtFilePath_list = []
    sub_dirName_list = next(os.walk(root_dirPath))[1]
    for sub_dirName in sub_dirName_list:
        sub_dirPath = os.path.join(root_dirPath, sub_dirName)
        fileName_list = next(os.walk(sub_dirPath))[2]
        filePath_list = [os.path.join(sub_dirPath, k) for k in fileName_list]
        txtFilePath_list.extend(filePath_list)
    return txtFilePath_list


root_dirPath = '../resources/THUCNews/'
txtFilePath_list = get_txtFilePath_list(root_dirPath)
print('文本文件路径的列表长度:', len(txtFilePath_list))

文本文件路径的列表长度: 836075


### 1.2 读取所有文本文件

In [5]:
def get_fileContent(txtFilePath):
    with open(txtFilePath, 'r', encoding='utf8') as file:
        fileContent = file.read()
    return fileContent


sequence_length = 600
sample_quantity = len(txtFilePath_list)
startTime = time.time()
content_list = []
for i in range(sample_quantity):
    txtFilePath = txtFilePath_list[i]
    fileContent = get_fileContent(txtFilePath)
    fileContent_1 = re.sub('\s+', ' ', fileContent)
    fileContent_2 = fileContent_1[:sequence_length]
    content_list.append(fileContent_2)
    # 打印提示信息，动态刷新进度条
    index = i + 1
    if index % 100 == 0 or index==sample_quantity:
        percent = index / sample_quantity * 100
        percent_int = int(percent)
        half_percent_int = int(percent_int / 2)
        string_0 = '%d/ %d ' %(index, sample_quantity)
        string_1 = '>' * half_percent_int + ' ' * (50-half_percent_int)
        string_2 = ' 进度百分比:%.2f%%' %percent
        usedTime = time.time() - startTime
        string_3 = ' 读取速度:%.2f文件/秒' %(index/usedTime)
        string_4 = ' 总共花费时间:%.2f秒' %(usedTime)
        print_string = string_0 + string_1 + string_2 + string_3 + string_4
        print_flush(print_string)

836075/ 836075 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 进度百分比:100.00% 读取速度:2088.18文件/秒 总共花费时间:400.38秒

### 1.3 把文件内容列表保存为pickle文件

In [6]:
pickleFilePath = '../resources/content_list.pickle'
with open(pickleFilePath, 'wb') as file:
    pickle.dump(content_list, file)

### 1.4 从pickle文件加载文件内容列表

In [4]:
pickleFilePath = '../resources/content_list.pickle'
with open(pickleFilePath, 'rb') as file:
    content_list = pickle.load(file)

## 2. 样本标签的列表

### 2.1 获取样本标签的列表

In [5]:
def get_label_list(root_dirPath):
    label_list = []
    sub_dirName_list = next(os.walk(root_dirPath))[1]
    for sub_dirName in sub_dirName_list:
        sub_dirPath = os.path.join(root_dirPath, sub_dirName)
        fileName_list = next(os.walk(sub_dirPath))[2]
        part_label_list = [sub_dirName] * len(fileName_list)
        label_list.extend(part_label_list)
    return label_list


root_dirPath = '../resources/THUCNews/'
label_list = get_label_list(root_dirPath)
print('样本标签的列表长度:', len(label_list)) 
pd.value_counts(label_list)

样本标签的列表长度: 836075


科技    162929
股票    154398
体育    131604
娱乐     92632
时政     63086
社会     50849
教育     41936
财经     37098
家居     32586
游戏     24373
房产     20050
时尚     13368
彩票      7588
星座      3578
dtype: int64

## 3. 字列表

### 3.1 根据文件内容列表，统计计数获得出现次数排名前6999的字
#### 排名7000以后的字统一用'PAD'表示

In [9]:
from collections import Counter

def get_word_list(content_list, size):
    startTime = time.time()
    counter = Counter()
    sample_quantity = len(content_list)
    for i, content in enumerate(content_list, 1):
        counter.update(content)
        if i%1000==0 or i==sample_quantity:
            string_0 = '%d/ %d' %(i, sample_quantity)
            string_1 = ' 进度百分比: %.2f%%' %(i/sample_quantity*100)
            usedTime = time.time() - startTime
            string_2 = ' 花费时间: %.2f秒' %usedTime
            print_string = string_0 + string_1 + string_2
            print_flush(print_string)
    word_list_1 = [k[0] for k in counter.most_common(size-1)]
    word_list = ['PAD'] + word_list_1
    return word_list


vocabulary_size = 7000
word_list = get_word_list(content_list, vocabulary_size)

836075/ 836075 进度百分比: 100.00% 花费时间: 81.30秒

### 3.2 把字列表保存为pickle文件

In [10]:
pickleFilePath = '../resources/word_list.pickle'
with open(pickleFilePath, 'wb') as file:
    pickle.dump(word_list, file)

### 3.3 从pickle文件加载字列表

In [6]:
pickleFilePath = '../resources/word_list.pickle'
with open(pickleFilePath, 'rb') as file:
    word_list = pickle.load(file)

## 4.数据准备

### 4.1 get_X

In [7]:
import keras

# sequence_length中文叫做序列长度，根据自己经验设置
# sequence_length设置为600，即根据文章的前600字则可判断文章类型
sequence_length = 600
word2id_dict = dict([(b, a) for a, b in enumerate(word_list)])
    
    
# 获取一篇文章对应的字id列表    
def get_id_list(index):
    content = index if isinstance(index, str) else content_list[index]
    id_list = []
    for word in content[:sequence_length]:
        if word in word2id_dict:
            id_ = word2id_dict[word]
            id_list.append(id_)
        else:
            id_list.append(0)
    return id_list       


# 获取多篇文章的字id列表
def get_X(indexes):
    idList_list = [get_id_list(k) for k in indexes]
    X = keras.preprocessing.sequence.pad_sequences(idList_list, sequence_length)        
    return X

Using TensorFlow backend.


### 4.2 get_Y

In [8]:
from sklearn.preprocessing import LabelEncoder

labelEncoder = LabelEncoder()
labelEncoder.fit(label_list)
category_quantity = labelEncoder.classes_.shape[0]


# 获取多篇文章标签one-hot编码矩阵
def get_Y(indexes):
    part_label_list = [label_list[k] for k in indexes]
    oneHot_2d_array = labelEncoder.transform(part_label_list)
    Y = keras.utils.to_categorical(oneHot_2d_array, category_quantity)
    return Y

### 4.3 使用带权重的抽样策略，计算每个样本的权重

In [9]:
def get_probability_list(label_list):
    count_series = pd.value_counts(label_list)
    category_quantity = len(count_series)
    category_weights = 1 / category_quantity
    label2weights_dict = dict([(a, b) for a, b in zip(count_series.index, category_weights/count_series)])
    probability_list = [label2weights_dict[k] for k in label_list]
    return probability_list

### 4.4 批量数据生产者线程

In [10]:
import threading
from sklearn.model_selection import train_test_split

sample_quantity = len(label_list)
index_1d_array = np.arange(sample_quantity)
train_index_1d_array, test_index_1d_array = train_test_split(index_1d_array, random_state=2019)
train_label_list = [label_list[k] for k in train_index_1d_array]
train_probability_list = get_probability_list(train_label_list)
batch_size = 128


class BatchDataThread(threading.Thread):
    def __init__(self, queue):
        super(BatchDataThread, self).__init__()
        self.queue = queue
        self.start()
    
    def run(self):
        while not self._is_stopped:
            if self.queue.qsize() < 4:
                selected_indexes = np.random.choice(
                    train_index_1d_array, size=batch_size, p=train_probability_list)
                batch_X = get_X(selected_indexes)
                batch_Y = get_Y(selected_indexes)
                put_tuple = batch_X.astype('int32'), batch_Y.astype('float32')
                self.queue.put(put_tuple)
            time.sleep(0.0001)            

### 4.5 批量数据生成器类

In [11]:
import queue

class BatchDataGenerator(object):
    def __init__(self, worker_quantity=4):
        self.queue = queue.Queue()
        for i in range(worker_quantity):
            BatchDataThread(self.queue)
            
    def __iter__(self):
        return self
    
    def __next__(self):
        batch_data = self.queue.get()
        return batch_data
    
    
batchData_generator = BatchDataGenerator()    

## 5.搭建神经网络

In [12]:
tf.reset_default_graph()
X_holder = tf.placeholder(tf.int32, [None, sequence_length])
Y_holder = tf.placeholder(tf.float32, [None, category_quantity])
data_0 = X_holder # N *  600
vocabulary_size = 7000
embedding_size = 100
layer_1 = tf.get_variable('embedding', [vocabulary_size, embedding_size])
data_1 = tf.nn.embedding_lookup(layer_1, data_0) # N * 600 * 100
filter_quantiy = 128 
layer_2 = tf.layers.conv1d # 3 * 100 * 128
data_2 = layer_2(data_1, filter_quantiy, 3, padding='SAME') # N * 600 * 128
layer_3 = tf.layers.conv1d # 5 * 100 * 128
data_3 = layer_3(data_1, filter_quantiy, 5, padding='SAME') # N * 600 * 128
layer_4 = tf.layers.conv1d # 7 * 100 * 128
data_4 = layer_4(data_1, filter_quantiy, 7, padding='SAME') # N * 600 * 128
layer_5 = tf.concat
data_5 = layer_5([data_2, data_3, data_4], axis=2) # N * 600 * 384
layer_6 = tf.reduce_max
data_6 = layer_6(data_5, [1]) # N * 384
layer_7 = tf.layers.dense # 384 * 128
fc1_units = 128
data_7 = layer_7(data_6, fc1_units) # N * 128
layer_8 = tf.nn.relu
data_8 = layer_8(data_7) # N * 128
layer_9 = tf.layers.dense
data_9 = layer_9(data_8, category_quantity) # N * 14
layer_10 = tf.nn.softmax
data_10 = layer_10(data_9) # N * 14
layer_11 = tf.nn.softmax_cross_entropy_with_logits_v2
data_11 = layer_11(labels=Y_holder, logits=data_9) # N
loss = tf.reduce_mean(data_11) # 1
learning_rate = 5e-4
optimizer = tf.train.AdamOptimizer(learning_rate)
train = optimizer.minimize(loss)
isCorrect = tf.equal(tf.argmax(Y_holder, 1), tf.argmax(data_10, 1))
accuracy = tf.reduce_mean(tf.cast(isCorrect, tf.float32))

W0828 09:59:24.074238  1516 deprecation.py:506] From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\ops\init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0828 09:59:24.089231  1516 deprecation.py:323] From <ipython-input-12-a0548f6617c7>:11: conv1d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.keras.layers.Conv1D` instead.
W0828 09:59:24.303164  1516 deprecation.py:323] From <ipython-input-12-a0548f6617c7>:22: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.dense instead.


## 6.参数初始化

In [13]:
init = tf.global_variables_initializer()
session = get_session()
session.run(init)

## 7.模型训练

In [25]:
train_steps = 10000
startTime = time.time()
for step in range(1, train_steps+1):
    batch_X, batch_Y = next(batchData_generator)
    session.run(train, {X_holder:batch_X, Y_holder:batch_Y})
    if step % 2 == 0 :
        loss_value, accuracy_value = session.run([loss, accuracy], {X_holder:batch_X, Y_holder:batch_Y})
        usedTime = time.time() - startTime
        speed = step / usedTime
        print_string = '步数:%d 损失值:%.4f 准确率:%.4f 训练速度:%.2f步/秒' %(
            step, loss_value, accuracy_value, speed)
        print_flush(print_string)

步数:10000 损失值:0.0470 准确率:0.9844 训练速度:11.10步/秒

## 8.模型测试

In [43]:
import warnings
warnings.filterwarnings("ignore")

def predict(input_content):
    id_list = get_id_list(input_content)
    X = keras.preprocessing.sequence.pad_sequences([id_list], sequence_length)    
    Y = session.run(data_10, {X_holder:X}) # 1 * 10
    y = np.argmax(Y, axis=1) # 1
    label = labelEncoder.inverse_transform(y)[0]
    return label

selected_index = np.random.choice(test_index_1d_array, 1)[0]
selected_content = content_list[selected_index]
true_label = label_list[selected_index]
predict_label = predict(selected_content)
print('选出文本内容为: ', selected_content)
print('真实标签: ', true_label)
print('预测标签: ', predict_label, '\n')
print('对于任意文本做分类预测，例如:')
input_content = "足球篮球"
print('predict("%s") :' %input_content, predict(input_content))

选出文本内容为:  高房价痛的不仅是普通人 明星买房也三思(图) 那些成名后依旧“蜗居”的明星…… 电视剧《蜗居》的热播，让人们深刻体会了高房价之下的残酷人生，高房价甚至毁掉了年青一代的幸福。日前，有网友发博文曝光了现在还在租房子住的十大当红明星，引发热议，紧接着，更多的明星公开表示自己“买不起房”，一部明星版《蜗居》俨然在上演。这些明星，果真买不起房吗？还是，挣钱再多也难以赶上飙涨的房价？还是，他们要的不是房子而是性价比？策划：周娴 撰文：记者 张素芹 房价凭什么疯涨？ 《蜗居》主创海清文章，买的都是二手房 海清、文章买了二手房，邬君梅和张嘉译坦言每月还房贷……《蜗居》中的明星们众口一词：“房价这么高，凭什么？” 饰演郭海萍的海清曾“抱怨”：现实生活中自己也是房奴。迟疑了很久，直到北京房价大涨，才终于出手花了多年积蓄在三环外买了一处二手房。“买房子缺钱，把我逼得去买彩票，我在剧组一个月时间里，每周都去买次彩票，每次买十块钱。” 饰演小贝的文章，也和海清一样，倾尽自己和家人的全力，买了套二手房：“我买房不是为了投资，是真的要去住。自从我开枝散叶，家里有老人有小孩后，一家人就真的没地住了。”文章如是说。 点评：纷繁复杂的原因，使涨价成为2009年楼市的关键词。京城，更是居不易。从今年春节之后的“小阳春”，到“红五月”、“金七银八”，再到一手房9月成交的小幅收窄、10月的小幅回升，涨价，成了贯穿今年京城楼
真实标签:  娱乐
预测标签:  娱乐 

对于任意文本做分类预测，例如:
predict("足球篮球") : 体育


## 9.混淆矩阵

In [44]:
from sklearn.metrics import confusion_matrix

def predict_test():
    startTime = time.time()
    test_sample_quantity = len(test_index_1d_array)
    batch_size  = 100
    predict_Y_list = []
    for i in range(0, test_sample_quantity, batch_size):
        part_index_1d_array = test_index_1d_array[i: i + batch_size]
        batch_X = get_X(part_index_1d_array)
        predict_Y = session.run(data_10, {X_holder:batch_X})
        predict_Y_list.extend(predict_Y)
        usedTime = time.time() - startTime
        print_string = '%d/ %d 花费时间:%.2f秒' %(i, test_sample_quantity, usedTime)
        print_flush(print_string)
    print_string = '%d/ %d 花费时间:%.2f秒' %(test_sample_quantity, test_sample_quantity, usedTime)
    print_flush(print_string)    
    Y = np.array(predict_Y_list)   
    y = np.argmax(Y, axis=1)
    predict_label_list = labelEncoder.inverse_transform(y)
    return predict_label_list


test_label_list = [label_list[k] for k in test_index_1d_array]
predict_label_list = predict_test()
pd.DataFrame(confusion_matrix(test_label_list, predict_label_list), 
             columns=labelEncoder.classes_,
             index=labelEncoder.classes_ )

209019/ 209019 花费时间:71.07秒

Unnamed: 0,体育,娱乐,家居,彩票,房产,教育,时尚,时政,星座,游戏,社会,科技,股票,财经
体育,32559,109,12,44,4,14,5,38,0,35,76,81,4,7
娱乐,112,22133,58,1,32,26,44,57,5,52,256,311,6,19
家居,15,70,7896,0,57,11,23,12,4,9,38,159,14,17
彩票,35,4,1,1835,1,0,0,4,0,4,44,7,0,2
房产,0,8,39,0,4714,5,1,26,0,3,41,53,48,40
教育,12,23,24,4,18,9750,6,81,4,16,317,198,6,15
时尚,7,36,47,0,7,3,3125,6,0,9,10,33,0,3
时政,56,67,7,4,56,93,13,14313,0,17,476,501,77,67
星座,2,0,3,0,0,0,3,0,899,0,1,0,0,1
游戏,4,4,3,0,0,2,2,3,0,5815,12,203,1,0


## 10.报告表

In [45]:
import numpy as np
from sklearn.metrics import precision_recall_fscore_support

def eval_model(y_true, y_pred, labels):
    # 计算每个分类的Precision, Recall, f1, support
    p, r, f1, s = precision_recall_fscore_support(y_true, y_pred)
    # 计算总体的平均Precision, Recall, f1, support
    tot_p = np.average(p, weights=s)
    tot_r = np.average(r, weights=s)
    tot_f1 = np.average(f1, weights=s)
    tot_s = np.sum(s)
    res1 = pd.DataFrame({
        u'Label': labels,
        u'Precision': p,
        u'Recall': r,
        u'F1': f1,
        u'Support': s
    })
    res2 = pd.DataFrame({
        u'Label': ['总体'],
        u'Precision': [tot_p],
        u'Recall': [tot_r],
        u'F1': [tot_f1],
        u'Support': [tot_s]
    })
    res2.index = [999]
    res = pd.concat([res1, res2])
    return res[['Label', 'Precision', 'Recall', 'F1', 'Support']]

eval_model(test_label_list, predict_label_list, labelEncoder.classes_)

Unnamed: 0,Label,Precision,Recall,F1,Support
0,体育,0.990629,0.986995,0.988809,32988
1,娱乐,0.976269,0.957641,0.966865,23112
2,家居,0.951669,0.948468,0.950066,8325
3,彩票,0.953742,0.947341,0.950531,1937
4,房产,0.909336,0.946967,0.92777,4978
5,教育,0.964201,0.930876,0.947246,10474
6,时尚,0.96302,0.951004,0.956974,3286
7,时政,0.931775,0.908935,0.920213,15747
8,星座,0.984666,0.988999,0.986828,909
9,游戏,0.892419,0.961316,0.925587,6049


## 11.模型保存

In [23]:
saver = tf.train.Saver()
ckptFilePath = '../resources/trained_model/textCnn.ckpt'
saver.save(session, ckptFilePath)

'../resources/trained_model/textCnn.ckpt'

## 12.模型加载

In [24]:
saver = tf.train.Saver()
session = get_session()
ckptFilePath = '../resources/trained_model/textCnn.ckpt'
saver.restore(session, ckptFilePath)

W0827 23:09:13.827638 17324 deprecation.py:323] From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
