Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[RAG/FiD/BB2] Incremental Decoding #4088

Merged
merged 4 commits into from Oct 13, 2021
Merged

[RAG/FiD/BB2] Incremental Decoding #4088

merged 4 commits into from Oct 13, 2021

Conversation

klshuster
Copy link
Contributor

@klshuster klshuster commented Oct 13, 2021

Patch description

I've implemented incremental decoding for FiD and RAG (and, by extension, BlenderBot2).

CC #4073

Testing steps

CI

Existing CI, both for rag and bb2. Local runs:

$ pytest test_rag.py  -x
================================================================================ test session starts ================================================================================
platform linux -- Python 3.7.9, pytest-6.2.1, py-1.10.0, pluggy-1.0.0.dev0
rootdir: /private/home/kshuster/ParlAI, configfile: pytest.ini
plugins: hydra-core-1.0.7, requests-mock-1.8.0, regressions-2.1.1, datadir-1.3.1
collected 38 items

test_rag.py ......................................                                                                                                                            [100%]

=============================================================================== slowest 10 durations ================================================================================
72.64s call     tests/nightly/gpu/test_rag.py::TestRagTfidf::test_rag_token
69.52s call     tests/nightly/gpu/test_rag.py::TestRagZooModels::test_bart_rag_dpr_poly
53.18s call     tests/nightly/gpu/test_rag.py::TestFidZooModels::test_bart_fid_rag_dpr_poly
53.09s call     tests/nightly/gpu/test_rag.py::TestLoadDPRModel::test_load_dpr
52.57s call     tests/nightly/gpu/test_rag.py::TestRagDprPoly::test_rag_sequence
49.35s call     tests/nightly/gpu/test_rag.py::TestRagDprPoly::test_rag_turn
45.17s call     tests/nightly/gpu/test_rag.py::TestRagZooModels::test_bart_rag_sequence
44.22s call     tests/nightly/gpu/test_rag.py::TestRagZooModels::test_bart_rag_turn_do
44.19s call     tests/nightly/gpu/test_rag.py::TestRagZooModels::test_bart_rag_turn_dtt
42.21s call     tests/nightly/gpu/test_rag.py::TestFidZooModels::test_bart_fid_dpr
=================================================================== 38 passed, 16 warnings in 1057.45s (0:17:37) ====================================================================
$ pytest test_bb2.py -x
================================================================================ test session starts ================================================================================
platform linux -- Python 3.7.9, pytest-6.2.1, py-1.10.0, pluggy-1.0.0.dev0
rootdir: /private/home/kshuster/ParlAI, configfile: pytest.ini
plugins: hydra-core-1.0.7, requests-mock-1.8.0, regressions-2.1.1, datadir-1.3.1
collected 20 items

test_bb2.py ....................                                                                                                                                              [100%]

=============================================================================== slowest 10 durations ================================================================================
193.45s call     tests/nightly/gpu/test_bb2.py::TestBB2Search::test_rag
183.39s call     tests/nightly/gpu/test_bb2.py::TestBB2RagTurn::test_rag_turn
172.64s call     tests/nightly/gpu/test_bb2.py::TestBB2QGenParams::test_rag
169.55s call     tests/nightly/gpu/test_bb2.py::TestBB2AdditionalTruncation::test_rag
168.82s call     tests/nightly/gpu/test_bb2.py::TestBB2RagSequence::test_rag
146.15s call     tests/nightly/gpu/test_bb2.py::TestBB2ZooModel::test_zoo_model_3B
116.30s call     tests/nightly/gpu/test_bb2.py::TestBB2ZooModel::test_zoo_model
106.08s call     tests/nightly/gpu/test_bb2.py::TestBB2Rag::test_retrieval_all
101.89s call     tests/nightly/gpu/test_bb2.py::TestBB2GoldDocs::test_rag
100.97s call     tests/nightly/gpu/test_bb2.py::TestBB2MemoryDecoder::test_rag
==================================================================== 20 passed, 4 warnings in 2369.29s (0:39:29) ====================================================================

Empirical Evaluations

Empirical Speed tests indicate the following speedups (batch size 16, generation parameters from the hallucination project zoo models). Two observations:

  1. Model performance marginally decreases
  2. FiD does not seem to speed up as much as RAG Sequence, and RAG Token. This is due to the expanded effective batch size for the RAG models given their document marginalization techniques.
Model Incremental Decoding PPL F1 KF1 RF1 Speed (minutes)
RAG Sequence Yes 11.91 20.99 25.79 16.25 1:06:17
No 11.91 21.01 25.85 16.28 1:23:11
RAG Token Yes 12.40 22.12 23.06 17.07 44:03
No 12.40 22.10 23.03 17.07 57:25
FiD-RAG Yes 12.37 22.84 27.50 18.07 49:56
no 12.37 22.90 27.55 18.14 50:37

KnowledgeAccessMethod(opt['knowledge_access_method'])
is KnowledgeAccessMethod.ALL
):
self.n_docs *= 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: is this because we need n_docs for search and n_docs for memory? Could you add a small comment on it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good call, will add comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another question: if this was added in this PR, how was it handled so far?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we did not require a specific BlenderBot2Fid-specific model interface prior; this n_docs is used for incremental decoding, however

Copy link
Contributor

@mojtaba-komeili mojtaba-komeili left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@klshuster klshuster merged commit ac0d5d3 into main Oct 13, 2021
@klshuster klshuster deleted the rag_fid_incr_decoding branch October 13, 2021 20:21
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants