In [1]:
import jieba
import os
import fasttext
import logging

#### Prepare Data

In [None]:
basedir = r"D:/自学/NLP Basic/data/fastText_data/THUCNews/"
dir_list = ["时政","星座","财经","教育","娱乐","时尚","游戏","家居","房产","彩票","科技","体育","社会","股票"]

ftrain = open(basedir + "news_fasttext_train.txt","w", encoding="utf-8")
ftest = open(basedir + "news_fasttext_test.txt","w", encoding="utf-8")

num = -1
for e in dir_list:
    num += 1
    indir = basedir + e + '/'
    files = os.listdir(indir)
    count = 0
    for fileName in files:
        count += 1            
        filepath = indir + fileName
        with open(filepath,'r', encoding="utf-8") as fr:
            text = fr.read()
        text = str(text.encode("utf-8"),"utf-8")
        seg_text = jieba.cut(text.replace("\t"," ").replace("\n"," "))
        outline = " ".join(seg_text)
        outline = outline + "\t__label__" + e + "\n"

        if count < 10000:
            ftrain.write(outline)
            ftrain.flush()
            continue
        elif count  < 20000:
            ftest.write(outline)
            ftest.flush()
            continue
        else:
            break

ftrain.close()
ftest.close()
print("Get File Done")

#### Train fastText with THUCNews data

In [8]:
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

classifier = fasttext.train_supervised("../data/fastText_data/THUCNews/news_fasttext_train.txt", label_prefix="__label__")

print("Training Done!")

Training Done!


In [9]:
result = classifier.test("../data/fastText_data/THUCNews/news_fasttext_test.txt")
print('precision: ', result[1])

precision： 0.8833808183895069


#### Get Prediction and Evaluate

In [14]:
labels_right = []
texts = []
with open("../data/fastText_data/THUCNews/news_fasttext_test.txt", encoding="utf-8") as fr:
    for line in fr:
        line = str(line.encode("utf-8"), 'utf-8').rstrip()
        labels_right.append(line.split("\t")[1].replace("__label__",""))
        texts.append(line.split("\t")[0])

labels_predict = [term[0] for term in classifier.predict(texts)[0]]


text_labels = list(set(labels_right))
text_predict_labels = list(set(labels_predict))

A = dict.fromkeys(text_labels,0)          # Number of correct predictions in each category
B = dict.fromkeys(text_labels,0)          # Number of categories in test dataset
C = dict.fromkeys(text_predict_labels,0)  # Number of each category in prediction result
for i in range(0,len(labels_right)):
    B[labels_right[i]] += 1
    C[labels_predict[i]] += 1
    if labels_right[i] == labels_predict[i].replace('__label__', ''):
        A[labels_right[i]] += 1


for key in B:
    try:
        r = float(A[key]) / float(B[key])
        p = float(A[key]) / float(C['__label__' + key])
        f = p * r * 2 / (p + r)
        print("%s:\t precision:%f\t recall:%f\t f1-score:%f" % (key, p, r, f))
    except:
        print("error:", key, "right:", A.get(key,0), "real:", B.get(key,0), "predict:", C.get(key,0))

房产:	 precision:0.849284	 recall:0.972600	 f1-score:0.906769
时尚:	 precision:0.743872	 recall:0.963787	 f1-score:0.839669
娱乐:	 precision:0.940328	 recall:0.866700	 f1-score:0.902014
家居:	 precision:0.923861	 recall:0.888200	 f1-score:0.905680
体育:	 precision:0.986564	 recall:0.844400	 f1-score:0.909963
财经:	 precision:0.922133	 recall:0.894100	 f1-score:0.907900
股票:	 precision:0.799554	 recall:0.752300	 f1-score:0.775207
游戏:	 precision:0.972219	 recall:0.948400	 f1-score:0.960162
科技:	 precision:0.857346	 recall:0.964600	 f1-score:0.907816
社会:	 precision:0.858730	 recall:0.919700	 f1-score:0.888170
时政:	 precision:0.906178	 recall:0.712800	 f1-score:0.797940
教育:	 precision:0.892991	 recall:0.926300	 f1-score:0.909341
