In [34]:
import fasttext
import pandas as pd
import jieba
import re
import csv
from torchtext.data.utils import get_tokenizer
from tqdm.notebook import tqdm
tqdm.pandas()
from sklearn.model_selection import train_test_split

# 1.将数据处理成fasttext所需格式

In [40]:
def get_fasttext_data(df, tokenizer, stopwords):
    df['label'] = '__label__' + df['label']
    def process_text(title):
        tokens = [token for token in tokenizer(title.strip()) if token not in stopwords]
        return ' '.join(tokens)
    
    df['text'] = df['title'].progress_map(process_text)
    df['label_text'] = df['label'] + ',' + df['text']
    return df[['label_text']]

In [58]:
tokenizer = get_tokenizer('spacy', language='zh_core_web_sm')
stopwords = [line.strip() for line in open('/home/gechengze/project/nlp-notebook/stopwords/cn_stopwords.txt',
                                           'r', encoding='utf-8').readlines()]

In [83]:
df = pd.read_csv('../../../datasets/THUCNews/title.csv')
df_train, df_test = train_test_split(df, test_size=0.2)

df_train = get_fasttext_data(df_train, tokenizer, stopwords)
df_train.to_csv('./train.txt', header=None, index=False, quoting=csv.QUOTE_NONE, escapechar=' ')

df_test = get_fasttext_data(df_test, tokenizer, stopwords)
df_test.to_csv('./test.txt', header=None, index=False, quoting=csv.QUOTE_NONE, escapechar=' ')

  0%|          | 0/663933 [00:00<?, ?it/s]

  0%|          | 0/165984 [00:00<?, ?it/s]

# 2.训练模型

In [84]:
# 训练模型
model = fasttext.train_supervised('train.txt', minCount=5, epoch=100)

Read 6M words
Number of words:  73772
Number of labels: 14
Progress: 100.0% words/sec/thread: 1169093 lr:  0.000000 avg.loss:  0.099906 ETA:   0h 0m 0s


# 3.预测

In [85]:
print(model.predict('盘点 明星 私生子 爱情 事故 戏外 戏图'))
print(model.predict('微软 员工 微博 泄密 手机 遭 解雇'))

(('__label__娱乐',), array([1.00001001]))
(('__label__科技',), array([0.9972623]))


In [86]:
# 训练集的precision和recall
print(model.test('train.txt'))

(663933, 0.9887699511848335, 0.9887699511848335)


In [87]:
# 测试集的precision和recall
print(model.test('test.txt'))

(165984, 0.8734094852515906, 0.8734094852515906)
