diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1ee75ff --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +*.DS_Store +build/ + +.vscode +.idea +.project +.cproject +.pydevproject +.settings/ +.test_env/ +third_party/ + +*~ diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 0000000..b75a7e3 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,14 @@ +# Names should be added to this file like so: +# Name or Organization + +Baidu.com, Inc. + +# Initial version authors: +Jiang Di +Chen Zeyu +Jiang Jiajun +Lian Rongzhong +Li Chen +Bao Siqi + +# Partial list of contributors: diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..00e96c3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2017, Baidu, Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the Baidu, Inc. nor the names of it + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6fe2fa1 --- /dev/null +++ b/Makefile @@ -0,0 +1,76 @@ +ifdef config +include $(config) +endif + +ifndef DEPS_PATH +DEPS_PATH = $(shell pwd)/third_party +endif + +ifndef PROTOC +PROTOC = ${DEPS_PATH}/bin/protoc +endif + +CXX=g++ +CXXFLAGS=-pipe \ + -W \ + -Wall \ + -fPIC \ + -std=c++11 \ + -fno-omit-frame-pointer \ + -fpermissive \ + -O3 \ + -ffast-math \ + -funroll-all-loops + +INCPATH=-I./include/ \ + -I./include/familia \ + -I./third_party/include + +LDFLAGS_SO = -L$(DEPS_PATH)/lib -L./build/ -lfamilia -lprotobuf -lglog -lgflags + +.PHONY: all +all: familia + @echo $(SOURCES) + @echo $(OBJS) + $(CXX) $(CXXFLAGS) $(INCPATH) build/demo/inference_demo.o -Xlinker "-(" $(LDFLAGS_SO) -Xlinker "-)" -o inference_demo + $(CXX) $(CXXFLAGS) $(INCPATH) build/demo/semantic_matching_demo.o -Xlinker "-(" $(LDFLAGS_SO) -Xlinker "-)" -o semantic_matching_demo + $(CXX) $(CXXFLAGS) $(INCPATH) build/demo/word_distance_demo.o -Xlinker "-(" $(LDFLAGS_SO) -Xlinker "-)" -o word_distance_demo + $(CXX) $(CXXFLAGS) $(INCPATH) build/demo/topic_word_demo.o -Xlinker "-(" $(LDFLAGS_SO) -Xlinker "-)" -o topic_word_demo + +include depends.mk + +.PHONY: clean +clean: + rm -rf inference_demo + rm -rf semantic_matching_demo + rm -rf word_distance_demo + rm -rf topic_word_demo + rm -rf build + find src -name "*.pb.[ch]*" -delete + +# third party dependency +deps: ${GLOGS} ${GFLAGS} ${PROTOBUF} + @echo "dependency installed!" + +familia: build/libfamilia.a + +OBJS = $(addprefix build/, vose_alias.o inference_engine.o model.o vocab.o document.o sampler.o config.o util.o semantic_matching.o tokenizer.o \ + demo/inference_demo.o \ + demo/semantic_matching_demo.o \ + demo/word_distance_demo.o \ + demo/topic_word_demo.o) + +build/libfamilia.a: include/config.pb.h $(OBJS) + @echo Target $@; + ar crv $@ $(filter %.o, $?) + +build/%.o: src/%.cpp + @mkdir -p $(@D) + $(CXX) $(INCPATH) $(CXXFLAGS) -MM -MT build/$*.o $< >build/$*.d + $(CXX) $(INCPATH) $(CXXFLAGS) -c $< -o $@ + +# build proto +include/config.pb.h src/config.cpp : proto/config.proto + $(PROTOC) --cpp_out=./src --proto_path=./proto $< + mv src/config.pb.h ./include/familia + mv src/config.pb.cc ./src/config.cpp diff --git a/README.md b/README.md new file mode 100644 index 0000000..38e1b5a --- /dev/null +++ b/README.md @@ -0,0 +1,114 @@ +# Familia + +# 代码编译 +第三方依赖包括gflags,glogs,protobuf, 要求编译器支持C++11, 如 g++ >= 4.8 +默认情况下会自动获取依赖并安装。 + + git clone ssh://g@gitlab.baidu.com:8022/chenzeyu01/familia.git + sh build.sh # 包含了拉取并安装第三方依赖 + +# 模型下载 + + cd model + sh download_model.sh + +# 注意事项 +* 若出现找不到libglog.so, libgflags.so等动态库错误,请添加third_party至环境变量的LD_LIBRARY_PATH中。 + + + export LD_LIBRARY_PATH=./third_party/lib:$LD_LIBRARY_PATH + + +# 运行DEMO +## 文档主题推断 + + sh run_inference_demo.sh # 运行文档主题推断的demo + +执行程序后,通过标准流方式输入文档,每行为一个文档,程序会返回每个文档的主题分布。如下所示 + + 请输入需要推断主题分布的文档: + 百度又一次展示了自动驾驶领域领导者的大气风范,发布了一项名为“Apollo(阿波罗)”的新计划,向汽车行业及自动驾驶领域的合作伙伴提供一个开放、完整、安全的软件平台,帮助他们结合车辆和硬件系统,快速搭建一套属于自己的完整的自动驾驶系统。 + + 文档主题分布: + 9159:0.103704 4296:0.072840 7486:0.058025 1378:0.037037 1073:0.037037 2414:0.037037 3935:0.034568 5921:0.032099 7380:0.032099 8643:0.032099 4757:0.030864 6808:0.025926 7185:0.022222 4091:0.019753 1167:0.017284 8843:0.016049 5292:0.014815 2507:0.014815 9914:0.013580 2520:0.011111 7658:0.011111 249:0.011111 2017:0.009877 2995:0.008642 4021:0.008642 7163:0.008642 9336:0.007407 1438:0.007407 136:0.007407 7095:0.007407 2313:0.007407 4309:0.007407 1314:0.006173 3573:0.006173 9529:0.006173 477:0.004938 6446:0.004938 281:0.004938 4072:0.004938 9082:0.004938 847:0.004938 27:0.004938 5872:0.004938 2720:0.004938 1322:0.004938 8848:0.003704 7765:0.003704 7838:0.003704 7891:0.003704 7918:0.003704 1592:0.003704 7107:0.003704 1766:0.003704 1812:0.003704 6726:0.003704 6513:0.003704 5660:0.003704 8996:0.003704 1434:0.003704 3407:0.003704 2285:0.003704 500:0.003704 3615:0.003704 3766:0.003704 4704:0.002469 1449:0.002469 9599:0.002469 7779:0.002469 2565:0.002469 7425:0.002469 1665:0.002469 9473:0.002469 9395:0.002469 872:0.002469 8411:0.002469 8606:0.002469 4490:0.002469 8722:0.002469 386:0.002469 4817:0.002469 8826:0.002469 1219:0.002469 75:0.002469 8859:0.002469 7716:0.001235 9280:0.001235 1399:0.001235 9304:0.001235 1:0.001235 9536:0.001235 8099:0.001235 8266:0.001235 1175:0.001235 91:0.001235 5809:0.001235 3087:0.001235 3265:0.001235 3752:0.001235 3832:0.001235 3908:0.001235 2515:0.001235 1046:0.001235 804:0.001235 1953:0.001235 5263:0.001235 428:0.001235 5514:0.001235 5624:0.001235 7696:0.001235 5826:0.001235 5906:0.001235 6196:0.001235 6240:0.001235 6378:0.001235 1896:0.001235 6875:0.001235 6917:0.001235 244:0.001235 7469:0.001235 1516:0.001235 7488:0.001235 + +其中,冒号前为主题ID,冒号后为该主题的概率,按照主题概率从小到大的方式排序。 +可通过更改脚本中--work_dir和--conf_file的配置选择其他模型,如 + + --work_dir="./model/webpage/" --conf_file="lda.conf" # 选用网页LDA主题模型 + --work_dir="./model/webpage/" --conf_file="slda.conf" # 选用网页SentenceLDA主题模型 + +## 语义匹配计算 + + sh run_semantic_matching_demo.sh # 运行语义匹配计算的demo + +默认为计算短文本与长文本语义匹配模式,运行结果如下所示 + + 请输入短文本: + 百度宣布阿波罗计划 开放自动驾驶技术有望改变汽车产业 + 请输入长文本: + 百度又一次展示了自动驾驶领域领导者的大气风范,发布了一项名为“Apollo(阿波罗)”的新计划,向汽车行业及自动驾驶领域的合作伙伴提供一个开放、完整、安全的软件平台,帮助他们结合车辆和硬件系统,快速搭建一套属于自己的完整的自动驾驶系统。 + LDA sim = 0.0133234 TWE sim = 0.128288 + +将脚本中的--mode参数修改为1,则为长文本语义相似度模式, 运行结果如下所示 + + 请输入文档1: + 在人工智能发展得最为系统化的硅谷,AI工程师们的薪水远高于其他领域的同行。随着人工智能概念的不断深入人心,人工智能的人才愈发的紧俏,时至今日,大学刚毕业的博士也能坐拥八九十万的年薪,与资深的硅谷工程师相媲美。 + 请输入文档2: + 在国内,部分企业早已瞄准人才的短板,走在了业界的前面。百度是最早进行AI的人才培养布局的,他们同国内诸多高校开展合作,共建工程实验室,在数据开放和资源共享上进行各种合作。这种方式类似美国在人工智能教育领域推行的“硅谷-斯坦福”校企联动模式,一方面斯坦福大学为硅谷提供了人才和科研成果,另一方面硅谷为斯坦福大学提供资金支持和大数据,以助力他们的科研能有更大的突破。 + Jensen Shannon Divergence = 1.13397 + Hellinger Distance = 0.889616 + +## 邻近词查询 + + sh run_word_distance_demo.sh # 运行邻近词查询的demo + +执行程序后,通过标准流方式输入词,每行为一个词,程序会返回每个词的最邻近的K个词。如下所示 + + 请输入词语: 篮球 + Word Cosine distance + -------------------------------- + 足球 0.903682 + 网球 0.842661 + 羽毛球 0.836915 + 足球比赛 0.809366 + 五人制足球 0.799211 + 美式足球 0.791207 + 中国足球 0.788995 + 乒乓球 0.788278 + 五人制 0.784913 + 足球新闻 0.783203 + +其中,每一行为一个词,数字表示该词与输入词的cosine距离,按照从大到小的顺序排序。可通过更改脚本中--work_dir和--conf_file的配置选择其他模型,--top_k配置展现词的个数,如 + + --work_dir="./model/webpage/" --conf_file="lda.conf" --top_k=10 # 选用网页LDA主题模型,展现距离最近的前10个词 + --work_dir="./model/webpage/" --conf_file="slda.conf" --top_k=20 # 选用网页SentenceLDA主题模型,展现距离最近的前20个词 + +## 主题词查询 +在TWE模型中,通过计算主题向量与词向量的cosine相似度可以衡量主题与每个词的相关性,可以每个主题下最邻近的K个词。同理,在LDA模型中,也可以得到每个主题下每个词的产生概率。主题词查询demo展示这两个模型的主题词结果。 + + sh run_topic_word_demo.sh # 运行主题词查询的demo + +执行程序后,通过标准流方式输入主题id,每行为一个id,程序会返回每个主题在TWE跟主题模型下最邻近的K个词的结果。如下所示 + + 请输入主题编号(0-10000): 105 + TWE result LDA result + ------------------------------------ + 卫生检疫 国家 + 检验检疫 出入境 + 上海口岸 外籍 + 外经贸部 检验检疫 + 正式批准 检验检疫局 + 认监委 国外 + 卫生注册证书 互认 + 检验检疫局 要闻 + 资格认可 奖励旅游 + 许可制度 公布 + +其中,每一行为有两次词,第一个词为TWE召回结果,第二个词为主题模型召回结果,按照相关性从大到小的顺序排序。可通过更改脚本中--work_dir和--emb_file的配置选择其他TWE模型,--topic_words_file配置主题模型的主题结果,如 + + --work_dir="./model/webpage/" --emb_file="webpage_twe_lda.model" --topic_words_file="topic_words.lda.txt" # 选用网页LDA主题模型训练得到TWE模型以及对应的主题展现结果 + +# 注意事项 +* 代码中内置简易的FMM分词工具,只针对主题模型中出现的词表进行正向匹配。该工具仅用于Demo示例使用,若对分词和语义准确度有更高要求,建议使用开源的分词工具, 并使用自定义词表的功能导入主题模型中的词表。 + diff --git a/benchmark.sh b/benchmark.sh new file mode 100644 index 0000000..917144c --- /dev/null +++ b/benchmark.sh @@ -0,0 +1,33 @@ +export LD_LIBRARY_PATH=./third_party/lib:$LD_LIBRARY_PATH + +if [ -d news_t1000 ];then + echo "model file downloaded already" +else + rm -rf news_t1000 + mkdir news_t1000 + + echo "get example model..." + cd news_t1000 + wget ftp://nj03-rp-m22nlp062.nj03.baidu.com/home/disk0/chenzeyu01/public/infer_data/news_t1000/* + + cd .. + rm -rf input.sample + + echo "get input data..." + wget ftp://nj03-rp-m22nlp062.nj03.baidu.com/home/disk0/chenzeyu01/public/infer_data/input.sample +fi + +echo "running infer program..." + +# ./lda-infer model_path, lda-infer.conf, #burn_in_iter #total_iter +# cat input.merge.sample | ./lda-infer ./news_t1000 lda_infer.conf > infer.result +#cat example/example.txt | ./inference_demo --work_dir="./news_t1000" --conf_file="model.conf" > infer.result +cat example/input.sample | ./test --work_dir="./news_t1000" --conf_file="model.conf" > infer.result +#head input.sample | ./lda-infer ./news_t1000 lda_infer.conf > infer.result + +echo "infer result store in infer.result" +python scripts/jsd.py infer.result news_t1000/doc_topic.txt 1000 +#python tools/jsd.py infer.slda.mh.result infer.slda.gs.result 1000 +#python tools/jsd.py infer.result infer.slda.gs.result 1000 +echo "new alias_table jsd / new random / fix_random_seed = 0.135918819452 / 0.146278813358 / 0.135296716886" +echo "original time cost 4.22496 sec" diff --git a/build.sh b/build.sh new file mode 100644 index 0000000..cbf609a --- /dev/null +++ b/build.sh @@ -0,0 +1,3 @@ +mkdir -p third_party +make deps +make clean && make -j4; diff --git a/depends.mk b/depends.mk new file mode 100644 index 0000000..25f29cc --- /dev/null +++ b/depends.mk @@ -0,0 +1,38 @@ +# Install dependencies + +URL=http://raw.githubusercontent.com/ZeyuChen/third_party/master/package/ +ifndef WGET + WGET = wget --no-check-certificate +endif + +# protobuf +PROTOBUF = ${DEPS_PATH}/include/google/protobuf/message.h +${PROTOBUF}: + $(eval FILE=protobuf-2.5.0.tar.gz) + $(eval DIR=protobuf-2.5.0) + rm -rf $(FILE) $(DIR) + $(WGET) $(URL)/$(FILE) && tar -zxf $(FILE) + cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure --disable-shared -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install + rm -rf $(FILE) $(DIR) +protobuf: | ${PROTOBUF} + +GFLAGS = ${DEPS_PATH}/include/google/gflags.h +${GFLAGS}: + $(eval FILE=gflags-2.0-no-svn-files.tar.gz) + $(eval DIR=gflags-2.0) + rm -rf $(FILE) $(DIR) + $(WGET) $(URL)/$(FILE) && tar -zxf $(FILE) + cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install + rm -rf $(FILE) $(DIR) +gflags: | ${GFLAGS} + +# glog +GLOGS = ${DEPS_PATH}/include/glog/logging.h +${GLOGS}: + $(eval FILE=glog-0.3.3.tar.gz) + $(eval DIR=glog-0.3.3) + rm -rf $(FILE) $(DIR) + $(WGET) $(URL)/$(FILE) && tar -zxf $(FILE) + cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) --with-gflags=$(DEPS_PATH) && $(MAKE) && $(MAKE) install + rm -rf $(FILE) $(DIR) +glog: | ${GLOGS} diff --git a/include/familia/document.h b/include/familia/document.h new file mode 100644 index 0000000..657320b --- /dev/null +++ b/include/familia/document.h @@ -0,0 +1,116 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#ifndef FAMILIA_DOCUMENT_H +#define FAMILIA_DOCUMENT_H + +#include +#include +#include + +namespace familia { + +// 主题的基本数据结构,包含id以及对应的主题概率 +struct Topic { + int tid; // topic id + double prob; // topic probability + bool operator<(const Topic& t) const { + return prob > t.prob; // 优先按照主题概率从大到小排序 + } +}; + +// LDA文档存储基本单元,包含词id以及对应的主题id +struct Token { + int topic; + int id; +}; + +// SentenceLDA文档存储基本单元,包含句子的词id以及对应的主题id +struct Sentence { + int topic; + std::vector tokens; +}; + +// LDA模型inference结果存储结构 +class LDADoc { +public: + LDADoc() = default; + + explicit LDADoc(int num_topics) { + init(num_topics); + } + + // 根据主题数初始化文档结构 + void init(int num_topics); + + // 添加新的单词 + void add_token(const Token& token); + + Token& token(size_t index) { + return _tokens[index]; + } + + // 对文档中第index位单词的主题置为new_topic, 并更新相应的文档主题分布 + void set_topic(int index, int new_topic); + + // 返回文档中词的数量 + inline size_t size() const { + return _tokens.size(); + } + + inline size_t topic_sum(int topic_id) const { + return _topic_sum[topic_id]; + } + + // 返回稀疏格式的文档主题分布, 默认按照主题概率从大到小的排序 + void topic_dist(std::vector& topic_dist, bool sort = true) const; + + // 返回稠密格式的文档主题分布 + void dense_topic_dist(std::vector& dense_dist) const; + + // 对每轮采样结果进行累积, 以得到一个逼近真实结果更平滑的分布 + void accumulate_topic_sum(); + +protected: + // 主题数 + int _num_topics; + // inference 结果存储结构 + std::vector _tokens; + // 文档在一轮采样中的topic sum + std::vector _topic_sum; + // topic sum在多轮采样中的累积结果 + std::vector _accum_topic_sum; +}; + +// Sentence LDA Document +// 继承自LDADoc,新增了add_sentence接口 +class SLDADoc : public LDADoc { +public: + SLDADoc() = default; + + void init(int num_topics); + + // 新增句子 + void add_sentence(const Sentence& sent); + + // 对文档中第index位句子的主题置为new_topic, 并更新相应的文档主题分布 + void set_topic(int index, int new_topic); + + // 返回文档句子数量 + inline size_t size() const { + return _sentences.size(); + } + + inline Sentence& sent(size_t index) { + return _sentences[index]; + } + +private: + // 文档为句子的集合,每个句子有一个对应主题 + std::vector _sentences; +}; +} // namespace familia +#endif // FAMILIA_DOCUMENT_H diff --git a/include/familia/inference_engine.h b/include/familia/inference_engine.h new file mode 100644 index 0000000..7a5db17 --- /dev/null +++ b/include/familia/inference_engine.h @@ -0,0 +1,69 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#ifndef FAMILIA_INFERENCE_ENGINE_H +#define FAMILIA_INFERENCE_ENGINE_H + +#include + +#include "familia/util.h" +#include "familia/config.pb.h" +#include "familia/vocab.h" +#include "familia/model.h" +#include "familia/sampler.h" +#include "familia/document.h" + +namespace familia { + +// 采样器类型 +enum class SamplerType { + GibbsSampling = 0, + MetropolisHastings = 1 +}; + +// Inference Engine 支持LDA 和Sentence-LDA两种模型的主题推断, 两种模型使用相同的存储格式 +// 同时包含吉布斯采样和Metroplis-Hastings两种采样算法 +class InferenceEngine { +public: + ~InferenceEngine() = default; + + // 默认使用 Metroplis-Hastings 采样算法 + InferenceEngine(const std::string& work_dir, + const std::string& conf_file, + SamplerType type = SamplerType::MetropolisHastings); + + // 对input的输入进行LDA主题推断,输出结果存放在doc中 + // 其中input是分词后字符串的集合 + int infer(const std::vector& input, LDADoc& doc); + + // 对input的输入进行SentenceLDA主题推断,输出结果存放在doc中 + // 其中input是句子的集合 + int infer(const std::vector>& input, SLDADoc& doc); + + // REQUIRE: 总轮数需要大于burn-in迭代轮数, 其中总轮数越大,得到的文档主题分布越平滑 + void lda_infer(LDADoc& doc, int burn_in_iter, int total_iter) const; + + // REQUIRE: 总轮数需要大于burn-in迭代轮数, 其中总轮数越大,得到的文档主题分布越平滑 + void slda_infer(SLDADoc& doc, int burn_in_iter, int total_iter) const; + + // 返回模型指针以便获取模型参数 + inline std::shared_ptr get_model() { + return _model; + } + + // 返回模型类型, 指明为LDA还是SetennceLDA + ModelType model_type() { + return _model->type(); + } + +private: + // 模型结构指针 + std::shared_ptr _model; + // 采样器指针, 作用域仅在InferenceEngine + std::unique_ptr _sampler; +}; +} // namespace familia +#endif // FAMILIA_INFERENCE_ENGINE_H diff --git a/include/familia/model.h b/include/familia/model.h new file mode 100644 index 0000000..027d908 --- /dev/null +++ b/include/familia/model.h @@ -0,0 +1,115 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#ifndef FAMILIA_MODEL_H +#define FAMILIA_MODEL_H + +#include +#include +#include +#include + +#include "familia/config.pb.h" +#include "familia/util.h" +#include "familia/vocab.h" + +namespace familia { + +// 主题计数,key为topic id,value是计数值 +typedef std::pair TopicCount; +// 多个主题计数构成主题分布 TopicDist = Topic Distribution +typedef std::vector TopicDist; + +// 主题模型模型存储结构,包含词表和word topic count两分布 +// 其中LDA和SentenceLDA使用同样的模型存储格式 +class TopicModel { +public: + TopicModel() = delete; + + TopicModel(const std::string& work_dir, const ModelConfig& config); + + inline int term_id(const std::string& term) const { + return _vocab.get_id(term); + } + + // 加载word topic count以及词表文件 + void load_model(const std::string& word_topic_path, const std::string& vocab_path); + + // 返回模型中某个词在某个主题下的参数值,由于模型采用稀疏存储,若找不到则返回0 + int word_topic(int word_id, int topic_id) const { + // 二分查找 + auto it = std::lower_bound(_word_topic[word_id].begin(), + _word_topic[word_id].end(), + std::make_pair(topic_id, std::numeric_limits::min())); + if (it != _word_topic[word_id].end() && it->first == topic_id) { + return it->second; + } else { + return 0; + } + } + + // 返回某个词的主题分布 + TopicDist& word_topic(int term_id) { + return _word_topic.at(term_id); + } + + // 返回指定topic id的topic sum参数 + uint64_t topic_sum(int topic_id) const; + + // 返回topic sum参数向量 + std::vector& topic_sum() { + return _topic_sum; + } + + inline int num_topics() const { + return _num_topics; + } + + inline size_t vocab_size() const { + return _vocab.size(); + } + + inline float alpha() const { + return _alpha; + } + + inline float alpha_sum() const { + return _alpha_sum; + } + + inline float beta() const { + return _beta; + } + + inline float beta_sum() const { + return _beta_sum; + } + + inline ModelType type() const { + return _type; + } + +private: + // 加载word topic参数 + void load_word_topic(const std::string& word_topic_path); + // word topic 模型参数 + std::vector _word_topic; + // word topic对应的每一维主题的计数总和 + std::vector _topic_sum; + // 模型对应的词表数据结构 + Vocab _vocab; + // 主题数 + int _num_topics; + // 主题模型超参数 + float _alpha; + float _alpha_sum; + float _beta; + float _beta_sum; + // 模型类型 + ModelType _type; +}; +} // namespace familia +#endif // FAMILIA_MODEL_H diff --git a/include/familia/sampler.h b/include/familia/sampler.h new file mode 100644 index 0000000..0de6e7c --- /dev/null +++ b/include/familia/sampler.h @@ -0,0 +1,132 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#ifndef FAMILIA_LDA_SAMPLER_H +#define FAMILIA_LDA_SAMPLER_H + +#include "familia/document.h" +#include "familia/vose_alias.h" +#include "familia/model.h" +#include "familia/util.h" + +#include + +namespace familia { + +typedef std::vector TopicIndex; + +// 采样器的接口 +class Sampler { +public: + virtual ~Sampler() = default; + + // 对文档进行LDA主题采样 + virtual void sample_doc(LDADoc& doc) = 0; + + // 对文档进行SentenceLDA主题采样 + virtual void sample_doc(SLDADoc& doc) = 0; +}; + +// 基于Metropolis-Hastings的采样器实现,包含LDA和SentenceLDA两个模型的实现 +class MHSampler : public Sampler { +public: + // 默认MH-Steps为2 + MHSampler(std::shared_ptr model) : _model(model) { + construct_alias_table(); + } + + void sample_doc(LDADoc& doc) override; + + void sample_doc(SLDADoc& doc) override; + + // no copying allowed + MHSampler(const MHSampler&) = delete; + MHSampler& operator=(const MHSampler&) = delete; + +private: + // 根据LDA模型参数构建alias table + int construct_alias_table(); + + // 对文档中的一个词进行主题采样, 返回采样结果对应的主题ID + int sample_token(LDADoc& doc, Token& token); + + // 对文档中的一个句子进行主题采样, 返回采样结果对应的主题ID + int sample_sentence(SLDADoc& doc, Sentence& sent); + + // doc proposal for LDA + int doc_proposal(LDADoc& doc, Token& token); + + // doc proposal for Sentence-LDA + int doc_proposal(SLDADoc& doc, Sentence& sent); + + // word proposal for LDA + int word_proposal(LDADoc& doc, Token& token, int old_topic); + + // word proposal for Sentence-LDA + int word_proposal(SLDADoc& doc, Sentence& sent, int old_topic); + + // propotional function for LDA model + float proportional_funtion(LDADoc& doc, Token& token, int new_topic); + + // propotional function for SLDA model + float proportional_funtion(SLDADoc& doc, Sentence& sent, int new_topic); + + // word proposal distribuiton for LDA and Sentence-LDA + float word_proposal_distribution(int word_id, int topic); + + // doc proposal distribution for LDA and Sentence-LDA + float doc_proposal_distribution(LDADoc& doc, int topic); + + // 对当前词id的单词使用Metroplis-Hastings方法proprose一个主题id + int propose(int word_id); + + // LDA model pointer, shared by sampler and inference engine + std::shared_ptr _model; + + // 主题的下标映射 + std::vector _topic_indexes; + + // 存放每个单词使用VoseAlias Method构建的alias结果(word-proposal无先验参数部分) + std::vector _alias_tables; + + // 存放每个单词各个主题下概率之和(word-proposal无先验参数部分) + std::vector _prob_sum; + + // 存放先验参数部分使用VoseAlias Method构建的alias结果(word-proposal先验参数部分) + VoseAlias _beta_alias; + + // 存放先验参数各个主题下概率之和(word-proposal先验参数部分) + double _beta_prior_sum; + + // metroplis-hastings steps, 默认值为2 + static constexpr int _mh_steps = 2; +}; + +// 吉布斯采样器,实现了LDA和SentenceLDA两种模型的采样算法 +class GibbsSampler : public Sampler { +public: + GibbsSampler(std::shared_ptr model) : _model(model) { + } + + // 对文档输入进行LDA主题采样,主题结果保存在doc中 + void sample_doc(LDADoc& doc) override; + + // 使用Sentence-LDA模型对文档每个句子进行采样, 结果保存在doc中 + // 其中SentenceLDA采样算法考虑了数值计算的精度问题,对公式进行了采样 + void sample_doc(SLDADoc& doc) override; + + // no copying allowed + GibbsSampler(const GibbsSampler&) = delete; + GibbsSampler& operator=(const GibbsSampler&) = delete; +private: + int sample_token(LDADoc& doc, Token& token); + + int sample_sentence(SLDADoc& doc, Sentence& sent); + + std::shared_ptr _model; +}; +} // namespace familia +#endif // FAMILIA_SAMPLER_H diff --git a/include/familia/semantic_matching.h b/include/familia/semantic_matching.h new file mode 100644 index 0000000..9e58439 --- /dev/null +++ b/include/familia/semantic_matching.h @@ -0,0 +1,218 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: jiangjiajun@baidu.com + +#ifndef FAMILIA_SEMANTIC_MATCHING_H +#define FAMILIA_SEMANTIC_MATCHING_H + +#include "familia/model.h" +#include "familia/document.h" + +#include +#include +#include +#include + +namespace familia { + +constexpr double EPS = 1e-06; // epsilon + +typedef std::vector Embedding; +typedef std::vector Distribution; + +// 存储候选词以及对应距离 +struct WordAndDis { + std::string word; + float distance; +}; + +// Topical Word Embedding (TWE) 模型类 +// 包括了模型的加载及embedding的获取 +class TopicalWordEmbedding { +public: + TopicalWordEmbedding(const std::string& work_dir, + const std::string& emb_file) { + const std::string emb_path = work_dir + "/" + emb_file; + CHECK_EQ(load_emb(emb_path), 0) << "Failed to load Topical Word Embedding!"; + } + + ~TopicalWordEmbedding() = default; + + // 加载Topical Word Embedding + int load_emb(const std::string& emb_file); + + // 根据topic id返回topic的embedding + Embedding& topic_emb(int topic_id); + + // 根据明文返回词的embedding + Embedding& word_emb(const std::string& term); + + // 返回距离词最近的K个词 + void nearest_words(const std::string& word, + std::vector& candidates); + + // 返回离主题最近的K个词 + void nearest_words_around_topic(int topic_index, + std::vector& candidates); + + // 检查当前词是否在TWE模型中 + bool contains_word(const std::string& term) const; + + // 返回主题数 + int num_topics() const; + +private: + // word embedding + std::unordered_map _word_emb; + // topic embedding + std::vector _topic_emb; + // num of topics + int _num_topics; + // TWE模型embeeding size + int _emb_size; + // TWE中word embedding的词表大小 + int _vocab_size; +}; + +// 语义匹配计算指标类 +class SemanticMatching { +public: + // 计算向量的长度,传入的是embedding + // NOTE: 可用SSE进行向量运算加速,此处为了代码可读性不进行优化 + static float l2_norm(const Embedding& vec) { + float result = 0.0; + for (size_t i = 0; i < vec.size(); ++i) { + result += vec[i] * vec[i]; + } + + return sqrt(result); + } + + // 计算两个embedding的余弦相似度 + static float cosine_similarity(const Embedding& vec1, const Embedding& vec2) { + float result = 0.0; + float norm1 = l2_norm(vec1); + float norm2 = l2_norm(vec2); + + // NOTE: 可用SSE进行向量运算加速,此处为了代码可读性不进行优化 + for (size_t i = 0; i < vec1.size(); ++i) { + result += vec1[i] * vec2[i]; + } + result = result / norm1 / norm2; + return result; + } + + // 使用短文本到长文本之间的似然值表示之间的相似度 + static float likelihood_based_similarity(const std::vector& terms, + const std::vector& doc_topic_dist, + std::shared_ptr model) { + int num_of_term_in_vocab = 0; + float result = 0.0; + + for (size_t i = 0; i < terms.size(); ++i) { + int term_id = model->term_id(terms[i]); + if (term_id == OOV) { + continue; + } + + // 统计在词表中的单词 + num_of_term_in_vocab += 1; + for (size_t j = 0; j < doc_topic_dist.size(); ++j) { + int topic_id = doc_topic_dist[j].tid; + float prob = doc_topic_dist[j].prob; + result += model->word_topic(term_id, topic_id) * 1.0 / + model->topic_sum(topic_id) * prob; + } + } + + if (num_of_term_in_vocab == 0) { + return result; + } + + return result / num_of_term_in_vocab; + } + + + // 基于Topical Word Embedding (TWE) 计算短文本与长文本的相似度 + // 输入短文本明文分词结果,长文本主题分布,TWE模型,返回长文本与短文本语义相似度 + static float twe_based_similarity(const std::vector& terms, + const std::vector& doc_topic_dist, + TopicalWordEmbedding& twe) { + int short_text_length = terms.size(); + float result = 0.0; + + for (size_t i = 0; i < terms.size(); ++i) { + if (!twe.contains_word(terms[i])) { + short_text_length--; + continue; + } + Embedding& word_emb = twe.word_emb(terms[i]); + for (const auto& topic : doc_topic_dist) { + Embedding& topic_emb = twe.topic_emb(topic.tid); + result += cosine_similarity(word_emb, topic_emb) * topic.prob; + } + } + + if (short_text_length == 0) { // 如果短文本中的词均不在词表中 + return 0.0; + } + + return result / short_text_length; // 针对短文本长度进行归一化 + } + + // Kullback Leibler Divergence + // D(P||Q) = \sum_i {P(i) ln \frac {P(i)}{Q(i)} + // REQUIRE: 传入的两个参数维度须一致 + static float kullback_leibler_divergence(Distribution& dist1, Distribution& dist2) { + CHECK_EQ(dist1.size(), dist2.size()); + float result = 0.0; + for (size_t i = 0; i < dist1.size(); ++i) { + dist2[i] = dist2[i] < EPS ? EPS : dist2[i]; + result += dist1[i] * log(dist1[i] / dist2[i]); + } + + return result; + } + + // Jensen-Shannon Divergence + // REQUIRE: 传入的两个参数维度须一致 + static float jensen_shannon_divergence(Distribution& dist1, Distribution& dist2) { + CHECK_EQ(dist1.size(), dist2.size()); + // 检测分布值小于epsilon的情况 + for (size_t i = 0; i < dist1.size(); ++i) { + dist1[i] = dist1[i] < EPS ? EPS : dist1[i]; + dist2[i] = dist2[i] < EPS ? EPS : dist2[i]; + } + + Distribution mean(dist1.size(), 0); + + for (size_t i = 0; i < dist1.size(); ++i) { + mean[i] = (dist1[i] + dist2[i]) * 0.5; + } + + float jsd = kullback_leibler_divergence(dist1, mean) * 0.5 + + kullback_leibler_divergence(dist2, mean) * 0.5; + return jsd; + } + + // Hellinger Distance + // REQUIRE: 传入的两个参数维度须一致 + static float hellinger_distance(Distribution& dist1, Distribution& dist2) { + CHECK_EQ(dist1.size(), dist2.size()); + + // NOTE: 可用SSE进行向量运算加速,此处为了代码可读性不进行优化 + float result = 0.0; + for (size_t i = 0; i < dist1.size(); ++i) { + float tmp = sqrt(dist1[i]) - sqrt(dist2[i]); + result += tmp * tmp; + } + + // 1/√2 = 0.7071067812 + result = sqrt(result) * 0.7071067812; + return result; + } +}; +} // namespace familia +#endif // FAMILIA_SEMANTIC_MATCHING_H diff --git a/include/familia/tokenizer.h b/include/familia/tokenizer.h new file mode 100644 index 0000000..718fe71 --- /dev/null +++ b/include/familia/tokenizer.h @@ -0,0 +1,69 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#ifndef FAMILIA_TOKENIZER_H +#define FAMILIA_TOKENIZER_H + +#include +#include +#include +#include + +#include "familia/util.h" + +namespace familia { + +// 分词器基类 +class Tokenizer { +public: + Tokenizer() = default; + + virtual ~Tokenizer() = default; + + virtual void tokenize(const std::string& text, std::vector& result) const = 0; +}; + +// 简单版本FMM分词器,仅用于主题模型应用Demo,非真实业务应用场景使用 +// NOTE: 该分词器只识别主题模型中词表的单词 +class SimpleTokenizer : public Tokenizer { +public: + SimpleTokenizer(const std::string& vocab_path) : _max_word_len(1) { + load_vocab(vocab_path); + } + + ~SimpleTokenizer() = default; + + // 对输入text字符串进行简单分词,结果存放在result中 + void tokenize(const std::string& text, std::vector& result) const override; + + // 检查word是否在词表中 + bool contains(const std::string& word) const; +private: + // 加载分词词典, 与主题模型共享一套词典 + void load_vocab(const std::string& vocab_path); + + // 检查字符是否为英文字符 + static bool is_eng_char(char c) { + // 'A' - 'Z' & 'a' - 'z' + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); + } + + // 返回对应字符的小写字符,如无对应小写字符则返回原字符 + static char tolower(char c) { + if (c >= 'A' && c <= 'Z') { + return 'a' + (c - 'A'); + } else { + return c; + } + } + + // 词表中单词最大长度 + int _max_word_len; + // 词典数据结构 + std::unordered_set _vocab; +}; +} // namespace familia +#endif // FAMILIA_TOKENIZER_H diff --git a/include/familia/util.h b/include/familia/util.h new file mode 100644 index 0000000..a30c734 --- /dev/null +++ b/include/familia/util.h @@ -0,0 +1,93 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#ifndef FAMILIA_UTIL_H +#define FAMILIA_UTIL_H + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace familia { + +// 返回一个可在多线程下工作的随机数引擎 +inline std::default_random_engine& local_random_engine() { + struct engine_wrapper_t { + std::default_random_engine engine; + engine_wrapper_t() { + static std::atomic x(0); + // TODO: change to random seed + // std::seed_seq sseq = {x++, x++, x++, (unsigned long)time(NULL)}; + // fix seed + std::seed_seq sseq = {x++, x++, x++, x++}; + engine.seed(sseq); + } + }; + static engine_wrapper_t r; + return r.engine; +} + +// 固定随机种子并重置分布 +inline void fix_random_seed(int seed = 2147483647) { + auto& engine = local_random_engine(); + engine.seed(seed); + static std::uniform_real_distribution distribution(0.0, 1.0); + distribution.reset(); // 重置分布,使下一次从分布中生成的样本不依赖过去的生成的状态 +} + +// 返回min~max之间的随机浮点数, 默认返回0~1的浮点数 +inline double rand(double min = 0.0, double max = 1.0) { + auto& engine = local_random_engine(); + static std::uniform_real_distribution distribution(min, max); + + return distribution(engine); +} + +// 返回[0, k - 1]之间的整型浮点数 +inline int rand_k(int k) { + return static_cast(rand(0.0, 1.0) * k); +} + +template +int load_prototxt(const std::string& config_file, T& proto) { + LOG(INFO) << "Loading prototxt: " << config_file; + std::ifstream fin(config_file); + CHECK(fin) << "Open prototxt file: " << config_file; + + if (fin.fail()) { + LOG(FATAL) << "Open prototxt file failed: " << config_file; + return -1; + } + + fin.seekg(0, std::ios::end); + int file_size = fin.tellg(); + fin.seekg(0, std::ios::beg); + + std::vector file_content_buffer(file_size, ' '); + fin.read(file_content_buffer.data(), file_size); + + std::string proto_str(file_content_buffer.data(), file_size); + + if (!google::protobuf::TextFormat::ParseFromString(proto_str, &proto)) { + LOG(FATAL) << "Failed to load config: " << config_file; + return -1; + } + + fin.close(); + + return 0; +} + +// 简单版本的split函数, 按照分隔符进行分割 +void split(std::vector& result, const std::string& text, char separator); + +} // namespace familia +#endif // FAMILIA_UTIL_H diff --git a/include/familia/vocab.h b/include/familia/vocab.h new file mode 100644 index 0000000..25c4cb3 --- /dev/null +++ b/include/familia/vocab.h @@ -0,0 +1,38 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#ifndef FAMILIA_VOCAB_H +#define FAMILIA_VOCAB_H + +#include + +namespace familia { +// OOV: out of vocabulary, 表示单词不在词表中 +constexpr int OOV = -1; + +// 主题模型词表数据结构 +// 主要负责明文单词到词id之间的映射, 若单词不在词表中,则范围OOV(-1) +class Vocab { +public: + Vocab() = default; + // 范围给定明文单词的词id + int get_id(const std::string& word) const; + + // 加载词表 + void load(const std::string& vocab_file); + + // 返回词表大小 + size_t size() const; + + // no copying alowed + Vocab(const Vocab&) = delete; + Vocab& operator=(const Vocab&) = delete; +private: + // 明文到id的映射 + std::unordered_map _term2id; +}; +} // familia +#endif // FAMILIA_VOCAB_H diff --git a/include/familia/vose_alias.h b/include/familia/vose_alias.h new file mode 100644 index 0000000..bd492cb --- /dev/null +++ b/include/familia/vose_alias.h @@ -0,0 +1,43 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#ifndef FAMILIA_VOSE_ALIAS_H +#define FAMILIA_VOSE_ALIAS_H + +#include + +#include "familia/util.h" + +namespace familia { +// Vose's Alias Method 数值稳定版本实现 +// 更多的具体细节可以参考 http://www.keithschwarz.com/darts-dice-coins/ +class VoseAlias { +public: + VoseAlias() = default; + + // 根据输入分布初始化alias table + void initialize(const std::vector& distribution); + + // 从给定分布中生成采样样本 + int generate() const; + + // 离散分布的维度 + inline size_t size() const { + return _prob.size(); + } + + // no copying alowed + VoseAlias(const VoseAlias&) = delete; + VoseAlias& operator=(const VoseAlias&) = delete; + +private: + // alias table + std::vector _alias; + // probability table + std::vector _prob; +}; +} // namespace familia +#endif // FAMILIA_VOSE_ALIAS_H diff --git a/run_inference_demo.sh b/run_inference_demo.sh new file mode 100755 index 0000000..eaa34bd --- /dev/null +++ b/run_inference_demo.sh @@ -0,0 +1,8 @@ +#!/bin/bash +export LD_LIBRARY_PATH=./third_party/lib:$LD_LIBRARY_PATH + +cd model +sh download_model.sh +cd .. + +./inference_demo --work_dir="./model/news" --conf_file="lda.conf" diff --git a/run_semantic_matching_demo.sh b/run_semantic_matching_demo.sh new file mode 100755 index 0000000..afab48c --- /dev/null +++ b/run_semantic_matching_demo.sh @@ -0,0 +1,10 @@ +#!/bin/bash +export LD_LIBRARY_PATH=./third_party/lib:$LD_LIBRARY_PATH + +cd model +sh download_model.sh +cd .. + +# mode = 0 为计算短文本与长文本的主题语义相似度 +# mode = 1 为计算长文本与长文本的主题语义相似度 +./semantic_matching_demo --work_dir="./model/news" --conf_file="lda.conf" --emb_file="news_twe_lda.model" --mode=0 diff --git a/run_topic_word_demo.sh b/run_topic_word_demo.sh new file mode 100755 index 0000000..077b6a5 --- /dev/null +++ b/run_topic_word_demo.sh @@ -0,0 +1,8 @@ +#!/bin/bash +export LD_LIBRARY_PATH=./third_party/lib:$LD_LIBRARY_PATH + +cd model +sh download_model.sh +cd .. + +./topic_word_demo --work_dir="./model/news" --emb_file="news_twe_lda.model" --topic_words_file="topic_words.lda.txt" diff --git a/run_word_distance_demo.sh b/run_word_distance_demo.sh new file mode 100755 index 0000000..9fcdfd4 --- /dev/null +++ b/run_word_distance_demo.sh @@ -0,0 +1,8 @@ +#!/bin/bash +export LD_LIBRARY_PATH=./third_party/lib:$LD_LIBRARY_PATH + +cd model +sh download_model.sh +cd .. + +./word_distance_demo --work_dir="./model/news" --emb_file="news_twe_lda.model" --top_k=20 diff --git a/src/demo/inference_demo.cpp b/src/demo/inference_demo.cpp new file mode 100644 index 0000000..d968d8c --- /dev/null +++ b/src/demo/inference_demo.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#include "familia/inference_engine.h" +#include "familia/tokenizer.h" +#include "familia/util.h" + +#include +#include + +using std::string; +using std::vector; +using std::cin; +using std::cout; +using std::endl; +using namespace familia; // no lint + +DEFINE_string(work_dir, "./", "working directory"); +DEFINE_string(conf_file, "lda.conf", "model configuration file"); + +// 打印文档的主题分布 +void print_doc_topic_dist(const vector& topics) { + printf("Document Topic Distribution:\n"); + for (size_t i = 0; i < topics.size(); ++i) { + printf("%d:%f ", topics[i].tid, topics[i].prob); + } + printf("\n"); +} + +int main(int argc, char* argv[]) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + google::SetVersionString("1.0.0.0"); + string usage = string("Usage: ./semantic_matching_demo --work_dir=\"PATH/TO/MODEL\" ") + + string("--conf_file=\"lda.conf\" "); + google::SetUsageMessage(usage); + google::ParseCommandLineFlags(&argc, &argv, true); + + InferenceEngine engine(FLAGS_work_dir, FLAGS_conf_file, SamplerType::MetropolisHastings); + + std::string vocab_path = FLAGS_work_dir + "/vocab_info.txt"; + Tokenizer* tokenizer = new SimpleTokenizer(vocab_path); + + string line; + vector> sentences; + while (true) { + cout << "请输入文档:" << endl; + getline(cin, line); + vector input; + tokenizer->tokenize(line, input); + if (engine.model_type() == ModelType::LDA) { + LDADoc doc; + engine.infer(input, doc); + vector topics; + doc.topic_dist(topics); + print_doc_topic_dist(topics); + } else if (engine.model_type() == ModelType::SLDA) { + vector sent; + for (size_t i = 0; i < input.size(); ++i) { + sent.push_back(input[i]); + // 为了简化句子边界问题,以5-gram作为一个句子 + // 其中n不宜太大,否则会导致采样过程中数值计算精度下降 + if (sent.size() % 5 == 0) { + sentences.push_back(sent); + sent.clear(); + } + } + + // 剩余单词作为一个句子 + if (sent.size() > 0) { + sentences.push_back(sent); + } + + SLDADoc doc; + engine.infer(sentences, doc); + vector topics; + doc.topic_dist(topics); + print_doc_topic_dist(topics); + sentences.clear(); + } + } + + delete tokenizer; + + return 0; +} diff --git a/src/demo/semantic_matching_demo.cpp b/src/demo/semantic_matching_demo.cpp new file mode 100644 index 0000000..f2d66b3 --- /dev/null +++ b/src/demo/semantic_matching_demo.cpp @@ -0,0 +1,169 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#include "familia/inference_engine.h" +#include "familia/semantic_matching.h" +#include "familia/tokenizer.h" + +#include +#include + +using std::string; +using std::vector; +using std::ifstream; +using std::cin; +using std::cout; +using std::endl; + +DEFINE_string(work_dir, "./", "working directory"); +DEFINE_string(conf_file, "lda.conf", "model configuration"); +DEFINE_string(emb_file, "./", "Topical Word Embedding (TWE) file"); +DEFINE_int32(mode, 0, "0: query-doc similarity 1: doc-doc semantic distance"); + +namespace familia { + +// 主题模型语义匹配计算Demo类 +class SemanticMatchingDemo { +public: + SemanticMatchingDemo() : _engine(FLAGS_work_dir, FLAGS_conf_file) , + _twe(FLAGS_work_dir, FLAGS_emb_file) { + // 初始化分词器, 加载主题模型词表 + _tokenizer = new SimpleTokenizer(FLAGS_work_dir + "/vocab_info.txt"); + } + + ~SemanticMatchingDemo() = default; + + // 计算query (短文本) 与 content (长文本) 的相似度 + // 可选的指标包括: + // 1. content主题分布生成query的likelihood, 值越大相似度越高 + // 2. 基于TWE模型的相似度计算 + void cal_query_content_similarity(const string& query, const string& content) { + // 分词 + vector q_tokens, c_tokens; + _tokenizer->tokenize(query, q_tokens); + _tokenizer->tokenize(content, c_tokens); + print_tokens("Query Tokens", q_tokens); + print_tokens("Content Tokens", c_tokens); + + // 对长文本content进行主题推断,并获取主题分布 + LDADoc doc; + _engine.infer(c_tokens, doc); + vector doc_topic_dist; + doc.topic_dist(doc_topic_dist); + + float lda_sim = SemanticMatching::likelihood_based_similarity(q_tokens, + doc_topic_dist, + _engine.get_model()); + float twe_sim = SemanticMatching::twe_based_similarity(q_tokens, doc_topic_dist, _twe); + + cout << "LDA sim = " << lda_sim << "\t" + << "TWE sim = " << twe_sim << endl; + } + + // 计算长文本之间的相似度 + // 可选的指标包括常用的分布间距离jensen shannon diveregnce和hellinger distance + void cal_doc_distance(const string& doc_text1, const string& doc_text2) { + // 分词 + vector doc1_tokens, doc2_tokens; + _tokenizer->tokenize(doc_text1, doc1_tokens); + _tokenizer->tokenize(doc_text2, doc2_tokens); + print_tokens("Doc1 Tokens", doc1_tokens); + print_tokens("Doc2 Tokens", doc2_tokens); + + // 文档主题推断, 输入分词结果,主题推断结果存放于LDADoc中 + LDADoc doc1, doc2; + _engine.infer(doc1_tokens, doc1); + _engine.infer(doc2_tokens, doc2); + + // 计算jsd需要传入稠密型分布 + // 获取稠密的文档主题分布 + vector dense_dist1; + vector dense_dist2; + doc1.dense_topic_dist(dense_dist1); + doc2.dense_topic_dist(dense_dist2); + + // 计算分布之间的距离, 值越小则表示文档语义相似度越高 + float jsd = SemanticMatching::jensen_shannon_divergence(dense_dist1, dense_dist2); + float hd = SemanticMatching::hellinger_distance(dense_dist1, dense_dist2); + cout << "Jensen Shannon Divergence = " << jsd << endl + << "Hellinger Distance = " << hd << endl; + } + + // 打印分词结果 + void print_tokens(const string& title, const vector& tokens) { + cout << title << ": "; + for (size_t i = 0; i < tokens.size(); ++i) { + cout << tokens[i] << " "; + } + cout << endl; + } + +private: + InferenceEngine _engine; + // Topic Word Embedding模型 + TopicalWordEmbedding _twe; + // 分词器 + Tokenizer* _tokenizer; +}; +} // namespace familia + +// 语义匹配类型 +enum MatchingType { + QueryDocSim = 0, // 短文本与长文本相似度 + DocDistance = 1 // 文档间距离 +}; + +int main(int argc, char* argv[]) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + google::SetVersionString("1.0.0.0"); + string usage = string("Usage: ./semantic_matching_demo --work_dir=\"PATH/TO/MODEL\" ") + + string("--conf_file=\"lda.conf\" ") + + string("--emb_file=\"webpage_twe_lda.model\" "); + google::SetUsageMessage(usage); + google::ParseCommandLineFlags(&argc, &argv, true); + + familia::SemanticMatchingDemo sm_demo; + // 计算短文本与长文本的相似度 + if (FLAGS_mode == MatchingType::QueryDocSim) { + string query; + string doc; + while (true) { + cout << "请输入短文本:" << endl; + getline(cin, query); + if (query.size() == 0) { + LOG(ERROR) << "Empty input!"; + continue; + } + cout << "请输入长文本:" << endl; + getline(cin, doc); + if (doc.size() == 0) { + LOG(ERROR) << "Empty input!"; + continue; + } + sm_demo.cal_query_content_similarity(query, doc); + } + } else if (FLAGS_mode == MatchingType::DocDistance) { + string doc1; + string doc2; + while (true) { + cout << "请输入文档1:" << endl; + getline(cin, doc1); + if (doc1.size() == 0) { + LOG(ERROR) << "Empty input!"; + continue; + } + cout << "请输入文档2:" << endl; + getline(cin, doc2); + if (doc2.size() == 0) { + LOG(ERROR) << "Empty input!"; + continue; + } + sm_demo.cal_doc_distance(doc1, doc2); + } + } + + return 0; +} diff --git a/src/demo/topic_word_demo.cpp b/src/demo/topic_word_demo.cpp new file mode 100644 index 0000000..a5d3717 --- /dev/null +++ b/src/demo/topic_word_demo.cpp @@ -0,0 +1,119 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: lianrongzhong@baidu.com + +#include "familia/semantic_matching.h" +#include "familia/util.h" + +#include +#include +#include +#include + +using std::string; +using std::vector; +using std::unordered_map; +using std::make_pair; +using std::cin; +using std::cout; +using std::endl; + +DEFINE_string(work_dir, "./", "Working directory"); +DEFINE_string(emb_file, "./", "Topical Word Embedding (TWE) file"); +DEFINE_string(topic_words_file, "./", "Topic word file"); + +namespace familia { +// 主题词展示Demo类 +class TopicWordDemo { +public: + TopicWordDemo() : _twe(FLAGS_work_dir, FLAGS_emb_file) { + // 加载主题模型每个主题的返回词 + load_topic_words(FLAGS_work_dir + "/" + FLAGS_topic_words_file); + } + + ~TopicWordDemo() = default; + + // 展示同个主题下不同召回方式的结果 + void show_topics(int topic_id) { + // 获取TWE模型下每个主题最相关的词 + vector items(_top_k); + _twe.nearest_words_around_topic(topic_id, items); + print_result(items, _topic_words[topic_id]); + } + + // 打印结果 + void print_result(const vector& items, + const vector& words) { + cout << "TWE result LDA result" << endl; + cout << "---------------------------------------------" << endl; + for (int i = 0; i < _top_k; i++) { + cout << std::left << std::setw(30) << items[i].word << "\t\t" << words[i] << endl; + } + cout << endl; + } + + int num_topics() const { + return _twe.num_topics(); + } + +private: + // 读取主题模型的每个主题下的展现结果 + void load_topic_words(const string& topic_words_file) { + std::ifstream fin(topic_words_file, std::ios::in); + CHECK(fin) << "Failed to open topic word file!"; + string line; + for (int t = 0; t < num_topics(); t++) { + // 读取每个主题第一行信息,解析出topk + getline(fin, line); + vector cols; + split(cols, line, '\t'); + CHECK_EQ(cols.size(), 2) << "Format of the topic_words file error!"; + int topk = std::stoi(cols[1]); + // 读取多余行 + getline(fin, line); + // 读取前k个词 + vector words; + for (int i = 0; i < topk; i++) { + string line; + vector cols; + getline(fin, line); + split(cols, line, '\t'); + words.push_back(cols[0]); + } + _topic_words[t] = words; + } + } + + TopicalWordEmbedding _twe; + // LDA中每个主题下出现概率最高的词 + unordered_map> _topic_words; + // 每个主题下展示的词的数目, 默认为10 + static constexpr int _top_k = 10; +}; +} // namespace familia + +int main(int argc, char* argv[]) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + google::SetVersionString("1.0.0.0"); + string usage = string("Usage: ./topic_distance_demo --work_dir=\"PATH/TO/MODEL\" ") + + string("--emb_file=\"webpage_twe_lda.model\" ") + + string("--topic_words_file=\"topic_words.txt\" "); + google::SetUsageMessage(usage); + google::ParseCommandLineFlags(&argc, &argv, true); + + familia::TopicWordDemo tw_demo; + + string line; + while(true) { + cout << "请输入主题编号(0-" << tw_demo.num_topics() - 1 << "):\t"; + getline(cin, line); + int topic_id = std::stoi(line); + CHECK_GE(topic_id, 0) << "Topic out of range!"; + CHECK_LT(topic_id, tw_demo.num_topics()) << "Topic out of range!"; + // 展示主题下TWE与LDA召回的词 + tw_demo.show_topics(topic_id); + } + return 0; +} diff --git a/src/demo/word_distance_demo.cpp b/src/demo/word_distance_demo.cpp new file mode 100644 index 0000000..c9bc6d8 --- /dev/null +++ b/src/demo/word_distance_demo.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: lianrongzhong@baidu.com + +#include "familia/semantic_matching.h" + +#include +#include +#include + +using std::string; +using std::vector; +using std::cin; +using std::cout; +using std::endl; + +DEFINE_string(work_dir, "./", "working directory"); +DEFINE_string(emb_file, "./", "Topical Word Embedding (TWE) file"); +DEFINE_int32(top_k, 20, "the nearest k words"); + +namespace familia { +// 寻找距离词最邻近的K个词Demo类 +class WordDistanceDemo { +public: + WordDistanceDemo() : _twe(FLAGS_work_dir, FLAGS_emb_file) { + } + + ~WordDistanceDemo() = default; + + // 获取词距离最近的k个词,并打印出来 + void find_nearest_words(const string& word, int k) { + vector items(k); + // 如果词不存在模型词典,则返回 + if (!_twe.contains_word(word)) { + cout << word << " is out of vocabulary." << endl; + return; + } + + _twe.nearest_words(word, items); + print_result(items); + } + + // 打印结果 + void print_result(const std::vector& items) { + cout << "Word Cosine Distance " << endl; + cout << "--------------------------------------------" << endl; + for (const auto& item : items) { + cout << std::left << std::setw(24) << item.word << "\t" + << item.distance << endl; + } + cout << endl; + } + +private: + // Topic Word Embedding模型 + TopicalWordEmbedding _twe; +}; +} // namespace familia + +int main(int argc, char* argv[]) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + google::SetVersionString("1.0.0.0"); + string usage = string("Usage: ./word_distance_demo --work_dir=\"PATH/TO/MODEL\" ") + + string("--emb_file=\"webpage_twe_lda.model\" ") + + string("--top_k=\"20\" "); + google::SetUsageMessage(usage); + google::ParseCommandLineFlags(&argc, &argv, true); + + familia::WordDistanceDemo wd_demo; + + string word; + while (true) { + cout << "请输入词语:\t"; + getline(cin, word); + wd_demo.find_nearest_words(word, FLAGS_top_k); + } + + return 0; +} diff --git a/src/document.cpp b/src/document.cpp new file mode 100644 index 0000000..fac3bd0 --- /dev/null +++ b/src/document.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#include "familia/document.h" +#include "familia/util.h" + +using std::vector; +using std::string; + +namespace familia { + +// -------------LDA Begin--------------- +void LDADoc::init(int num_topics) { + _num_topics = num_topics; + _tokens.clear(); + _topic_sum.resize(_num_topics, 0); + _accum_topic_sum.resize(_num_topics, 0); +} + +void LDADoc::add_token(const Token& token) { + CHECK_GE(token.topic, 0) << "Topic " << token.topic << " out of range!"; + CHECK_LT(token.topic, _num_topics) << "Topic " << token.topic << " out of range!"; + _tokens.push_back(token); + _topic_sum[token.topic]++; +} + +void LDADoc::set_topic(int index, int new_topic) { + CHECK_GE(new_topic, 0) << "Topic " << new_topic << " out of range!"; + CHECK_LT(new_topic, _num_topics) << "Topic " << new_topic << " out of range!"; + int old_topic = _tokens[index].topic; + if (new_topic == old_topic) { + return; + } + _tokens[index].topic = new_topic; + _topic_sum[old_topic]--; + _topic_sum[new_topic]++; +} + +void LDADoc::topic_dist(vector& topic_dist, bool sort) const { + topic_dist.clear(); + size_t sum = 0; + for (int i = 0; i < _num_topics; ++i) { + sum += _accum_topic_sum[i]; + } + if (sum == 0) { + return; // 返回空结果 + } + for (int i = 0; i < _num_topics; ++i) { + // 跳过0的的项,得到稀疏主题分布 + if (_accum_topic_sum[i] == 0) { + continue; + } + topic_dist.push_back({i, _accum_topic_sum[i] * 1.0 / sum}); + } + if (sort) { + std::sort(topic_dist.begin(), topic_dist.end()); + } +} + +void LDADoc::dense_topic_dist(vector& dense_dist) const { + dense_dist.clear(); + dense_dist.resize(_num_topics, 0.0); + size_t sum = 0; + for (int i = 0; i < _num_topics; ++i) { + sum += _accum_topic_sum[i]; + } + if (sum == 0) { + return; // 返回0向量 + } + for (int i = 0; i < _num_topics; ++i) { + dense_dist[i] = _accum_topic_sum[i] * 1.0 / sum; + } +} + +void LDADoc::accumulate_topic_sum() { + for (int i = 0; i < _num_topics; ++i) { + _accum_topic_sum[i] += _topic_sum[i]; + } +} +// -------------LDA End--------------- + +// --------Sentence-LDA Begin--------- +void SLDADoc::init(int num_topics) { + _num_topics = num_topics; + _sentences.clear(); + _topic_sum.resize(_num_topics, 0); + _accum_topic_sum.resize(_num_topics, 0); +} + +void SLDADoc::add_sentence(const Sentence& sent) { + CHECK_GE(sent.topic, 0) << "Topic " << sent.topic << " out of range!"; + CHECK_LT(sent.topic, _num_topics) << "Topic " << sent.topic << " out of range!"; + _sentences.push_back(sent); + _topic_sum[sent.topic]++; +} + +void SLDADoc::set_topic(int index, int new_topic) { + CHECK_GE(new_topic, 0) << "Topic " << new_topic << " out of range!"; + CHECK_LT(new_topic, _num_topics) << "Topic " << new_topic << " out of range!"; + int old_topic = _sentences[index].topic; + if (new_topic == old_topic) { + return; + } + _sentences[index].topic = new_topic; + _topic_sum[old_topic]--; + _topic_sum[new_topic]++; +} +// --------Sentence-LDA End--------- +} // namespace familia diff --git a/src/inference_engine.cpp b/src/inference_engine.cpp new file mode 100644 index 0000000..6a9f2f0 --- /dev/null +++ b/src/inference_engine.cpp @@ -0,0 +1,101 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#include "familia/inference_engine.h" + +#include +#include +#include +#include + +namespace familia { + +InferenceEngine::InferenceEngine(const std::string& work_dir, + const std::string& conf_file, + SamplerType type) { + LOG(INFO) << "Inference Engine initializing..."; + // 读取模型配置和模型 + ModelConfig config; + load_prototxt(work_dir + "/" + conf_file, config); + _model = std::make_shared(work_dir, config); + + // 根据配置初始化采样器 + if (type == SamplerType::GibbsSampling) { + _sampler = std::unique_ptr(new GibbsSampler(_model)); + } else if (type == SamplerType::MetropolisHastings) { + _sampler = std::unique_ptr(new MHSampler(_model)); + } + + LOG(INFO) << "InferenceEngine initialize successfully!"; +} + +int InferenceEngine::infer(const std::vector& input, LDADoc& doc) { + fix_random_seed(); // 固定随机数种子, 保证同样输入下推断的的主题分布稳定 + doc.init(_model->num_topics()); + for (const auto& token : input) { + int id = _model->term_id(token); + if (id != OOV) { + int init_topic = rand_k(_model->num_topics()); + doc.add_token({init_topic, id}); + } + } + + lda_infer(doc, 20, 50); + + return 0; +} + +int InferenceEngine::infer(const std::vector>& input, SLDADoc& doc) { + fix_random_seed(); // 固定随机数种子, 保证同样输入下推断的的主题分布稳定 + doc.init(_model->num_topics()); + std::vector words; + int init_topic; + for (const auto& sent : input) { + for (const auto& token : sent) { + int id = _model->term_id(token); + if (id != OOV) { + words.push_back(id); + } + } + // 随机初始化 + init_topic = rand_k(_model->num_topics()); + doc.add_sentence({init_topic, words}); + words.clear(); + } + + slda_infer(doc, 20, 50); + + return 0; +} + +void InferenceEngine::lda_infer(LDADoc& doc, int burn_in_iter, int total_iter) const { + CHECK_GE(burn_in_iter, 0); + CHECK_GT(total_iter, 0); + CHECK_GT(total_iter, burn_in_iter); + + for (int iter = 0; iter < total_iter; ++iter) { + _sampler->sample_doc(doc); + if (iter >= burn_in_iter) { + // 经过burn-in阶段后, 对每轮采样的结果进行累积,以得到更平滑的分布 + doc.accumulate_topic_sum(); + } + } +} + +void InferenceEngine::slda_infer(SLDADoc& doc, int burn_in_iter, int total_iter) const { + CHECK_GE(burn_in_iter, 0); + CHECK_GT(total_iter, 0); + CHECK_GT(total_iter, burn_in_iter); + + for (int iter = 0; iter < total_iter; ++iter) { + _sampler->sample_doc(doc); + if (iter >= burn_in_iter) { + // 经过burn-in阶段后,对每轮采样的结果进行累积,以得到更平滑的分布 + doc.accumulate_topic_sum(); + } + } +} +} // namespace familia diff --git a/src/model.cpp b/src/model.cpp new file mode 100644 index 0000000..eba7c76 --- /dev/null +++ b/src/model.cpp @@ -0,0 +1,85 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#include "familia/model.h" + +#include + +namespace familia { + +TopicModel::TopicModel(const std::string& work_dir, const ModelConfig& config) { + _num_topics = config.num_topics(); + _beta = config.beta(); + _alpha = config.alpha(); + _alpha_sum = _alpha * _num_topics; + _topic_sum = std::vector(_num_topics, 0); + _type = config.type(); + + // 加载模型 + load_model(work_dir + "/" + config.word_topic_file(), work_dir + "/" + config.vocab_file()); +} + +uint64_t TopicModel::topic_sum(int topic_id) const { + return _topic_sum.at(topic_id); +} + +void TopicModel::load_model(const std::string& word_topic_path, + const std::string& vocab_path) { + LOG(INFO) << "Loading model: " << word_topic_path; + LOG(INFO) << "Loading vocab: " << vocab_path; + + // loading vocabulary + _vocab.load(vocab_path); + + _beta_sum = _beta * _vocab.size(); + _word_topic = std::vector(_vocab.size()); + + load_word_topic(word_topic_path); + + LOG(INFO) << "Model Info: #num_topics = " << num_topics() << " #vocab_size = " << vocab_size() + << " alpha = " << alpha() << " beta = " << beta(); +} + +void TopicModel::load_word_topic(const std::string& word_topic_path) { + LOG(INFO) << "Loading word topic from " << word_topic_path; + std::ifstream fin(word_topic_path.c_str(), std::ios::in); + CHECK(fin) << "Failed to open word topic file!"; + + std::string line; + while (getline(fin, line)) { + std::vector fields; + split(fields, line, ' '); + + CHECK_GT(fields.size(), 0) << "Model file format error!"; + + int term_id = std::stoi(fields[0]); + + CHECK_LT(term_id, vocab_size()) << "Term id out of range!"; + CHECK_GE(term_id, 0) << "Term id out of range!"; + + for (size_t i = 1; i < fields.size(); ++i) { + std::vector topic_count; + split(topic_count, fields[i], ':'); + CHECK_EQ(topic_count.size(), 2) << "Topic count format error!"; + + int topic_id = std::stoi(topic_count[0]); + CHECK_GE(topic_id, 0) << "Topic out of range!"; + CHECK_LT(topic_id, _num_topics) << "Topic out of range!"; + + int count = std::stoi(topic_count[1]); + CHECK_GT(count, 0) << "Topic count error!"; + + _word_topic[term_id].emplace_back(topic_id, count); + _topic_sum[topic_id] += count; + } + // 按照主题下标进行排序 + std::sort(_word_topic[term_id].begin(), _word_topic[term_id].end()); + } + + fin.close(); + LOG(INFO) << "Word topic load successfully!"; +} +} // namespace familia diff --git a/src/sampler.cpp b/src/sampler.cpp new file mode 100644 index 0000000..7d76a62 --- /dev/null +++ b/src/sampler.cpp @@ -0,0 +1,340 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#include "familia/sampler.h" + +namespace familia { + +void MHSampler::sample_doc(LDADoc& doc) { + for (size_t i = 0; i < doc.size(); ++i) { + int new_topic = sample_token(doc, doc.token(i)); + doc.set_topic(i, new_topic); + } +}; + +void MHSampler::sample_doc(SLDADoc& doc) { + int new_topic = 0; + for (size_t i = 0; i < doc.size(); ++i) { + new_topic = sample_sentence(doc, doc.sent(i)); + doc.set_topic(i, new_topic); + } +} + +int MHSampler::propose(int word_id) { + // 决定是否要从先验参数的alias table生成一个样本 + double dart = rand() * (_prob_sum[word_id] + _beta_prior_sum); + int topic = -1; + if (dart < _prob_sum[word_id]) { + int idx = _alias_tables[word_id].generate(); // 从alias table中生成一个样本 + topic = _topic_indexes[word_id][idx]; // 找到当前idx对应的真实主题id + } else { // 命中先验概率部分 + // 先验alias table为稠密分布,无需再做ID映射 + topic = _beta_alias.generate(); + } + + return topic; +} + +int MHSampler::sample_token(LDADoc& doc, Token& token) { + int new_topic = token.topic; + for (int i = 0; i < _mh_steps; ++i) { + int doc_proposed_topic = doc_proposal(doc, token); + new_topic = word_proposal(doc, token, doc_proposed_topic); + } + + return new_topic; +} + +int MHSampler::sample_sentence(SLDADoc& doc, Sentence& sent) { + int new_topic = sent.topic; + for (int i = 0; i < _mh_steps; ++i) { + int doc_proposed_topic = doc_proposal(doc, sent); + new_topic = word_proposal(doc, sent, doc_proposed_topic); + } + + return new_topic; +} + +int MHSampler::doc_proposal(LDADoc& doc, Token& token) { + int old_topic = token.topic; + int new_topic = old_topic; + + double dart = rand() * (doc.size() + _model->alpha_sum()); + if (dart < doc.size()) { + int token_index = static_cast(dart); + new_topic = doc.token(token_index).topic; + } else { + // 命中文档先验部分, 则随机进行主题采样 + new_topic = rand_k(_model->num_topics()); + } + + if (new_topic != old_topic) { + float proposal_old = doc_proposal_distribution(doc, old_topic); + float proposal_new = doc_proposal_distribution(doc, new_topic); + float proportion_old = proportional_funtion(doc, token, old_topic); + float proportion_new = proportional_funtion(doc, token, new_topic); + double transition_prob = (proportion_new * proposal_old) / (proportion_old * proposal_new); + double rejection = rand(); + int mask = -(rejection < transition_prob); + return (new_topic & mask) | (old_topic & ~mask); // 用位运算避免if分支判断 + } + + return new_topic; +} + +int MHSampler::doc_proposal(SLDADoc& doc, Sentence& sent) { + int old_topic = sent.topic; + int new_topic = -1; + + double dart = rand() * (doc.size() + _model->alpha_sum()); + if (dart < doc.size()) { + int token_index = static_cast(dart); + new_topic = doc.sent(token_index).topic; + } else { + // 命中文档先验部分, 则随机进行主题采样 + new_topic = rand_k(_model->num_topics()); + } + + if (new_topic != old_topic) { + float proportion_old = proportional_funtion(doc, sent, old_topic); + float proportion_new = proportional_funtion(doc, sent, new_topic); + float proposal_old = doc_proposal_distribution(doc, old_topic); + float proposal_new = doc_proposal_distribution(doc, new_topic); + double transition_prob = (proportion_new * proposal_old) / (proportion_old * proposal_new); + double rejection = rand(); + int mask = -(rejection < transition_prob); + return (new_topic & mask) | (old_topic & ~mask); + } + + return new_topic; +} + +int MHSampler::word_proposal(LDADoc& doc, Token& token, int old_topic) { + int new_topic = propose(token.id); // prpose a new topic from alias table + if (new_topic != old_topic) { + float proposal_old = word_proposal_distribution(token.id, old_topic); + float proposal_new = word_proposal_distribution(token.id, new_topic); + float proportion_old = proportional_funtion(doc, token, old_topic); + float proportion_new = proportional_funtion(doc, token, new_topic); + double transition_prob = (proportion_new * proposal_old) / (proportion_old * proposal_new); + double rejection = rand(); + int mask = -(rejection < transition_prob); + return (new_topic & mask) | (old_topic & ~mask); + } + + return new_topic; +} + +// word proposal for Sentence-LDA +int MHSampler::word_proposal(SLDADoc& doc, Sentence& sent, int old_topic) { + int new_topic = old_topic; + for (const auto& word_id : sent.tokens) { + new_topic = propose(word_id); // prpose a new topic from alias table + if (new_topic != old_topic) { + float proportion_old = proportional_funtion(doc, sent, old_topic); + float proportion_new = proportional_funtion(doc, sent, new_topic); + float proposal_old = word_proposal_distribution(word_id, old_topic); + float proposal_new = word_proposal_distribution(word_id, new_topic); + double transition_prob = (proportion_new * proposal_old) / + (proportion_old * proposal_new); + + double rejection = rand(); + int mask = -(rejection < transition_prob); + new_topic = (new_topic & mask) | (old_topic & ~mask); + } + } + + return new_topic; +} + +float MHSampler::proportional_funtion(LDADoc& doc, Token& token, int new_topic) { + int old_topic = token.topic; + float dt_alpha = doc.topic_sum(new_topic) + _model->alpha(); + float wt_beta = _model->word_topic(token.id, new_topic) + _model->beta(); + float t_sum_beta_sum = _model->topic_sum(new_topic) + _model->beta_sum(); + if (new_topic == old_topic && wt_beta > 1) { + if (dt_alpha > 1) { + dt_alpha -= 1; + } + wt_beta -= 1; + t_sum_beta_sum -= 1; + } + + return dt_alpha * wt_beta / t_sum_beta_sum; +} + +float MHSampler::proportional_funtion(SLDADoc& doc, Sentence& sent, int new_topic) { + int old_topic = sent.topic; + float result = doc.topic_sum(new_topic) + _model->alpha(); + if (new_topic == old_topic) { + result -= 1; + } + for (const auto& word_id : sent.tokens) { + float wt_beta = _model->word_topic(word_id, new_topic) + _model->beta(); + float t_sum_beta_sum = _model->topic_sum(new_topic) + _model->beta_sum(); + if (new_topic == old_topic && wt_beta > 1) { + wt_beta -= 1; + t_sum_beta_sum -= 1; + } + + result *= wt_beta / t_sum_beta_sum; + } + + return result; +} + +float MHSampler::doc_proposal_distribution(LDADoc& doc, int topic) { + return doc.topic_sum(topic) + _model->alpha(); +} + +float MHSampler::word_proposal_distribution(int word_id, int topic) { + float wt_beta = _model->word_topic(word_id, topic) + _model->beta(); + float t_sum_beta_sum = _model->topic_sum(topic) + _model->beta_sum(); + + return wt_beta / t_sum_beta_sum; +} + +int MHSampler::construct_alias_table() { + size_t vocab_size = _model->vocab_size(); + _topic_indexes = std::vector(vocab_size); + _alias_tables = std::vector(vocab_size); + _prob_sum = std::vector(vocab_size); + + // 构建每个词的alias table (不包含先验部分) + std::vector dist; + for (size_t i = 0; i < vocab_size; ++i) { + dist.clear(); + double prob_sum = 0; + for (auto& iter : _model->word_topic(i)) { + int topic_id = iter.first; // topic index + int word_topic_count = iter.second; // topic count + size_t topic_sum = _model->topic_sum(topic_id); // topic sum + + _topic_indexes[i].push_back(topic_id); + double q = word_topic_count / (topic_sum + _model->beta_sum()); + dist.push_back(q); + prob_sum += q; + } + _prob_sum[i] = prob_sum; + if (dist.size() > 0) { + _alias_tables[i].initialize(dist); + } + } + + // 构建先验参数beta的alias table + _beta_prior_sum = 0; + std::vector beta_dist(_model->num_topics(), 0); + for (int i = 0; i < _model->num_topics(); ++i) { + beta_dist[i] = _model->beta() / (_model->topic_sum(i) + _model->beta_sum()); + _beta_prior_sum += beta_dist[i]; + } + _beta_alias.initialize(beta_dist); + + return 0; +} + +void GibbsSampler::sample_doc(LDADoc& doc) { + int new_topic = -1; + for (size_t i = 0; i < doc.size(); ++i) { + new_topic = sample_token(doc, doc.token(i)); + doc.set_topic(i, new_topic); + } +} + +void GibbsSampler::sample_doc(SLDADoc& doc) { + int new_topic = -1; + for (size_t i = 0; i < doc.size(); ++i) { + new_topic = sample_sentence(doc, doc.sent(i)); + doc.set_topic(i, new_topic); + } +} + +int GibbsSampler::sample_token(LDADoc& doc, Token& token) { + int old_topic = token.topic; + int num_topics = _model->num_topics(); + std::vector accum_prob(num_topics, 0.0); + std::vector prob(num_topics, 0.0); + float sum = 0.0; + float dt_alpha = 0.0; + float wt_beta = 0.0; + float t_sum_beta_sum = 0.0; + for (int t = 0; t < num_topics; ++t) { + dt_alpha = doc.topic_sum(t) + _model->alpha(); + wt_beta = _model->word_topic(token.id, t) + _model->beta(); + t_sum_beta_sum = _model->topic_sum(t) + _model->beta_sum(); + if (t == old_topic && wt_beta > 1) { + if (dt_alpha > 1) { + dt_alpha -= 1; + } + wt_beta -= 1; + t_sum_beta_sum -= 1; + } + prob[t] = dt_alpha * wt_beta / t_sum_beta_sum; + sum += prob[t]; + accum_prob[t] = (t == 0 ? prob[t] : accum_prob[t - 1] + prob[t]); + } + + double dart = rand() * sum; + if (dart <= accum_prob[0]) { + return 0; + } + for (int t = 1; t < num_topics; ++t) { + if (dart > accum_prob[t - 1] && dart <= accum_prob[t]) { + return t; + } + } + + return num_topics - 1; // 返回最后一个主题id +} + +int GibbsSampler::sample_sentence(SLDADoc& doc, Sentence& sent) { + int old_topic = sent.topic; + int num_topics = _model->num_topics(); + std::vector accum_prob(num_topics, 0.0); + std::vector prob(num_topics, 0.0); + float sum = 0.0; + float dt_alpha = 0.0; + float t_sum_beta_sum = 0.0; + float wt_beta = 0.0; + // 为了保证数值计算的稳定,以下实现为SentenceLDA的采样近似实现 + // TODO: 添加近似方法相关论文 + for (int t = 0; t < num_topics; ++t) { + dt_alpha = doc.topic_sum(t) + _model->alpha(); + t_sum_beta_sum = _model->topic_sum(t) + _model->beta_sum(); + if (t == old_topic) { + if (dt_alpha > 1) { + dt_alpha -= 1; + } + if (t_sum_beta_sum > 1) { + t_sum_beta_sum -= 1; + } + } + prob[t] = dt_alpha; + for (size_t i = 0; i < sent.tokens.size(); ++i) { + int w = sent.tokens[i]; + wt_beta = _model->word_topic(w, t) + _model->beta(); + if (t == old_topic && wt_beta > 1) { + wt_beta -= 1; + } + // NOTE: 若句子长度过长,此处连乘项过多会导致概率过小, 丢失精度 + prob[t] *= wt_beta / t_sum_beta_sum; + } + sum += prob[t]; + accum_prob[t] = (t == 0 ? prob[t] : accum_prob[t - 1] + prob[t]); + } + double dart = rand() * sum; + if (dart <= accum_prob[0]) { + return 0; + } + for (int t = 1; t < num_topics; ++t) { + if (dart > accum_prob[t - 1] && dart <= accum_prob[t]) { + return t; + } + } + + return num_topics - 1; // 返回最后一个主题id +} +} // namespace familia diff --git a/src/semantic_matching.cpp b/src/semantic_matching.cpp new file mode 100644 index 0000000..d18c1e4 --- /dev/null +++ b/src/semantic_matching.cpp @@ -0,0 +1,121 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: jiangjiajun@baidu.com + +#include "familia/semantic_matching.h" +#include "familia/util.h" + +using std::vector; +using std::string; + +namespace familia { + +// 根据topic id返回对应的topic embedding +Embedding& TopicalWordEmbedding::topic_emb(int topic_id) { + CHECK_GE(topic_id, 0) << "Topic id out of range"; + CHECK_LT(topic_id, _num_topics) << "Topic id out of range"; + + return _topic_emb[topic_id]; +} + +Embedding& TopicalWordEmbedding::word_emb(const string& term) { + CHECK(contains_word(term)) << term << " out of vocabulary!"; + return _word_emb[term]; +} + +bool TopicalWordEmbedding::contains_word(const string& term) const { + return _word_emb.find(term) == _word_emb.end() ? false : true; +} + +int TopicalWordEmbedding::num_topics() const { + return _num_topics; +} + +int TopicalWordEmbedding::load_emb(const string& emb_file) { + LOG(INFO) << "Loading Topical Word Embedding (TWE)..."; + FILE* fin_emb = fopen(emb_file.c_str(), "rb"); + CHECK(fin_emb) << "Error to open embedding file!"; + + fscanf(fin_emb, "%d%d%d\n", &_vocab_size, &_num_topics, &_emb_size); + + LOG(INFO) << "#word = " << _vocab_size + << " #topic = " << _num_topics + << " #emb_size = " << _emb_size; + + const int MAX_TOKEN_LENGTH = 50; + char term[MAX_TOKEN_LENGTH]; + Embedding emb(_emb_size, 0); + int total_num = _vocab_size + _num_topics; + // LTE模型存储格式: + // 单词明文 \空格 二进制embedding \n + // 0~vocab_size-1行为word embedding....... + // _topic_#ID \空格 二进制embedding \n + // 随后num_topics行为topic embedding + for (int i = 0; i < total_num; ++i) { + if (i % 100000 == 0) { + LOG(INFO) << "Loading embedding #id = " << i; + } + fscanf(fin_emb, "%s", term); + fgetc(fin_emb); // 跳过空格 + if (i < _vocab_size) { + // 加载word embedding + _word_emb[term] = Embedding(_emb_size, 0); + fread(_word_emb[term].data(), sizeof(Embedding::value_type), _emb_size, fin_emb); + fgetc(fin_emb); // 跳过\n + } else { + // 加载topic embedding + fread(emb.data(), sizeof(Embedding::value_type), _emb_size, fin_emb); + fgetc(fin_emb); // 跳过\n + _topic_emb.push_back(emb); + } + } + fclose(fin_emb); + LOG(INFO) << "Load Topical Word Embedding (TWE) successully!"; + + return 0; +} + +void TopicalWordEmbedding::nearest_words(const string& word, + std::vector& items) { + Embedding& target_word_emb = word_emb(word); + int num_k = items.size(); + for (const auto& it : _word_emb) { + // 如果与目标词相同,跳过 + if (it.first == word) { + continue; + } + float dist = SemanticMatching::cosine_similarity(target_word_emb, it.second); + for (int i = 0; i < num_k; i++) { + if (dist > items[i].distance) { + for (int j = num_k - 1; j > i; j--) { + items[j] = items[j - 1]; + } + items[i].word = it.first; + items[i].distance = dist; + break; + } + } + } +} + +void TopicalWordEmbedding::nearest_words_around_topic(int topic_id, + std::vector& items) { + Embedding& target_topic_emb = topic_emb(topic_id); + int num_k = items.size(); + for (const auto& it : _word_emb) { + float dist = SemanticMatching::cosine_similarity(target_topic_emb, it.second); + for (int i = 0; i < num_k; i++) { + if (dist > items[i].distance) { + for (int j = num_k - 1; j > i; j--) { + items[j] = items[j - 1]; + } + items[i].word = it.first; + items[i].distance = dist; + break; + } + } + } +} +} // namespace familia diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp new file mode 100644 index 0000000..f763cad --- /dev/null +++ b/src/tokenizer.cpp @@ -0,0 +1,77 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#include "familia/tokenizer.h" + +#include +#include +#include + +namespace familia { + +void SimpleTokenizer::tokenize(const std::string& text, std::vector& result) const { + result.clear(); + std::string word; + std::string found_word; + int text_len = text.size(); + for (int i = 0; i < text_len; ++i) { + word.clear(); + found_word = ""; + // 处理英文字符的分支 + if (is_eng_char(text[i])) { + // 遍历至字符串末尾\0以保证纯英文串切分 + for (int j = i; j <= text_len; ++j) { + // 一直寻找英文字符,直到遇到非英文字符串 + if (j < text_len && is_eng_char(text[j])) { + // 词表中只包含小写字母单词, 对所有英文字符均转小写 + word.push_back(tolower(text[j])); + } else { + // 按字符粒度正向匹配 + if (_vocab.find(word) != _vocab.end()) { + result.push_back(word); + } + i = j - 1; + break; + } + } + } else { + for (int j = i; j < i + _max_word_len && j < text_len; ++j) { + word.push_back(text[j]); + if (_vocab.find(word) != _vocab.end()) { + found_word = word; + } + } + if (found_word.size() > 0) { + result.push_back(found_word); + i += found_word.size() - 1; + } + } + } +} + +bool SimpleTokenizer::contains(const std::string& word) const { + return _vocab.find(word) != _vocab.end(); +} + +void SimpleTokenizer::load_vocab(const std::string& vocab_path) { + LOG(INFO) << "Loading vocabulary file from " << vocab_path; + std::ifstream fin(vocab_path, std::ios::in); + CHECK(fin) << "Open vocabulary file failed!"; + + std::string line; + int vocab_size = 0; + while (getline(fin, line)) { + std::vector fields; + split(fields, line, '\t'); + CHECK_GE(fields.size(), 2); + std::string word = fields[1]; + _max_word_len = std::max(static_cast(word.size()), _max_word_len); + _vocab.insert(word); + ++vocab_size; + } + LOG(INFO) << "Vocabulary load successfully! #vocab_size = " << vocab_size; +} +} // namespace familia diff --git a/src/util.cpp b/src/util.cpp new file mode 100644 index 0000000..794bc58 --- /dev/null +++ b/src/util.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#include "familia/util.h" + +namespace familia { + +void split(std::vector& result, const std::string& text, char separator) { + size_t start = 0; + size_t end = 0; + while ((end = text.find(separator, start)) != std::string::npos) { + std::string substr = text.substr(start, end - start); + result.push_back(std::move(substr)); + start = end + 1; + } + // NOTE: 如果输入没有分割字符,则返回原输入 + result.push_back(text.substr(start)); +} +} // namespace familia diff --git a/src/vocab.cpp b/src/vocab.cpp new file mode 100644 index 0000000..bf62b70 --- /dev/null +++ b/src/vocab.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#include "familia/vocab.h" +#include "familia/util.h" + +#include +#include + +namespace familia { + +int Vocab::get_id(const std::string& term) const { + auto it = _term2id.find(term); + return it == _term2id.end() ? OOV : it->second; +} + +size_t Vocab::size() const { + return _term2id.size(); +} + +void Vocab::load(const std::string& vocab_file) { + _term2id.clear(); + std::ifstream fin(vocab_file, std::ios::in); + CHECK(fin) << "Failed to open vocab file!"; + + std::string line; + std::vector term_id; + while (getline(fin, line)) { + term_id.clear(); + split(term_id, line, '\t'); + CHECK_EQ(term_id.size(), 5) << "Vocabulary file [" << vocab_file << "] format error!"; + std::string term = term_id[1]; + int id = std::stoi(term_id[2]); + if (_term2id.find(term) != _term2id.end()) { + LOG(ERROR) << "Duplicate word [" << term << "] in vocab file"; + continue; + } + _term2id[term] = id; + } + fin.close(); + + LOG(INFO) << "Load vocabulary success! #vocabulary size = " << size(); +} +} // namespace familia diff --git a/src/vose_alias.cpp b/src/vose_alias.cpp new file mode 100644 index 0000000..c825dec --- /dev/null +++ b/src/vose_alias.cpp @@ -0,0 +1,68 @@ +// Copyright (c) 2017, Baidu.com, Inc. All Rights Reserved +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Author: chenzeyu01@baidu.com + +#include "familia/vose_alias.h" +#include "familia/util.h" + +#include + +namespace familia { + +void VoseAlias::initialize(const std::vector& distribution) { + int size = distribution.size(); + _prob.resize(size); + _alias.resize(size); + std::vector p(size, 0.0); + double sum = 0; + for (int i = 0; i < size; ++i) { + sum += distribution[i]; + } + for (int i = 0; i < size; ++i) { + p[i] = distribution[i] / sum * size; // scale up probability + } + + std::queue large; + std::queue small; + for (int i = 0; i < size; ++i) { + if (p[i] < 1.0) { + small.push(i); + } else { + large.push(i); + } + } + while (!small.empty() && !large.empty()) { + int l = small.front(); + int g = large.front(); + small.pop(); + large.pop(); + _prob[l] = p[l]; + _alias[l] = g; + p[g] = p[g] + p[l] - 1; // a more numerically stable option + if (p[g] < 1.0) { + small.push(g); + } else { + large.push(g); + } + } + while (!large.empty()) { + int g = large.front(); + large.pop(); + _prob[g] = 1.0; + } + while (!small.empty()) { + int l = small.front(); + small.pop(); + _prob[l] = 1.0; + } +} + +int VoseAlias::generate() const { + int dart1 = rand_k(size()); + int dart2 = rand(); + + return dart2 > _prob[dart1]? dart1 : _alias[dart1]; +} +} // namespace familia