## Facebook FastText
by @寒小阳

fasttext是facebook开源的一个词向量与文本分类工具，在学术上没有太多创新点，好处是模型简单，训练速度非常快。简单尝试可以发现，用起来还是非常顺手的，做出来的结果也不错，可以达到上线使用的标准。

简单说来，fastText做的事情，就是把文档中所有词通过lookup table变成向量，取平均后直接用线性分类器得到分类结果。fastText和ACL-15上的deep averaging network(DAN，如下图)比较相似，是一个简化的版本，去掉了中间的隐层。论文指出了对一些简单的分类任务，没有必要使用太复杂的网络结构就可以取得差不多的结果。
![](fast_text.png)


### fastText结构


fastText论文中提到了两个tricks

- hierarchical softmax

类别数较多时，通过构建一个霍夫曼编码树来加速softmax layer的计算，和之前word2vec中的trick相同

- N-gram features

只用unigram的话会丢掉word order信息，所以通过加入N-gram features进行补充用hashing来减少N-gram的存储

## fastText有监督学习(分类)示例

可以通过pip install fasttext安装包含fasttext python的接口的package

fastText做文本分类要求文本是如下的存储形式：

__label__2 , birchas chaim , yeshiva birchas chaim is a orthodox jewish mesivta high school in lakewood township new jersey . it was founded by rabbi shmuel zalmen stein in 2001 after his father rabbi chaim stein asked him to open a branch of telshe yeshiva in lakewood . as of the 2009-10 school year the school had an enrollment of 76 students and 6 . 6 classroom teachers ( on a fte basis ) for a student–teacher ratio of 11 . 5 1 . 

__label__6 , motor torpedo boat pt-41 , motor torpedo boat pt-41 was a pt-20-class motor torpedo boat of the united states navy built by the electric launch company of bayonne new jersey . the boat was laid down as motor boat submarine chaser ptc-21 but was reclassified as pt-41 prior to its launch on 8 july 1941 and was completed on 23 july 1941 . 

__label__11 , passiflora picturata , passiflora picturata is a species of passion flower in the passifloraceae family . 

__label__13 , naya din nai raat , naya din nai raat is a 1974 bollywood drama film directed by a . bhimsingh . the film is famous as sanjeev kumar reprised the nine-role epic performance by sivaji ganesan in navarathri ( 1964 ) which was also previously reprised by akkineni nageswara rao in navarathri ( telugu 1966 ) . this film had enhanced his status and reputation as an actor in hindi cinema .

其中前面的__label__是前缀，也可以自己定义，__label__后接的为类别。

我们定义我们的5个类别分别为：

- 1:technology
- 2:car
- 3:entertainment
- 4:military
- 5:sports

### 生成文本格式

In [1]:
import jieba
import pandas as pd
import random

cate_dic = {'technology':1, 'car':2, 'entertainment':3, 'military':4, 'sports':5}

df_technology = pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/technology_news.csv", encoding='utf-8')
df_technology = df_technology.dropna()

df_car = pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/car_news.csv", encoding='utf-8')
df_car = df_car.dropna()

df_entertainment = pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/entertainment_news.csv", encoding='utf-8')
df_entertainment = df_entertainment.dropna()

df_military = pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/military_news.csv", encoding='utf-8')
df_military = df_military.dropna()

df_sports = pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/sports_news.csv", encoding='utf-8')
df_sports = df_sports.dropna()

technology = df_technology.content.values.tolist()[1000:21000]
car = df_car.content.values.tolist()[1000:21000]
entertainment = df_entertainment.content.values.tolist()[:20000]
military = df_military.content.values.tolist()[:20000]
sports = df_sports.content.values.tolist()[:20000]

In [2]:
stopwords=pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/stopwords.txt",index_col=False,quoting=3,sep="\t",names=['stopword'], encoding='utf-8')
stopwords=stopwords['stopword'].values

def preprocess_text(content_lines, sentences, category):
    for line in content_lines:
        try:
            segs=jieba.lcut(line)
            segs = filter(lambda x:len(x)>1, segs)
            segs = filter(lambda x:x not in stopwords, segs)
            sentences.append("__label__"+str(category)+" , "+" ".join(segs))
        except Exception:
            print (line)
            continue 

In [3]:
#生成训练数据
sentences = []

preprocess_text(technology, sentences, cate_dic['technology'])
preprocess_text(car, sentences, cate_dic['car'])
preprocess_text(entertainment, sentences, cate_dic['entertainment'])
preprocess_text(military, sentences, cate_dic['military'])
preprocess_text(sports, sentences, cate_dic['sports'])

random.shuffle(sentences)

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.761 seconds.
Prefix dict has been built successfully.


In [4]:
sentences[0]

'__label__3 , 指引 一言以蔽之 遏制 天价 片酬 思路 明星 中心 编剧 中心 收入 激励 核心 要素 打破 利益 分配 格局'

In [5]:
print ("writing data to fasttext format...")
out = open('train_data.txt', 'w')
for sentence in sentences:
    out.write(sentence+"\n")
print ("done!")

writing data to fasttext format...
done!


### 调用fastText训练生成模型

In [6]:
import fasttext
#https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module

classifier = fasttext.train_supervised("train_data.txt", lr=0.1, dim=100, epoch=5 ,word_ngrams=2, loss='softmax')
classifier.save_model("model_file.bin")

# classifier = fasttext.supervised('train_data.txt', 
#                                        model='classifier.model', 
#                                        label_prefix='__label__')

Read 2M words
Number of words:  138518
Number of labels: 5
Progress: 100.0% words/sec/thread: 1067180 lr:  0.000000 avg.loss:  0.353969 ETA:   0h 0m 0s


### 对模型效果进行评估

In [7]:
result = classifier.test('train_data.txt')

In [8]:
result

(87573, 0.9829399472440136, 0.9829399472440136)

In [9]:
print ('P@1:', result[1])
print ('R@1:', result[2])
print ('Number of examples:', result[0])

P@1: 0.9829399472440136
R@1: 0.9829399472440136
Number of examples: 87573


### 实际预测

In [10]:
label_to_cate = {1:'technology', 2:'car', 3:'entertainment', 4:'military', 5:'sports'}

texts = ['中新网 日电 2018 预赛 亚洲区 强赛 中国队 韩国队 较量 比赛 上半场 分钟 主场 作战 中国队 率先 打破 场上 僵局 利用 角球 机会 大宝 前点 攻门 得手 中国队 领先']
labels = classifier.predict(texts)
print (labels)
print (label_to_cate[int(labels[0][0][0][-1])])

([['__label__5']], [array([0.9999566], dtype=float32)])
sports


In [11]:
# labels = classifier.predict_proba(texts) .predict可以直接得到概率
# print (labels)

### TopK个预测结果

In [12]:
labels = classifier.predict(texts, k=3)
print (labels)
# labels = classifier.predict_proba(texts, k=3)
# print (labels)

([['__label__5', '__label__4', '__label__3']], [array([9.9995661e-01, 4.5699966e-05, 2.1432630e-05], dtype=float32)])


## Fasttext文本无监督学习

In [13]:
def preprocess_text_unsupervised(content_lines, sentences, category):
    for line in content_lines:
        try:
            segs=jieba.lcut(line)
            segs = filter(lambda x:len(x)>1 and x not in stopwords, segs)
#             segs = filter(lambda x:, segs)
            sentences.append(" ".join(segs))
        except Exception:
            print (line)
            continue
#生成无监督训练数据
sentences = []

preprocess_text_unsupervised(technology, sentences, cate_dic['technology'])

In [14]:
preprocess_text_unsupervised(car, sentences, cate_dic['car'])

In [15]:
preprocess_text_unsupervised(entertainment, sentences, cate_dic['entertainment'])

In [16]:
preprocess_text_unsupervised(military, sentences, cate_dic['military'])

In [17]:
preprocess_text(sports, sentences, cate_dic['sports'])

print ("writing data to fasttext unsupervised learning format...")
out = open('unsupervised_train_data.txt', 'w')
for sentence in sentences:
    out.write(sentence+"\n")
print ("done!" )  

writing data to fasttext unsupervised learning format...
done!


In [18]:
import fasttext
model = fasttext.train_unsupervised("unsupervised_train_data.txt", model='skipgram', lr=0.05, dim=100, ws=5, epoch=5)
model.save_model("skipgram_file.bin")
# Skipgram model

# model = fasttext.skipgram('unsupervised_train_data.txt', 'model')
# print (model.words) # list of words in dictionary

# # CBOW model
# model = fasttext.cbow('unsupervised_train_data.txt', 'model')
# print (model.words) # list of words in dictionary

Read 2M words
Number of words:  41643
Number of labels: 1
Progress: 100.0% words/sec/thread:   51988 lr:  0.000000 avg.loss:  1.810080 ETA:   0h 0m 0s100.0% words/sec/thread:   51988 lr: -0.000000 avg.loss:  1.810080 ETA:   0h 0m 0s


In [19]:
print(model.words[0:100])
print (model['赛季'])

['</s>', ',', '中国', '发展', '汽车', '用户', '技术', '比赛', '市场', '平台', '服务', '电影', '产品', '2017', '企业', '数据', '北京', '公司', '互联网', '行业', '手机', '提供', '内容', '美国', '未来', '时间', '工作', '品牌', '日电', '网络', '智能', '国家', '合作', '观众', '能力', '系统', '世界', '全球', '领域', '球员', '直播', '创新', '训练', '提升', '中新网', '国际', '足球', '希望', '国内', '信息', '战略', '节目', '方式', '情况', '活动', '媒体', '生活', '去年', '球队', '视频', '项目', '带来', '科技', '发布', '包括', '消费者', '现场', '产业', '相关', '体验', '俱乐部', '建设', '增长', '体育', '显示', '超过', '全国', '需求', '设计', '人工智能', '百度', '城市', '模式', '赛事', '关注', '打造', '新能源', '上海', '参加', '推出', '文化', '业务', '表现', '功能', '选择', '集团', '导演', '海军', '拥有', '发现']
[ 0.38464445  0.5256937   0.16470358  0.283395   -0.23804471 -0.48227912
 -0.20009351  0.45830518 -0.09409188  0.39611045 -0.23909107 -0.25167876
 -0.09522913  0.11575998 -0.00747415 -0.11214487 -0.4335881  -0.57697177
 -0.5558986   0.0837727  -0.12260437  0.70339715 -0.24817887  0.5360349
 -0.27602118  0.10558875 -0.52715206  0.5204655   0.07781375  0.24900925
 -0.20513883 -0.03990277 

## 对比gensim的word2vec

In [21]:
import jieba
import pandas as pd
import random

cate_dic = {'technology':1, 'car':2, 'entertainment':3, 'military':4, 'sports':5}

df_technology = pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/technology_news.csv", encoding='utf-8')
df_technology = df_technology.dropna()

df_car = pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/car_news.csv", encoding='utf-8')
df_car = df_car.dropna()

df_entertainment = pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/entertainment_news.csv", encoding='utf-8')
df_entertainment = df_entertainment.dropna()

df_military = pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/military_news.csv", encoding='utf-8')
df_military = df_military.dropna()

df_sports = pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/sports_news.csv", encoding='utf-8')
df_sports = df_sports.dropna()

technology = df_technology.content.values.tolist()[1000:21000]
car = df_car.content.values.tolist()[1000:21000]
entertainment = df_entertainment.content.values.tolist()[:20000]
military = df_military.content.values.tolist()[:20000]
sports = df_sports.content.values.tolist()[:20000]

stopwords=pd.read_csv("/usr/local/codeData/jupyterData/jupyter-bd/textCategorization/data/stopwords.txt",index_col=False,quoting=3,sep="\t",names=['stopword'], encoding='utf-8')
stopwords=stopwords['stopword'].values

In [22]:
technology[0]

'\u3000\u3000去年年中，周星驰、徐克的《西游伏妖篇》电影刚刚公布上线日期，完美世界在众多版权争夺方中脱颖而出，拿下了《西游伏妖篇》电影的独家手游开发版权，并计划同名手游将与电影同期发布。1月24日，由《西游伏妖篇》电影版权方正版授权，完美世界、聚力互娱联合发行的手游《西游伏妖篇》正式在iOS开放公测。然而，就在该作上架苹果商店的同时，苹果商店内却已有数款以“西游伏妖篇”或近似字样为名的无授权游戏产品横行肆意。'

In [23]:
def preprocess_text_unsupervised(content_lines, sentences, category):
    for line in content_lines:
        try:
            segs=jieba.lcut(line)
            segs = filter(lambda x:len(x)>1 and x not in stopwords, segs)
#             segs = filter(lambda x:len(x)>1, segs)
#             segs = filter(lambda x:x not in stopwords, segs)
            sentences.append(list(segs))
        except Exception:
            print (line)
            continue
#生成无监督训练数据
sentences = []

preprocess_text_unsupervised(technology, sentences, cate_dic['technology'])

preprocess_text_unsupervised(car, sentences, cate_dic['car'])
preprocess_text_unsupervised(entertainment, sentences, cate_dic['entertainment'])
preprocess_text_unsupervised(military, sentences, cate_dic['military'])
preprocess_text_unsupervised(sports, sentences, cate_dic['sports'])

In [24]:
from gensim.models import Word2Vec
model = Word2Vec(sentences, size=100, window=5, min_count=5, workers=4)
model.save("gensim_word2vec.model")

In [25]:
model.wv["西游"]

array([ 1.34200644e+00,  1.33819652e+00, -8.88573587e-01,  1.07016146e+00,
       -1.63786304e+00, -1.00203626e-01, -8.65577579e-01,  9.92751300e-01,
        6.39973819e-01,  7.53422976e-01,  1.93136215e-01,  4.39125568e-01,
       -1.42659056e+00,  1.98122300e-03,  5.18594801e-01, -1.07973099e+00,
        1.23641300e+00,  7.67392755e-01,  9.01606143e-01,  7.35682189e-01,
        1.39535522e+00,  1.38254046e+00, -3.41796994e-01,  1.65568507e+00,
        2.05397153e+00,  2.41655130e-02, -2.17170358e-01,  1.02553163e-02,
       -2.21165705e+00, -1.29946005e+00, -6.98596597e-01,  8.50975811e-01,
        1.29141903e+00,  1.62934586e-01,  1.08236182e+00,  9.67248917e-01,
       -1.19605029e+00, -1.53455043e+00, -1.40964997e+00, -6.85151935e-01,
       -9.69715774e-01, -2.48186179e-02, -3.69338065e-01,  3.51844914e-02,
       -1.84719825e+00, -1.90779552e-01,  9.79875922e-01, -1.17010140e+00,
       -1.41697556e-01,  1.41309559e+00,  3.65389079e-01,  1.05244970e+00,
       -1.11130261e+00, -

In [26]:
model.wv.most_similar('赛季')

[('亚冠', 0.8700935244560242),
 ('OL3', 0.8449038863182068),
 ('国安', 0.8443719148635864),
 ('本赛季', 0.8385491371154785),
 ('中甲', 0.8265472650527954),
 ('辽足', 0.8256725072860718),
 ('延后', 0.8246681690216064),
 ('中超联赛', 0.8166651725769043),
 ('国奥队', 0.8166406154632568),
 ('BIG4', 0.8158341646194458)]