Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RAG] Propagating of n_docs as parameter to all RagModel's related functions #7891

Merged
merged 14 commits into from Oct 19, 2020
Merged

[RAG] Propagating of n_docs as parameter to all RagModel's related functions #7891

merged 14 commits into from Oct 19, 2020

Conversation

lalitpagaria
Copy link
Contributor

What does this PR do?

Fixes #7874

Before submitting

  • This PR fixes a typo or improves the docs (you can dimiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to the it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.

@patrickvonplaten

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides a small suggestion for the docstrings this PR looks great! Thanks a lot @lalitpagaria !

@patrickvonplaten
Copy link
Contributor

@lhoestq would be great if you can review as well

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@lalitpagaria
Copy link
Contributor Author

@patrickvonplaten Thanks for the review.

while working on this PR I found that in RagTokenForGeneration we are computing batch_size as follows -

batch_size = context_input_ids.shape[0] // n_docs

So still issue can come when ((context_input_ids.shape[0] % n_docs) != 0), but I can't think of solution to address this.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good to me :)

src/transformers/modeling_rag.py Outdated Show resolved Hide resolved
@lhoestq
Copy link
Member

lhoestq commented Oct 19, 2020

@patrickvonplaten Thanks for the review.

while working on this PR I found that in RagTokenForGeneration we are computing batch_size as follows -

batch_size = context_input_ids.shape[0] // n_docs

So still issue can come when ((context_input_ids.shape[0] % n_docs) != 0), but I can't think of solution to address this.

context_input_ids is always supposed to have a size of n_docs times the number of input questions

@lalitpagaria
Copy link
Contributor Author

@patrickvonplaten Thanks for the review.
while working on this PR I found that in RagTokenForGeneration we are computing batch_size as follows -

batch_size = context_input_ids.shape[0] // n_docs

So still issue can come when ((context_input_ids.shape[0] % n_docs) != 0), but I can't think of solution to address this.

context_input_ids is always supposed to have a size of n_docs times the number of input questions

It would be better if we mention it explicitly by assert. WDYT?
In one of my test case I used n_docs=3 for retriever and n_docs=2 for generator and it failed

@lhoestq
Copy link
Member

lhoestq commented Oct 19, 2020

It would be better if we mention it explicitly by assert. WDYT?
In one of my test case I used n_docs=3 for retriever and n_docs=2 for generator and it failed

Yes indeed. Also if ((context_input_ids.shape[0] % n_docs) != 0) then we should raise an error otherwise some retrieved documents will be ignored for generation.

@patrickvonplaten
Copy link
Contributor

Yes @lalitpagaria - it would be nice if you can add an asserte statement verifying that n_docs is correctly set. n_docs should be the same for both retriever and generator.

…s should be the same for both retriever and generator.
@lalitpagaria
Copy link
Contributor Author

@patrickvonplaten @lhoestq Added assert at two places please verify, along with supporting unit test. Pardon my naming convention for test function, and please suggest proper name :)

n_docs should be the same for both retriever and generator.

This can't be check if generator does not know about retriever hence using this ((context_input_ids.shape[0] % n_docs) != 0)

tests/test_modeling_rag.py Outdated Show resolved Hide resolved
@lalitpagaria
Copy link
Contributor Author

lalitpagaria commented Oct 19, 2020

@patrickvonplaten and @lhoestq Thanks for the review. I liked the test coverage of this project. Initially I struggled but letter all worked nicely. You can merge when you want.

@patrickvonplaten
Copy link
Contributor

Slow tests pass => ready to merge

@patrickvonplaten patrickvonplaten merged commit 0193c82 into huggingface:master Oct 19, 2020
@patrickvonplaten
Copy link
Contributor

Good job @lalitpagaria !

@lalitpagaria lalitpagaria deleted the propogate_n_docs_as_param branch October 19, 2020 13:41
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
…nctions (huggingface#7891)

* Propagating n_docs as parameter to all RagModel's related functions that defaults to self.config.n_docs

* Making n_docs parameter's default value to None in marginalize function

* Fixing code quality issues

* Handle the special case when generator is of T5PreTrainedModel instance type. T5PreTrainedModel do not have n_docs as parameter

* T5PreTrainedModel do not have n_docs as parameter

* Addressing review comment

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Correcting comment by addressing review comment

* Adding assert statement verifying that n_docs is correctly set. n_docs should be the same for both retriever and generator.

* Fixing flake8 reported issue

* Correcting test datasets for rag

* Using doc_scores instead of context_input_ids to check assert as in RagSequenceForGeneration context_input_ids can be null

* doc_scores second dimension have number of retrieved docs

* Changing assert comment

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Rag] extend_enc_output fails when number of retrieved documents not equal to RagConfig.n_docs
3 participants