Skip to content

基于向量召回的检索式对话系统解决方案,dense retrieval,FAQ……

Notifications You must be signed in to change notification settings

iseesaw/FAQ-Semantic-Retrieval

Repository files navigation

FAQ-Semantic-Retrieval

一种 FAQ 向量语义检索解决方案

项目介绍

FAQ 的处理流程一般为:

  • 问题理解,对用户 query 进行改写以及向量表示
  • 召回模块,在问题集上进行候选问题召回,获得 topk(基于关键字的倒排索引 vs 基于向量的语义召回)
  • 排序模块,对 topk 进行精排序

本项目着眼于 召回模块向量检索 的实现,适用于 小规模 FAQ 问题集(候选问题集<10万)的系统快速搭建

FAQ 语义检索

传统召回模块基于关键字检索

随着语义表示模型的增强、预训练模型的发展,基于 BERT 向量的语义检索得到广泛应用

  • 对候选问题集合进行向量编码,得到 corpus 向量矩阵
  • 当用户输入 query 时,同样进行编码得到 query 向量表示
  • 然后进行语义检索(矩阵操作,KNN,FAISS)

本项目针对小规模 FAQ 问题集直接计算 query 和 corpus 向量矩阵的余弦相似度,从而获得 topk 候选问题 $$ score = \frac{V_{query} \cdot V_{corpus}}{||V_{query}|| \cdot ||V_{corpus}||} $$ 句向量获取解决方案

Python Lib Framework Desc Example
bert-as-serivce TensorFlow 高并发服务调用,支持 fine-tune,较难拓展其他模型 getting-started
Sentence-Transformers PyTorch 接口简单易用,支持各种模型调用,支持 fine-turn(单GPU) using-Sentence-Transformers-model
using-Transformers-model
🤗 Transformers PyTorch 自定义程度高,支持各种模型调用,支持 fine-turn(多GPU) sentence-embeddings-with-Transformers
  • Sentence-Transformers 进行小规模数据的单 GPU fine-tune 实验(尚不支持多 GPU 训练,Multi-GPU-training #311 ;实现了多种 Ranking loss 可供参考)
  • Transformers 进行大规模数据的多 GPU fine-tune 训练(推荐自定义模型使用 Trainer 进行训练)
  • 实际使用过程中 Sentence-TransformersTransformers 模型基本互通互用,前者多了 Pooling 层(Mean/Max/CLS Pooling) ,可参考 Example
  • 🔥 实际上线推荐直接使用 Transformers 封装,Sentence-Transformers 在 CPU 服务器上运行存在位置问题。

BERT 微调与蒸馏

在句向量获取中可以直接使用 bert-base-chinese 作为编码器,但在特定领域数据上可能需要进一步 fine-tune 来获取更好的效果

fine-tune 过程主要进行文本相似度计算任务,亦句对分类任务;此处是为获得更好的句向量,因此使用**双塔模型(SiameseNetwork ,孪生网络)**微调,而非常用的基于表示的模型 BertForSequenceClassification

BertForSiameseNetwork

BertForSiameseNetwork 主要步骤如下

  • Encoding,使用(同一个) BERT 分别对 query 和 candidate 进行编码
  • Pooling,对最后一层进行池化操作获得句子表示(Mean/Max/CLS Pooling)
  • Computing,计算两个向量的余弦相似度(或其他度量函数),计算 loss 进行反向传播

损失函数

模型训练使用的损失函数为 Ranking loss,不同于CrossEntropy 和 MSE 进行分类和回归任务,Ranking loss 目的是预测输入样本对(即上述双塔模型中 $u$$v$ 之间)之间的相对距离(度量学习任务

  • Contrastive Loss

    来自 LeCun Dimensionality Reduction by Learning an Invariant Mapping

    sentence-transformers 源码实现

    • 公式形式如下,其中 $u, v$ 为 BERT 编码的向量表示,$y$ 为对应的标签(1 表示正样本,0 表示负样本), $\tau$ 为超参数 $$ L(u_i, v_i, y_i) = y_i ||u_i, v_i|| + (1 - y_i) \max(0, \tau - ||u_i, v_i|| $$

    • 公式意义为:对于正样本,输出特征向量之间距离要尽量小;而对于负样本,输出特征向量间距离要尽量大;但是若负样本间距太大(即容易区分的简单负样本,间距大于 $\tau$则不处理,让模型关注更加难以区分的样本

  • OnlineContrastive Loss

    • 属于 online negative sampling ,与 Contrastive Loss 类似

    • 参考 sentence-transformers 源码实现 ,在每个 batch 内,选择最难以区分的正例样本和负例样本进行 loss 计算(容易识别的正例和负例样本则忽略)

    • 公式形式如下,如果正样本距离小于 $\tau_1$ 则不处理,如果负样本距离大于 $\tau_0$ 则不处理,实现过程中 $\tau_0, \tau_1$ 可以分别取负/正样本的平均距离值 $$ L(u_i, v_i, y_i) \begin{cases} \max (0, ||u_i, v_i|| - \tau_1) ,\ if \ y_i=1 & \ \max (0, \tau_0 - ||u_i, v_i||),\ if \ y_i =0 \end{cases} $$

本项目使用 OnlineContrastive Loss ,更多 Ranking loss 信息可参考博客 Understanding Ranking Loss, Contrastive Loss, Margin Loss, Triplet Loss, Hinge Loss and all those confusing names ,以及 SentenceTransformers 中的 Loss API 以及 PyTorch 中的 margin_ranking_loss

数据集

  • 文本相似度数据集

    • 相关论文比赛发布的数据集可见 文本相似度数据集 ,大部分为金融等特定领域文本,其中 LCQMC 提供基于百度知道的约 20w+ 开放域问题数据集,可供模型测试

      data total positive negative
      training 238766 138574 100192
      validation 8802 4402 4400
      test 12500 6250 6250
    • 除此以外,百度千言项目发布了文本相似度评测,包含 LCQMC/BQ Corpus/PAWS-X 等数据集,可供参考

  • FAQ数据集

    • 内部给定的 FAQ 数据集形式如下,包括各种”主题/问题“,每种“主题/问题”可以有多种不同表达形式的问题 post,同时对应多种形式的回复 resp

    • 检索时只需要将 query 与所有 post 进行相似度计算,从而召回最相似的 post ,然后获取对应的 “主题/问题” 的所有回复 resp ,最后随机返回一个回复即可

      {
         "晚安": {
             "post": [
                  "10点了,我要睡觉了",
                  "唉,该休息了",
               		...
              ],
             "resp": [
                  "祝你做个好梦",
                  "祝你有个好梦,晚安!",
                  ...
              ]
           },
        	"感谢": {
            	"post": [
              		"多谢了",
             		 "非常感谢",
                ...
            	],
            	"resp": [
              		"助人为乐为快乐之本",
              		"别客气",
                	...
            	]
          },
        	...
      }
    • 内部FAQ数据包括两个版本

      • chitchat-faq-small,主要是小规模闲聊FAQ,1500主题问题(topic)、2万多不同形式问题(post)
      • entity-faq-large,主要是大规模实体FAQ(涉及业务问题),大约3-5千主题问题(topic)、12万不同形式问题(post)

负采样

对于每个 query,需要获得与其相似的 positve candidate 以及不相似的 negtive candidate,从而构成正样本和负样本作为模型输入,即 (query, candidate)

⚠️ 此处为 offline negtive sampling,即在训练前采样构造负样本,区别于 online negtive sampling,后者在训练中的每个 batch 内进行动态的负采样(可以通过相关损失函数实现,如 OnlineContrastive Loss

两种方法可以根据任务特性进行选择,online negtive sampling 对于数据集有一定的要求,需要确保每个 batch 内的 query 是不相似的,但是效率更高

对于 offline negtive sampling 主要使用以下两种方式采样:

  • 全局负采样
    • 在整个数据集上进行正态分布采样,很难产生高难度的负样本
  • 局部负采样
    • 首先使用少量人工标注数据预训练的 BERT 模型对候选问题集合进行编码
    • 然后使用无监督聚类 ,如 Kmeans
    • 最后在每个 query 所在聚类簇中进行采样

实验中对已有 FAQ 数据集中所有主题的 post 进行 9:1 划分得到训练集和测试集,负采样结果对比

chitchat-faq-small 为需要上线的 FAQ 闲聊数据,entity-faq-large 为辅助数据

dataset topics posts positive(sampling) negative(sampling) total(sampling)
chitchat-faq-small
train
1468 18267 5w+ 5w+ 10w+
chitchat-faq-small
test
768 2030 2984 7148 10132
chitchat-faq-small 1500 20297 - - -
entity-faq-large - 12w+ 50w+ 50w+ 100w+

模型蒸馏

使用基于 Transformers 的模型蒸馏工具 TextBrewer ,主要参考 官方 入门示例cmrc2018示例

FAQ Web服务

Web API

  • Web 框架选择
    • Flask + Gunicorn + gevent + nginx ,进程管理(崩溃自动重启)(uwsgi 同理,gunicorn 更简单)
    • 🔥 FastAPI + uvicorn(崩溃自动重启),最快的Python Web框架(实测的确比 Flask 快几倍)
  • cache 缓存机制(保存最近的query对应的topic,命中后直接返回)
    • Flask 相关
      • flask-caching (默认缓存500,超时300秒),使用 set/get 进行数据操作;项目来源于 pallets/werkzeug (werkzeug 版本0.4以后弃用 cache)
    • Python 3.2 以上自带(FastAPI 中可使用)

Locust 压力测试

使用 Locust 编写压力测试脚本

使用说明

主要依赖参考 requirements.txt

pip install -r requirements.txt

负采样

python sampling.py \
	--filename='faq/train_faq.json' \
	--model_name_or_path='./model/bert-base-chinese' \
	--is_transformers=True \
	--hyper_beta=2 \
	--num_pos=5 \
	--local_num_negs=3 \
	--global_num_negs=2 \
	--output_dir='./samples'

主要参数说明

  • --filenamefaq 数据集,按前文所述组织为 {topic: {post:[], resp:[]}} 格式
  • --model_name_or_path ,用于句向量编码的 Transformers 预训练模型位置(bert-base-chinese 或者基于人工标注数据微调后的模型)
  • --hyper_beta聚类数超参数,聚类类别为 n_cluster=num_topics/hyper_beta ,其中 num_topics 为上述数据中的主题数,hyper_beta 默认为 2(过小可能无法采样到足够局部负样本)
  • --num_pos正采样个数,默认 5(注意正负比例应为 1:1)
  • --local_num_negs局部负采样个数,默认 3(该值太大时,可能没有那么多局部负样本,需要适当调低正采样个数,保证正负比例为 1:1)
  • --global_num_negs全局负采样个数,默认 2
  • --is_split ,是否进行训练集拆分,默认 False(建议直接在 faq 数据上进行拆分,然后使用评估语义召回效果)
  • --test_size ,测试集比例,默认 0.1
  • --output_dir ,采样结果文件保存位置(sentence1, sentence2, label 形式的 csv 文件)

BERT 微调

  • 参考 Sentence-Transformers 的 raining_OnlineConstrativeLoss.py 修改,适合单 GPU 小规模样本训练

    • 模型训练

      CUDA_VISIBLE_DEVICES=0 python sentence_transformers_train.py \
        --do_train \
        --model_name_or_path='./model/bert-base-chinese' \
        --trainset_path='./lcqmc/LCQMC_train.csv' \
        --devset_path='./lcqmc/LCQMC_dev.csv' \
        --testset_path='./lcqmc/LCQMC_test.csv' \
        --train_batch_size=128 \
        --eval_batch_size=128 \
        --model_save_path
    • 主要参数说明

      • 模型预测时则使用 --do_eval
      • 数据集为 sentence1, sentence2, label 形式的 csv 文件
      • 16G 显存设置 batch size 为 128
  • 使用 Transformers 自定义数据集和 BertForSiameseNetwork 模型并使用 Trainer 训练,适合多 GPU 大规模样本训练

    CUDA_VISIBLE_DEVICES=0,1,2,3 python transformers_trainer.py \
    	--do_train=True \
    	--do_eval=True \
    	--do_predict=False \
    	--model_name_or_path='./model/bert-base-chinese' \
    	--trainset_path='./samples/merge.csv' \
    	--devset_path='./samples/test.csv' \
    	--testset_path='./samples/test.csv' \
    	--output_dir='./output/transformers-merge-bert'
  • 使用 TransformersBertForSequenceClassification 进行句对分类对比实验

    CUDA_VISIBLE_DEVICES=0 python bert_for_seq_classification.py \
    	--do_train=True \
    	--do_eval=True \
    	--do_predict=False \
      --trainset_path='./lcqmc/LCQMC_train.csv' \
      --devset_path='./lcqmc/LCQMC_dev.csv' \
      --testset_path='./lcqmc/LCQMC_test.csv' \
      --output_dir='./output/transformers-bert-for-seq-classify'

模型蒸馏

使用 TextBrewer 以及前文自定义的 SiameseNetwork 进行模型蒸馏

CUDA_VISIBLE_DEVICES=0 python model_distillation.py \
	--teacher_model='./output/transformers-merge-bert' \
	--student_config='./distills/bert_config_L3.json' \
	--bert_model='./model/bert-base-chinese' \
	--train_file='./samples/train.csv' \
	--test_file='./samples/test.csv' \
	--output_dir='./distills/outputs/bert-L3'

主要参数说明:

  • 此处使用的 bert_config_L3.json 作为学生模型参数,更多参数 student_config 或者自定义
  • 3层应用于特定任务效果不错,但对于句向量获取,至少得蒸馏 6层
  • 学生模型可以使用 bert-base-chinese 的前几层初始化

Web服务

  • 服务启动(gunicornuvicorn 均支持多进程启动以及失败重启)

    • Flask

      gunicorn -w 1 -b 127.0.0.1:8888 faq_app_flask:app
    • FastAPI 🔥 (推荐) ​

      uvicorn faq_app_fastapi:app --reload --port=8888
  • 压力测试 Locust ,实现脚本参考 locust_test.py

结果及分析

微调结果

对于SiameseNetwork,需要在开发集上确定最佳阈值,然后测试集上使用该阈值进行句对相似度结果评价

句对测试集评价结果,此处为 LCQMC 的实验结果

model acc(dev/test) f1(dev/test)
BertForSeqClassify
🚂 lcqmc
0.8832/0.8600 0.8848/0.8706
SiameseNetwork
🚂 lcqmc
0.8818/0.8705 0.8810/0.8701

基于表示和基于交互的模型效果差别并不大

语义召回结果

此处为 FAQ 数据集的召回结果评估,将训练集 post 作为 corpus,测试集 post 作为 query 进行相似度计算

model hit@1(chitchat-faq-small) hit@1(entity-faq-large)
lucene bm25 (origin) 0.6679 -
bert-base-chinese 0.7394 0.7745
bert-base-chinese
👆 6 layers
0.7276 -
SiameseNetwork
🚂 chit-faq-small
0.8567 0.8500
SiameseNetwork
🚂 chitchat-faq-small + entity-faq-large
0.8980 0.9961
👆 6 layers 🔥 0.9128 0.8201
  • chitchat-faq-small
    • 测试集 hit@1 大约 85% 左右
    • 错误原因主要是 hflqa 数据问题
      • 数据质量问题,部分 topic 意思相同,可以合并
      • 一些不常用表达或者表达不完整的句子
      • 正常对话的召回率还是不错的
  • chitchat-faq-small + entity-faq-large
    • 2000 chitchat-faq-small 测试集,6层比12层效果好一个点,hit@1 达到 90%
    • 10000 entity-faq-large 测试集,12层 hit@1 达到 99%,6层只有 82%
    • 底层学到了较为基础的特征,在偏向闲聊的 chitchat-faq-small 上仅使用6层效果超过12层(没有蒸馏必要)
    • 高层学到了较为高级的特征,在偏向实体的 entity-faq-large 上12层效果远超于6层
    • 另外,entity-faq-large 数量规模远大于 chitchat-faq-small ,因此最后几层分类器偏向于从 entity-faq-large 学到的信息,因此在 chitchat-faq-small 小效果略有下降;同时能够避免 chitchat-faq-small 数据过拟合

Web服务压测

  • 运行命令说明

    总共 100 个模拟用户,启动时每秒递增 10 个,压力测试持续 3 分钟

    locust  -f locust_test.py  --host=http://127.0.0.1:8889/module --headless -u 100 -r 10 -t 3m
  • ⌛ 配置 4核8G CPU (6层小模型占用内存约 700MB)

    • 小服务器上 bert-as-service 服务非常不稳定(tensorflow各种报错), 效率不如简单封装的 TransformersEncoder
    • FastAPI 框架速度远胜于 Flask,的确堪称最快的 Python Web 框架
    • cache 的使用能够大大提高并发量和响应速度(最大缓存均设置为500
    • 最终推荐配置 🔥 TransformersEncoder + FastAPI + functools.lru_cache
model Web Cache Users req/s reqs fails Avg Min Max Median fails/s
lucene bm25 (origin) flask werkzeug 1000 271.75 48969 0 91 3 398 79 0.00
BertSiameseNet
​​6 layers
Transformers
flask flask-caching 1000 24.55 4424 654 28005 680 161199 11000 3.63
BertSiameseNet
6 layers
Transformers
fastapi lru_cache 1000 130.87 23566 1725 3884 6 127347 26 9.58
lucene bm25 (origin) flask werkzeug 100 27.66 4973 1 32 6 60077 10 0.01
BertSiameseNet
6 layers
bert-as-service
flask flask-caching 100 5.49 987 0 13730 357 17884 14000 0.00
BertSiameseNet
6 layers
Transformers
flask flask-caching 100 5.93 1066 0 12379 236 17062 12000 0.00
BertSiameseNet
🔥 6 layers
Transformers
fastapi lru_cache 100 22.19 3993 0 824 10 2402 880 0.00
BertSiameseNet
6 layers
transformers
fastapi None 100 18.17 1900 0 1876 138 3469 1900 0.00

使用 bert-as-service 遇到的一些问题:

测试服部署时报错 src/tcmalloc.cc:277] Attempt to free invalid pointer,通过改变import顺序来解决,在 numpy 之后 import pytorch

更多

About

基于向量召回的检索式对话系统解决方案,dense retrieval,FAQ……

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages