Skip to content

Latest commit

 

History

History

RocketQA

RocketQA PyTorch re-implementation

百度开源的信息检索框架

其原代码在 GitHub 上有, 但是是在飞桨框架下实现的, 这里给出 PyTorch 实现的主要代码。 这里使用的是 DuReader Retrieval 数据集, 由于实验需要非常多的 GPU 资源, 同时测试时间也非常的漫长, 因此无法将实验完全做完, 有条件的可以尝试完成, 没有条件的看看就好 ...

GPU 需求的资源多主要体现在以下一些方面:

  • 训练需求的显存大, 这主要体现在 RocketQA V2, batch size 为 1 也不一定能训练的起来
    • 在百度的源码中, MSMARCO 数据集的 passage 长度设置是 96, DuReader Retrieval 的 passage 长度是 384, 你可以尝试将其调小 (在 DuReader Retrieval 的论文中, 作者说明了尽量让文本长度大于 256, 文本长度调的太小可能会影响性能)
    • 在百度的源码中, MSMARCO 数据集的 list 大小为 128, NQ 数据集的 list 大小为 32, 你可以尝试调小 list 的大小, 这样也可以减轻数据增强部分的计算量
    • 可以改变 query 的计算方式, 详见代码注释, 性能可能也会下降 (因为训练变简单了)
  • 测试的运算量大, 由于 DPR 架构的特点, 每一次测试都需要对800万的 passage 文本进行句向量编码, 单张3090需要16个小时才能编码完成 (含分词时间和 faiss 索引建立时间)
  • 数据增强运算量也很大, 训练如果有90万个 query, 每一个 query 只对 top-50 进行清洗, 需要进行4500万次的句向量编码
    • 百度给了 MSMARCO 和 NQ 数据集增强后的结果, 但是对于 DuReader Retrieval, 只给出了 negative passage (已经很人性化啦), 如果要跑数据增强, 一定要确保代码的正确性再跑, 或者分阶段跑 !!!

如果你有 4 张 3090, 并且有两个星期的时间, 在代码不出错的情况下, 应该可以完成整套 RocketQA 的实验

百度在 PaddleNLP 中有相关的集成, 采用 SimCSE 对训练好的 RocketQA query 编码器进行微调, 具体可以参考: 政务问答案例

相关论文地址:

文件说明:

  • DuReader Retrieval 数据的预处理和分词: code
  • RocketQA v2 模型架构和显存测试: code
  • PAIR 模型架构: code
  • RocketQA v1 dual encoder 完整的训练代码: code
  • RocketQA v1 dual encoder 生成测试文件的代码: code
  • 测试代码: code

requirements:

faiss==1.7.2
torch==1.13.0
transformers==4.24.0
pytorch-lightning==1.8.2

RocketQA v1 dual encoder 运行方式:

python examples/RocketQA/01_train_basic_dual_encoder.py
python examples/RocketQA/02_test_basic_dual_encoder.py
python examples/RocketQA/evaluation.py
Model MRR@10 recall@1 recall@50
dual-encoder 24.46 15.40 72.60

效果不好, 我认为的主要原因有:

原代码中 batch size 是 128, 我这里的 batch size 是 20, 差距有点大, 在对比学习中, 负样本的数量对模型的训练影响还是很大的

原代码中训练了 3 个 epochs, 我只训练了 2 个 epochs, 多训练一些效果肯能会好一些

改进方案:

  • 有能力的话增加 batch size, 并多训练一轮
  • 增加 eval 过程, 评测方式是计算 accuracy (multiple choice 是否选择正确), 并用其选择模型