diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..e69de29b diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..e26d1aba --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,76 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..ca1b377f --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,35 @@ +# Contributing to DPR +We want to make contributing to this project as easy and transparent as +possible. + +## Our Development Process +TBD + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `master`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +## Coding Style +* 2 spaces for indentation rather than tabs +* 120 character line length +* ... + +## License +By contributing to Facebook AI Research Dense Passage Retriever toolkit, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..50f2e656 --- /dev/null +++ b/LICENSE @@ -0,0 +1,399 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 00000000..ff0da4bf --- /dev/null +++ b/README.md @@ -0,0 +1,216 @@ +# Dense Passage Rertriever +------------------------------------------------------------------------------------------------------------------ + +Dense Passage Retriever - is a set of tools and models for open domain Q&A task. +It is based on [this](https://arxiv.org/abs/2004.04906) research work and provides state-of-the-art results for multiple Q&A datasets. + + +### Features +1. Dense retriever model based on biencoder architecture. +2. Extractive Q&A reader&ranker joint model inspired by [this](https://arxiv.org/abs/1911.03868) paper. +3. Related data pre- and post- processing tools. +4. Dense retriever component for inference time logic based on FAISS index. + + +### Installation + +Installation from source. Python's virtual or conda environments are recommended. + +```bash +git clone git@github.com:fairinternal/DPR.git +cd DPR +pip install . +``` + +DPR is tested on Python 3.6+ and PyTorch 1.2.0+. +DPR relies on third party libraries for encoder code implementations. +It currently supports Huggingface BERT, Pytext BERT and Fairseq Roberta encoder models. +Due to generality of tokenization process, DPR uses Huggingface tokenizers as of now. So Huggingface is the only required dependency, Pytext & Fairseq are optional. +Install them separately if you want to use those encoders. + + + +### Resources & Data formats +First you need to prepare data for either retriever or reader training. +Each of the DPR components has its own input/ouput data formats. You can see format desciptions below. +DPR provides NQ & Trivia preprocessed datasets (and model checkpoints) to be downloaded from cloud using our data/download_data.py tool. One needs to specify the resource name to be downloaded. Run 'python data/download_data.py' to see all options. + +```bash +python data/download_data.py --resource {key from download_data.py's RESOURCES_MAP} [optional --output_dir {your location}] +``` +The resource name matching is prefix based. So if you need to download all data resources, just use --resource data + +### Retriever input data format +The data format of the Retriever training data is json. +It contains pools of 2 types of negative passages per question, as well as positive passages and some additional information. + +``` +[ + { + "question": "....", + "answers": ["...", "...", ... ], + "positive_ctxs": [ + { + "title": "...", + "text": "....", + }, + "negative_ctxs": [...], + "hard_negative_ctxs": [...] + }, + ... + ] +``` + +Elements' structure for negative_ctxs & hard_negative_ctxs is exactly the same as for positive_ctxs. +The preprocessed data available for dowloading also contains some extra atributes which may be usefult for model modifications (like bm25 scores per passage) but they are not currently in use by DPR. + +You can download prepared NQ dataset used in the paper by using 'data.retriever.nq' key prefix. Only dev & train subsets are available in this format. +We also provide question & answers only csv data files for all train/dev/test splits. Those are used for the model evaluation since our NQ preprocessing step looses a part of original samples set. +Use 'data.retriever.qas.*' resource keys to get respective sets for evaluation. + +```bash +python data/download_data.py --resource data.retriever [optional --output_dir {your location}] +``` + + +### Retriever training +Retriever training quality depends on its effective batch size. The one reported in the paper used 8 32gb GPUs. +In order to start training on one machine: +```bash +python train_dense_encoder.py --encoder_model_type {hf_bert | pytext_bert | fairseq_roberta} --pretrained_model_cfg {bert-base-uncased| roberta-base} --train_file {train files glob expression} --dev_file {dev files glob expression} --output_dir {dir to save checkpoints} +``` + +Notes: +- If you use pytext_bert or fairseq_roberta, you need to download pre-trained weights and specify --pretrained_file parameter. Specify the dir location of the downloaded files for 'pretrained.fairseq.roberta-base' resource prefix for roberta model or the file path for pytext bert (resource name 'pretrained.pytext.bert-base.model'). +- Validation and checkpoint saving happens according to --eval_per_epoch parameter value. +- There is no stop condition besides specified amount of epochs to train. +- Every evaluation saves a model checkpoint. +- The best checkpoint is logged in the train process output. +- Regular NLL classification loss validation for biencoder training can be replaced with average rank evaluation. It aggregates passage and question vectors from the input data passages pools, does large similarity matrix calculation for those representations and then averages the rank of the gold passage for each question. We found this metric more correlating with the final retrieval performance vs nll classification loss. Note however that this average rank validation works differently in DistributedDataParallel vs DataParallel pytorch modes. See val_av_rank_* set of parameters to enable this mode and modify its settings. + +See the section 'Best hyperparmeter settings' below as e2e example for our best setups. + +### Generating representations for a large documents set + +Generating represenation vectors for the static dociments dataset is highly parallelizable process which can take up to few days if computed on a single GPU. You might want to use multiple available GPU servers by running the script on each of them independently and specifying their own shards. + +```bash +python generate_dense_embeddings.py --model_file {path to biencoder checkpoint} --ctx_file {path to psgs_w100.tsv file} --shard_id {shard_num, 0-based} --num_shards {total number of shards} --out_file ${out files location + name PREFX} + +``` +Note: you can use much large batch size here compared to training mode. For example, setting --batch_size 128 for 2 GPU(16gb) server should work fine. + +### Retriever validation against the entire set of documents: + +```bash +python dense_retriever.py --model_file ${path to biencoder checkpoint} --ctx_file {path to all documents .tsv file} --qa_file {path to test|dev .csv file} --encoded_ctx_file "{encoded document files glob expression}" --out_file {path to output json file with results} +``` + +The tool writes retrieved results for subsequent reader model training into specified out_file. +It is a json with the following format: + +``` +[ + { + "question": "...", + "answers": ["...", "...", ... ], + "ctxs": [ + { + "id": "...", # passage id from database tsv file + "title": "", + "text": "....", + "score": "...", # retriever score + "has_answer": true|false + }, +] +``` +Results are sorted by their similarity score, from most relevant to least relevant. + +By default, dense_retriever uses exhaustive search process, but you can opt in to use HNSW FAISS index by --hnsw_index flag. +Note that using this index may be useless from the research point of view since their fast retrieval process comes at the cost of much longer indexing time and higher RAM usage. +The similarity score provided is the dot product in the (default) case of exhaustive search and L2 distance in a modified representations space in case of HNSW index. + + +### Optional reader model input data pre-processing. +Since the reader model uses specific combination of positive and negative passages for each question and also needs to know the answer span location in the bpe-tokenized form, it is recommended to preprocess and serialize the output from the retriever model before starting the reader training. This saves hours at train time. +If you don't run this preprocessing, the Reader training pipeline checks if the input file(s) extension is .pkl and if not, preprocesses and caches results automatically in the same folder as the original files. + +```bash +python preprocess_reader_data.py --retriever_results {path to a file with results from dense_retriever.py} --gold_passages {path to gold passages info} --do_lower_case --pretrained_model_cfg {pretrained_cfg} --encoder_model_type {hf_bert | pytext_bert | fairseq_roberta} --out_file {path to for output files} --is_train_set +``` + + + +### Reader model training +```bash +python train_reader.py --encoder_model_type {hf_bert | pytext_bert | fairseq_roberta} --pretrained_model_cfg {bert-base-uncased| roberta-base} --train_file "{globe expression for train files from #5 or #6 above}" --dev_file "{globe expression for train files}" --output_dir {path to output dir} +``` + +Notes: +- if you use pytext_bert or fairseq_roberta, you need to download pre-trained weights and specify --pretrained_file parameter. Specify the dir location of the downloaded files for 'pretrained.fairseq.roberta-base' resource prefix for roberta model or the file path for pytext bert (resource name 'pretrained.pytext.bert-base.model'). +- Reader training pipeline does model validation every --eval_step batches +- As the biencoder, it saves model checkpoints on every validation +- Like the biencoder, there is no stop condition besides specified amount of epochs to train. +- Like the biencoder, there is no best checkpoit selection logic so one needs to select that based on dev set validation performance which is logged in the train process output. +- Our current code only calculates Exact Match metric. + +### Distributed training. +Use Pytorch's distributed training launcher tool: + +```bash +python -m torch.distributed.launch --nproc_per_node={WORLD_SIZE} {non distributed scipt name & parameters} +``` +Note: +- all batch size related parameters are specified per gpu in distributed mode(DistributedDataParallel) and for all available gpus in DataParallel (single node - multi gpu) mode. + +### Best hyperparmeter settings + +e2e example with the best settings for NQ dataset. + +#### 1. Download all retriever training and validation data: + +```bash +python data/download_data.py --resource data.wikipedia_split.psgs_w100 +python data/download_data.py --resource data.retriever.nq +python data/download_data.py --resource data.retriever.qas.nq +``` + +#### 2. Biencoder(Retriever) training in single set mode. + +We used distirbuted training mode on a single 8 gpu x 32 gb server + +```bash +python -m torch.distributed.launch --nproc_per_node=8 train_dense_encoder.py --max_grad_norm 2.0 --encoder_model_type hf_bert --pretrained_model_cfg bert-base-uncased --seed 12345 --sequence_length 256 --warmup_steps 1237 --batch_size 16 --do_lower_case --train_file "{glob expression to train files for 'data.retriever.nq' resource}" --dev_file {path to downloaded data.retriever.qas.nq-dev resource} --output_dir {your output dir} --learning_rate 2e-05 --num_train_epochs 40 --dev_batch_size 16 --val_av_rank_start_epoch 30 +``` +This takes about a day to complete the training for 40 epochs. It swiches to Average Rank validation on epoch 30 and it should be around 25 at the end. +The best checkpoint for biencoder is usually the last but it should not be so different if you take any after epoch ~ 25. + +#### 3. Generate embeddings for wikipedia. +Just use instrictions for "Generating representations for a large documents set". It takes about 40 minutes to produce 21 mln passages representation vectors on 50 2 gpu servers. + +#### 4. Evaluate retrieval accuracy and generate top passage results for each of train/dev/test datasets. + +```bash +python dense_retriever.py --model_file {path to checkpoint file from step 1} --ctx_file {path to psgs_w100.tsv file} --qa_file {path to test/dev qas file} --encoded_ctx_file "{glob expression for generated files from step 3}" --out_file {path for output json files} --n-docs 100 --validation_workers 32 --batch_size 64 +``` + +Adjust batch_size based on available number of gpus, 64 should work for 2 gpu server. + +#### 5. Reader training +We trained reader model for large datasets using single 8 gpu x 32 gb server. + +```bash +python train_reader.py --seed 42 --learning_rate 1e-5 --eval_step 2000 --do_lower_case --eval_top_docs 50 --encoder_model_type hf_bert --pretrained_model_cfg bert-base-uncased --train_file "{glob expression for train output files from step 4}" --dev_file {glob expression for dev output file from step 4} --warmup_steps 0 --sequence_length 350 --batch_size 16 --passages_per_question 24 --num_train_epochs 100000 --dev_batch_size 72 --passages_per_question_predict 50 --output_dir {your save dir path} +``` + +We found that usign the learning rate above works best with static schedule so one needs to stop training manually based on evaluation performance dynamics. +Our best results were achieved on 16-18 training epochs or after ~60k model updates. + +We provide all input and intermediate results for e2e pipeline for NQ dataset and most of the similar resources for Trivia. + +### Misc. +- TREC validation requires regexp based mathing. We support only retriever validation in regexp mode. See --math parameter options. +- WEbQ validation requires entity normalization which is not included as of now. + +#### License +DPR is CC-BY-NC 4.0 licensed as of now. diff --git a/data/download_data.py b/data/download_data.py new file mode 100644 index 00000000..ba708296 --- /dev/null +++ b/data/download_data.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Command line tool to download various preprocessed data sources & checkpoints for DPR +""" + +import gzip +import os +import pathlib + +import argparse +import wget + +NQ_LICENSE_FILES = [ + 'https://dl.fbaipublicfiles.com/dpr/nq_license/LICENSE', + 'https://dl.fbaipublicfiles.com/dpr/nq_license/README', +] + +RESOURCES_MAP = { + 'data.wikipedia_split.psgs_w100': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz', + 'original_ext': '.tsv', + 'compressed': True, + 'desc': 'Entire wikipedia passages set obtain by splitting all pages into 100-word segments (no overlap)' + }, + 'data.retriever.nq-dev': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz', + 'original_ext': '.json', + 'compressed': True, + 'desc': 'NQ dev subset with passages pools for the Retriever train time validation', + 'license_files': NQ_LICENSE_FILES, + }, + + 'data.retriever.nq-train': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz', + 'original_ext': '.json', + 'compressed': True, + 'desc': 'NQ train subset with passages pools for the Retriever training', + 'license_files': NQ_LICENSE_FILES, + }, + + 'data.retriever.trivia-dev': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-trivia-dev.json.gz', + 'original_ext': '.json', + 'compressed': True, + 'desc': 'TriviaQA dev subset with passages pools for the Retriever train time validation' + }, + + 'data.retriever.trivia-train': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-trivia-train.json.gz', + 'original_ext': '.json', + 'compressed': True, + 'desc': 'TriviaQA train subset with passages pools for the Retriever training' + }, + + 'data.retriever.qas.nq-dev': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-dev.qa.csv', + 'original_ext': '.csv', + 'compressed': False, + 'desc': 'NQ dev subset for Retriever validation and IR results generation', + 'license_files': NQ_LICENSE_FILES, + }, + + 'data.retriever.qas.nq-test': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-test.qa.csv', + 'original_ext': '.csv', + 'compressed': False, + 'desc': 'NQ test subset for Retriever validation and IR results generation', + 'license_files': NQ_LICENSE_FILES, + }, + + 'data.retriever.qas.nq-train': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-train.qa.csv', + 'original_ext': '.csv', + 'compressed': False, + 'desc': 'NQ train subset for Retriever validation and IR results generation', + 'license_files': NQ_LICENSE_FILES, + }, + + # + 'data.retriever.qas.trivia-dev': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-dev.qa.csv.gz', + 'original_ext': '.csv', + 'compressed': True, + 'desc': 'Trivia dev subset for Retriever validation and IR results generation' + }, + + 'data.retriever.qas.trivia-test': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-test.qa.csv.gz', + 'original_ext': '.csv', + 'compressed': True, + 'desc': 'Trivia test subset for Retriever validation and IR results generation' + }, + + 'data.retriever.qas.trivia-train': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever/trivia-train.qa.csv.gz', + 'original_ext': '.csv', + 'compressed': True, + 'desc': 'Trivia train subset for Retriever validation and IR results generation' + }, + + 'data.gold_passages_info.nq_train': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-train_gold_info.json.gz', + 'original_ext': '.json', + 'compressed': True, + 'desc': 'Original NQ (our train subset) gold positive passages and alternative question tokenization', + 'license_files': NQ_LICENSE_FILES, + }, + + 'data.gold_passages_info.nq_dev': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-dev_gold_info.json.gz', + 'original_ext': '.json', + 'compressed': True, + 'desc': 'Original NQ (our dev subset) gold positive passages and alternative question tokenization', + 'license_files': NQ_LICENSE_FILES, + }, + + 'data.gold_passages_info.nq_test': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-test_gold_info.json.gz', + 'original_ext': '.json', + 'compressed': True, + 'desc': 'Original NQ (our test, original dev subset) gold positive passages and alternative question ' + 'tokenization', + 'license_files': NQ_LICENSE_FILES, + }, + + 'pretrained.fairseq.roberta-base.dict': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/pretrained/fairseq/roberta/dict.txt', + 'original_ext': '.txt', + 'compressed': False, + 'desc': 'Dictionary for pretrained fairseq roberta model' + }, + + 'pretrained.fairseq.roberta-base.model': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/pretrained/fairseq/roberta/model.pt', + 'original_ext': '.pt', + 'compressed': False, + 'desc': 'Weights for pretrained fairseq roberta base model' + }, + + 'pretrained.pytext.bert-base.model': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/pretrained/pytext/bert/bert-base-uncased.pt', + 'original_ext': '.pt', + 'compressed': False, + 'desc': 'Weights for pretrained pytext bert base model' + }, + + 'data.retriever_results.nq.single.test': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-test.json.gz', + 'original_ext': '.json', + 'compressed': True, + 'desc': 'Retrieval results of NQ test dataset for the encoder trained on NQ', + 'license_files': NQ_LICENSE_FILES, + }, + 'data.retriever_results.nq.single.dev': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-dev.json.gz', + 'original_ext': '.json', + 'compressed': True, + 'desc': 'Retrieval results of NQ dev dataset for the encoder trained on NQ', + 'license_files': NQ_LICENSE_FILES, + }, + 'data.retriever_results.nq.single.train': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-train.json.gz', + 'original_ext': '.json', + 'compressed': True, + 'desc': 'Retrieval results of NQ train dataset for the encoder trained on NQ', + 'license_files': NQ_LICENSE_FILES, + }, + + 'checkpoint.retriever.single.nq.bert-base-encoder': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/checkpoint/retriever/single/nq/hf_bert_base.cp', + 'original_ext': '.cp', + 'compressed': False, + 'desc': 'Biencoder weights trained on NQ data and HF bert-base-uncased model' + }, + + 'checkpoint.retriever.multiset.bert-base-encoder': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/checkpoint/retriver/multiset/hf_bert_base.cp', + 'original_ext': '.cp', + 'compressed': False, + 'desc': 'Biencoder weights trained on multi set data and HF bert-base-uncased model' + }, + + 'data.reader.nq.single.train': { + 's3_url': ['https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/train.{}.pkl'.format(i) for i in range(8)], + 'original_ext': '.pkl', + 'compressed': False, + 'desc': 'Reader model NQ train dataset input data preprocessed from retriever results (also trained on NQ)', + 'license_files': NQ_LICENSE_FILES, + }, + + 'data.reader.nq.single.dev': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/dev.0.pkl', + 'original_ext': '.pkl', + 'compressed': False, + 'desc': 'Reader model NQ dev dataset input data preprocessed from retriever results (also trained on NQ)', + 'license_files': NQ_LICENSE_FILES, + }, + + 'data.reader.nq.single.test': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/reader/nq/single/test.0.pkl', + 'original_ext': '.pkl', + 'compressed': False, + 'desc': 'Reader model NQ test dataset input data preprocessed from retriever results (also trained on NQ)', + 'license_files': NQ_LICENSE_FILES, + }, + + 'data.reader.trivia.multi-hybrid.train': { + 's3_url': ['https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid/train.{}.pkl'.format(i) for i in + range(8)], + 'original_ext': '.pkl', + 'compressed': False, + 'desc': 'Reader model Trivia train dataset input data preprocessed from hybrid retriever results ' + '(where dense part is trained on multiset)' + }, + + 'data.reader.trivia.multi-hybrid.dev': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid/dev.0.pkl', + 'original_ext': '.pkl', + 'compressed': False, + 'desc': 'Reader model Trivia dev dataset input data preprocessed from hybrid retriever results ' + '(where dense part is trained on multiset)' + }, + + 'data.reader.trivia.multi-hybrid.test': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/data/reader/trivia/multi-hybrid//test.0.pkl', + 'original_ext': '.pkl', + 'compressed': False, + 'desc': 'Reader model Trivia test dataset input data preprocessed from hybrid retriever results ' + '(where dense part is trained on multiset)' + }, + + 'checkpoint.reader.nq-single.hf-bert-base': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-single/hf_bert_base.cp', + 'original_ext': '.cp', + 'compressed': False, + 'desc': 'Reader weights trained on NQ-single retriever results and HF bert-base-uncased model' + }, + + 'checkpoint.reader.nq-trivia-hybrid.hf-bert-base': { + 's3_url': 'https://dl.fbaipublicfiles.com/dpr/checkpoint/reader/nq-trivia-hybrid/hf_bert_base.cp', + 'original_ext': '.cp', + 'compressed': False, + 'desc': 'Reader weights trained on Trivia multi hybrid retriever results and HF bert-base-uncased model' + }, +} + + +def unpack(gzip_file: str, out_file: str): + print('Uncompressing ', gzip_file) + input = gzip.GzipFile(gzip_file, 'rb') + s = input.read() + input.close() + output = open(out_file, 'wb') + output.write(s) + output.close() + print('Saved to ', out_file) + + +def download_resource(s3_url: str, original_ext: str, compressed: bool, resource_key: str, out_dir: str) -> str: + print('Loading from ', s3_url) + + # create local dir + path_names = resource_key.split('.') + + root_dir = out_dir if out_dir else './' + save_root = os.path.join(root_dir, *path_names[:-1]) # last segment is for file name + + pathlib.Path(save_root).mkdir(parents=True, exist_ok=True) + + local_file = os.path.join(save_root, path_names[-1] + ('.tmp' if compressed else original_ext)) + + wget.download(s3_url, out=local_file) + + print('Saved to ', local_file) + + if compressed: + uncompressed_file = os.path.join(save_root, path_names[-1] + original_ext) + unpack(local_file, uncompressed_file) + os.remove(local_file) + return save_root + + +def download_file(s3_url: str, out_dir: str, file_name: str): + print('Loading from ', s3_url) + local_file = os.path.join(out_dir, file_name) + wget.download(s3_url, out=local_file) + print('Saved to ', local_file) + + +def download(resource_key: str, out_dir: str = None): + if resource_key not in RESOURCES_MAP: + # match by prefix + resources = [k for k in RESOURCES_MAP.keys() if k.startswith(resource_key)] + if resources: + for key in resources: + download(key, out_dir) + else: + print('no resources found for specified key') + return + download_info = RESOURCES_MAP[resource_key] + + s3_url = download_info['s3_url'] + + save_root_dir = None + if isinstance(s3_url, list): + for i, url in enumerate(s3_url): + save_root_dir = download_resource(url, download_info['original_ext'], download_info['compressed'], + '{}_{}'.format(resource_key, i), out_dir) + else: + save_root_dir = download_resource(s3_url, download_info['original_ext'], download_info['compressed'], + resource_key, out_dir) + + license_files = download_info.get('license_files', None) + if not license_files: + return + + download_file(license_files[0], save_root_dir, 'LICENSE') + download_file(license_files[1], save_root_dir, 'README') + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument("--output_dir", default="./", type=str, + help="The output directory to download file") + + parser.add_argument("--resource", type=str, + help="Resource name. See RESOURCES_MAP for all possible values") + + args = parser.parse_args() + if args.resource: + download(args.resource, args.output_dir) + else: + print('Please specify resource value. Possible options are:') + for k, v in RESOURCES_MAP.items(): + print('Resource key={} description: {}'.format(k, v['desc'])) + + +if __name__ == '__main__': + main() diff --git a/dense_retriever.py b/dense_retriever.py new file mode 100644 index 00000000..17eacbc2 --- /dev/null +++ b/dense_retriever.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Command line tool to get dense results and validate them +""" + +import argparse +import csv +import glob +import json +import logging +import pickle +import time +from typing import List, Tuple, Dict, Iterator + +import numpy as np +import torch +from torch import Tensor as T +from torch import nn + +from dpr.data.qa_validation import calculate_matches +from dpr.models import init_biencoder_components +from dpr.options import add_encoder_params, setup_args_gpu, print_args, set_encoder_params_from_state, \ + add_tokenizer_params, add_cuda_params +from dpr.utils.data_utils import Tensorizer +from dpr.utils.model_utils import setup_for_distributed_mode, get_model_obj, load_states_from_checkpoint +from dpr.indexer.faiss_indexers import DenseIndexer, DenseHNSWFlatIndexer, DenseFlatIndexer + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +if (logger.hasHandlers()): + logger.handlers.clear() +console = logging.StreamHandler() +logger.addHandler(console) + + +class DenseRetriever(object): + """ + Does passage retrieving over the provided index and question encoder + """ + def __init__(self, question_encoder: nn.Module, batch_size: int, tensorizer: Tensorizer, index: DenseIndexer): + self.question_encoder = question_encoder + self.batch_size = batch_size + self.tensorizer = tensorizer + self.index = index + + def generate_question_vectors(self, questions: List[str]) -> T: + n = len(questions) + bsz = self.batch_size + query_vectors = [] + + self.question_encoder.eval() + + with torch.no_grad(): + for j, batch_start in enumerate(range(0, n, bsz)): + + batch_token_tensors = [self.tensorizer.text_to_tensor(q) for q in + questions[batch_start:batch_start + bsz]] + + q_ids_batch = torch.stack(batch_token_tensors, dim=0).cuda() + q_seg_batch = torch.zeros_like(q_ids_batch).cuda() + q_attn_mask = self.tensorizer.get_attn_mask(q_ids_batch) + _, out, _ = self.question_encoder(q_ids_batch, q_seg_batch, q_attn_mask) + + query_vectors.extend(out.cpu().split(1, dim=0)) + + if len(query_vectors) % 100 == 0: + logger.info('Encoded queries %d', len(query_vectors)) + + query_tensor = torch.cat(query_vectors, dim=0) + + logger.info('Total encoded queries tensor %s', query_tensor.size()) + + assert query_tensor.size(0) == len(questions) + return query_tensor + + def index_encoded_data(self, vector_files: List[str], buffer_size: int = 50000): + """ + Indexes encoded passages takes form a list of files + :param vector_files: file names to get passages vectors from + :param buffer_size: size of a buffer (amount of passages) to send for the indexing at once + :return: + """ + buffer = [] + for i, item in enumerate(iterate_encoded_files(vector_files)): + db_id, doc_vector = item + buffer.append((db_id, doc_vector)) + if 0 < buffer_size == len(buffer): + self.index.index_data(buffer) + buffer = [] + self.index.index_data(buffer) + logger.info('Data indexing completed.') + + def get_top_docs(self, query_vectors: np.array, top_docs: int = 100) -> List[Tuple[List[object], List[float]]]: + """ + Does the retrieval of the best matching passages given the query vectors batch + :param query_vectors: + :param top_docs: + :return: + """ + time0 = time.time() + results = self.index.search_knn(query_vectors, top_docs) + logger.info('index search time: %f sec.', time.time() - time0) + return results + + +def parse_qa_csv_file(location) -> Iterator[Tuple[str, List[str]]]: + with open(location) as ifile: + reader = csv.reader(ifile, delimiter='\t') + for row in reader: + question = row[0] + answers = eval(row[1]) + yield question, answers + + +def validate(passages: Dict[object, Tuple[str, str]], answers: List[List[str]], + result_ctx_ids: List[Tuple[List[object], List[float]]], + workers_num: int, match_type: str) -> List[List[bool]]: + match_stats = calculate_matches(passages, answers, result_ctx_ids, workers_num, match_type) + top_k_hits = match_stats.top_k_hits + + logger.info('Validation results: top k documents hits %s', top_k_hits) + top_k_hits = [v / len(result_ctx_ids) for v in top_k_hits] + logger.info('Validation results: top k documents hits accuracy %s', top_k_hits) + return match_stats.questions_doc_hits + + +def load_passages(ctx_file: str) -> Dict[object, Tuple[str, str]]: + docs = {} + logger.info('Reading data from: %s', ctx_file) + with open(ctx_file) as tsvfile: + reader = csv.reader(tsvfile, delimiter='\t', ) + # file format: doc_id, doc_text, title + for row in reader: + if row[0] != 'id': + docs[row[0]] = (row[1], row[2]) + return docs + + +def save_results(passages: Dict[object, Tuple[str, str]], questions: List[str], answers: List[List[str]], + top_passages_and_scores: List[Tuple[List[object], List[float]]], per_question_hits: List[List[bool]], + out_file: str + ): + # join passages text with the result ids, their questions and assigning has|no answer labels + merged_data = [] + assert len(per_question_hits) == len(questions) == len(answers) + for i, q in enumerate(questions): + q_answers = answers[i] + results_and_scores = top_passages_and_scores[i] + hits = per_question_hits[i] + docs = [passages[doc_id] for doc_id in results_and_scores[0]] + scores = [str(score) for score in results_and_scores[1]] + ctxs_num = len(hits) + + merged_data.append({ + 'question': q, + 'answers': q_answers, + 'ctxs': [ + { + 'id': results_and_scores[0][c], + 'title': docs[c][1], + 'text': docs[c][0], + 'score': scores[c], + 'has_answer': hits[c], + } for c in range(ctxs_num) + ] + }) + + with open(out_file, "w") as writer: + writer.write(json.dumps(merged_data, indent=4) + "\n") + logger.info('Saved results * scores to %s', out_file) + + +def iterate_encoded_files(vector_files: list) -> Iterator[Tuple[object, np.array]]: + for i, file in enumerate(vector_files): + logger.info('Reading file %s', file) + with open(file, "rb") as reader: + doc_vectors = pickle.load(reader) + for doc in doc_vectors: + db_id, doc_vector = doc + yield db_id, doc_vector + + +def main(args): + saved_state = load_states_from_checkpoint(args.model_file) + set_encoder_params_from_state(saved_state.encoder_params, args) + + tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) + + encoder = encoder.question_model + + encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, + args.local_rank, + args.fp16) + encoder.eval() + + # load weights from the model file + model_to_load = get_model_obj(encoder) + logger.info('Loading saved model state ...') + + prefix_len = len('question_model.') + question_encoder_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if + key.startswith('question_model.')} + model_to_load.load_state_dict(question_encoder_state) + vector_size = model_to_load.get_out_size() + logger.info('Encoder vector_size=%d', vector_size) + + index_buffer_sz = args.index_buffer + if args.hnsw_index: + index = DenseHNSWFlatIndexer(vector_size) + index_buffer_sz = -1 # encode all at once + else: + index = DenseFlatIndexer(vector_size) + + retriever = DenseRetriever(encoder, args.batch_size, tensorizer, index) + + # index all passages + ctx_files_pattern = args.encoded_ctx_file + input_paths = glob.glob(ctx_files_pattern) + logger.info('Reading all passages data from files: %s', input_paths) + retriever.index_encoded_data(input_paths, buffer_size=index_buffer_sz) + + # get questions & answers + questions = [] + question_answers = [] + + for ds_item in parse_qa_csv_file(args.qa_file): + question, answers = ds_item + questions.append(question) + question_answers.append(answers) + + questions_tensor = retriever.generate_question_vectors(questions) + + # get top k results + top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), args.n_docs) + + all_passages = load_passages(args.ctx_file) + + if len(all_passages) == 0: + raise RuntimeError('No passages data found. Please specify ctx_file param properly.') + + questions_doc_hits = validate(all_passages, question_answers, top_ids_and_scores, args.validation_workers, + args.match) + + if args.out_file: + save_results(all_passages, questions, question_answers, top_ids_and_scores, questions_doc_hits, args.out_file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + add_encoder_params(parser) + add_tokenizer_params(parser) + add_cuda_params(parser) + + parser.add_argument('--qa_file', required=True, type=str, default=None, + help="Question and answers file of the format: question \\t ['answer1','answer2', ...]") + parser.add_argument('--ctx_file', required=True, type=str, default=None, + help="All passages file in the tsv format: id \\t passage_text \\t title") + parser.add_argument('--encoded_ctx_file', type=str, default=None, + help='Glob path to encoded passages (from generate_dense_embeddings tool)') + parser.add_argument('--out_file', type=str, default=None, + help='output .tsv file path to write results to ') + parser.add_argument('--match', type=str, default='string', choices=['regex', 'string'], + help="Answer matching logic type") + parser.add_argument('--n-docs', type=int, default=5, help="Amount of top docs to return") + parser.add_argument('--validation_workers', type=int, default=16, + help="Number of parallel processes to validate results") + parser.add_argument('--batch_size', type=int, default=32, help="Batch size for question encoder forward pass") + parser.add_argument('--index_buffer', type=int, default=50000, + help="Temporal memory data buffer size (in samples) for indexer") + parser.add_argument("--hnsw_index", action='store_true', help='If enabled, use inference time efficient HNSW index') + + args = parser.parse_args() + + assert args.model_file, 'Please specify --model_file checkpoint to init model weights' + + setup_args_gpu(args) + print_args(args) + main(args) diff --git a/dpr/__init__.py b/dpr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dpr/data/__init__.py b/dpr/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dpr/data/qa_validation.py b/dpr/data/qa_validation.py new file mode 100644 index 00000000..f4a63318 --- /dev/null +++ b/dpr/data/qa_validation.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Set of utilities for Q&A results validation tasks - Retriver passage validation and Reader predicted answer validation +""" + +import collections +import logging +import string +import unicodedata +from functools import partial +from multiprocessing import Pool as ProcessPool +from typing import Tuple, List, Dict + +import regex as re + +from dpr.utils.tokenizers import SimpleTokenizer + +logger = logging.getLogger(__name__) + +QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits']) + + +def calculate_matches(all_docs: Dict[object, Tuple[str, str]], answers: List[List[str]], + closest_docs: List[Tuple[List[object], List[float]]], workers_num: int, + match_type: str) -> QAMatchStats: + """ + Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of + documents and results. It internally forks multiple sub-processes for evaluation and then merges results + :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) + :param answers: list of answers's list. One list per question + :param closest_docs: document ids of the top results along with their scores + :param workers_num: amount of parallel threads to process data + :param match_type: type of answer matching. Refer to has_answer code for available options + :return: matching information tuple. + top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of + valid matches across an entire dataset. + questions_doc_hits - more detailed info with answer matches for every question and every retrieved document + """ + global dpr_all_documents + dpr_all_documents = all_docs + + tok_opts = {} + tokenizer = SimpleTokenizer(**tok_opts) + + processes = ProcessPool( + processes=workers_num, + ) + + logger.info('Matching answers in top docs...') + + get_score_partial = partial(check_answer, match_type=match_type, tokenizer=tokenizer) + + questions_answers_docs = zip(answers, closest_docs) + + scores = processes.map(get_score_partial, questions_answers_docs) + + logger.info('Per question validation results len=%d', len(scores)) + + n_docs = len(closest_docs[0][0]) + top_k_hits = [0] * n_docs + for question_hits in scores: + best_hit = next((i for i, x in enumerate(question_hits) if x), None) + if best_hit is not None: + top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] + + return QAMatchStats(top_k_hits, scores) + + +def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]: + """Search through all the top docs to see if they have any of the answers.""" + answers, (doc_ids, doc_scores) = questions_answers_docs + + global dpr_all_documents + hits = [] + + for i, doc_id in enumerate(doc_ids): + doc = dpr_all_documents[doc_id] + text = doc[0] + + answer_found = False + if text is None: # cannot find the document for some reason + logger.warning("no doc in db") + hits.append(False) + continue + + if has_answer(answers, text, tokenizer, match_type): + answer_found = True + hits.append(answer_found) + return hits + + +def has_answer(answers, text, tokenizer, match_type) -> bool: + """Check if a document contains an answer string. + If `match_type` is string, token matching is done between the text and answer. + If `match_type` is regex, we search the whole text with the regex. + """ + text = _normalize(text) + + if match_type == 'string': + # Answer is a list of possible strings + text = tokenizer.tokenize(text).words(uncased=True) + + for single_answer in answers: + single_answer = _normalize(single_answer) + single_answer = tokenizer.tokenize(single_answer) + single_answer = single_answer.words(uncased=True) + + for i in range(0, len(text) - len(single_answer) + 1): + if single_answer == text[i: i + len(single_answer)]: + return True + + elif match_type == 'regex': + # Answer is a regex + for single_answer in answers: + single_answer = _normalize(single_answer) + if regex_match(text, single_answer): + return True + return False + + +def regex_match(text, pattern): + """Test if a regex pattern is contained within a text.""" + try: + pattern = re.compile( + pattern, + flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, + ) + except BaseException: + return False + return pattern.search(text) is not None + + +# function for the reader model answer validation +def exact_match_score(prediction, ground_truth): + return _normalize_answer(prediction) == _normalize_answer(ground_truth) + + +def _normalize_answer(s): + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def _normalize(text): + return unicodedata.normalize('NFD', text) diff --git a/dpr/data/reader_data.py b/dpr/data/reader_data.py new file mode 100644 index 00000000..2d882feb --- /dev/null +++ b/dpr/data/reader_data.py @@ -0,0 +1,422 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Set of utilities for the Reader model related data processing tasks +""" + +import collections +import json +import logging +import math +import multiprocessing +import pickle +from functools import partial +from typing import Tuple, List, Dict, Iterable, Optional + +import torch +from torch import Tensor as T +from tqdm import tqdm + +from dpr.utils.data_utils import Tensorizer + +logger = logging.getLogger() + + +class ReaderPassage(object): + """ + Container to collect and cache all Q&A passages related attributes before generating the reader input + """ + + def __init__(self, id=None, text: str = None, title: str = None, score=None, + has_answer: bool = None): + self.id = id + # string passage representations + self.passage_text = text + self.title = title + self.score = score + self.has_answer = has_answer + self.passage_token_ids = None + # offset of the actual passage (i.e. not a question or may be title) in the sequence_ids + self.passage_offset = None + self.answers_spans = None + # passage token ids + self.sequence_ids = None + + def on_serialize(self): + # store only final sequence_ids and the ctx offset + self.sequence_ids = self.sequence_ids.numpy() + self.passage_text = None + self.title = None + self.passage_token_ids = None + + def on_deserialize(self): + self.sequence_ids = torch.tensor(self.sequence_ids) + + +class ReaderSample(object): + """ + Container to collect all Q&A passages data per singe question + """ + + def __init__(self, question: str, answers: List, positive_passages: List[ReaderPassage] = [], + negative_passages: List[ReaderPassage] = [], + passages: List[ReaderPassage] = [], + ): + self.question = question + self.answers = answers + self.positive_passages = positive_passages + self.negative_passages = negative_passages + self.passages = passages + + def on_serialize(self): + for passage in self.passages + self.positive_passages + self.negative_passages: + passage.on_serialize() + + def on_deserialize(self): + for passage in self.passages + self.positive_passages + self.negative_passages: + passage.on_deserialize() + + +SpanPrediction = collections.namedtuple('SpanPrediction', + ['prediction_text', 'span_score', 'relevance_score', 'passage_index', + 'passage_token_ids']) + +# configuration for reader model passage selection +ReaderPreprocessingCfg = collections.namedtuple('ReaderPreprocessingCfg', + ['use_tailing_sep', 'skip_no_positves', 'include_gold_passage', + 'gold_page_only_positives', 'max_positives', 'max_negatives', + 'min_negatives', 'max_retriever_passages']) + +DEFAULT_PREPROCESSING_CFG_TRAIN = ReaderPreprocessingCfg(use_tailing_sep=False, skip_no_positves=True, + include_gold_passage=False, gold_page_only_positives=True, + max_positives=20, max_negatives=50, min_negatives=150, + max_retriever_passages=200) + +DEFAULT_EVAL_PASSAGES = 100 + + +def preprocess_retriever_data(samples: List[Dict], gold_info_file: Optional[str], tensorizer: Tensorizer, + cfg: ReaderPreprocessingCfg = DEFAULT_PREPROCESSING_CFG_TRAIN, + is_train_set: bool = True, + ) -> Iterable[ReaderSample]: + """ + Converts retriever results into reader training data. + :param samples: samples from the retriever's json file results + :param gold_info_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ + :param tensorizer: Tensorizer object for text to model input tensors conversions + :param cfg: ReaderPreprocessingCfg object with positive and negative passage selection parameters + :param is_train_set: if the data should be processed as a train set + :return: iterable of ReaderSample objects which can be consumed by the reader model + """ + sep_tensor = tensorizer.get_pair_separator_ids() # separator can be a multi token + + gold_passage_map, canonical_questions = _get_gold_ctx_dict(gold_info_file) if gold_info_file else ({}, {}) + + no_positive_passages = 0 + positives_from_gold = 0 + + def create_reader_sample_ids(sample: ReaderPassage, question: str): + question_and_title = tensorizer.text_to_tensor(sample.title, title=question, add_special_tokens=True) + if sample.passage_token_ids is None: + sample.passage_token_ids = tensorizer.text_to_tensor(sample.passage_text, add_special_tokens=False) + + all_concatenated, shift = _concat_pair(question_and_title, sample.passage_token_ids, + tailing_sep=sep_tensor if cfg.use_tailing_sep else None) + + sample.sequence_ids = all_concatenated + sample.passage_offset = shift + assert shift > 1 + if sample.has_answer and is_train_set: + sample.answers_spans = [(span[0] + shift, span[1] + shift) for span in sample.answers_spans] + return sample + + for sample in samples: + question = sample['question'] + + if question in canonical_questions: + question = canonical_questions[question] + + positive_passages, negative_passages = _select_reader_passages(sample, question, + tensorizer, + gold_passage_map, + cfg.gold_page_only_positives, + cfg.max_positives, cfg.max_negatives, + cfg.min_negatives, + cfg.max_retriever_passages, + cfg.include_gold_passage, + is_train_set, + ) + # create concatenated sequence ids for each passage and adjust answer spans + positive_passages = [create_reader_sample_ids(s, question) for s in positive_passages] + negative_passages = [create_reader_sample_ids(s, question) for s in negative_passages] + + if is_train_set and len(positive_passages) == 0: + no_positive_passages += 1 + if cfg.skip_no_positves: + continue + + if next(iter(ctx for ctx in positive_passages if ctx.score == -1), None): + positives_from_gold += 1 + + if is_train_set: + yield ReaderSample(question, sample['answers'], positive_passages=positive_passages, + negative_passages=negative_passages) + else: + yield ReaderSample(question, sample['answers'], passages=negative_passages) + + logger.info('no positive passages samples: %d', no_positive_passages) + logger.info('positive passages from gold samples: %d', positives_from_gold) + + +def convert_retriever_results(is_train_set: bool, input_file: str, out_file_prefix: str, + gold_passages_file: str, + tensorizer: Tensorizer, + num_workers: int = 8) -> List[str]: + """ + Converts the file with dense retriever(or any compatible file format) results into the reader input data and + serializes them into a set of files. + Conversion splits the input data into multiple chunks and processes them in parallel. Each chunk results are stored + in a separate file with name out_file_prefix.{number}.pkl + :param is_train_set: if the data should be processed for a train set (i.e. with answer span detection) + :param input_file: path to a json file with data to convert + :param out_file_prefix: output path prefix. + :param gold_passages_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ + :param tensorizer: Tensorizer object for text to model input tensors conversions + :param num_workers: the number of parallel processes for conversion + :return: names of files with serialized results + """ + with open(input_file, 'r', encoding="utf-8") as f: + samples = json.loads("".join(f.readlines())) + logger.info("Loaded %d questions + retrieval results from %s", len(samples), input_file) + workers = multiprocessing.Pool(num_workers) + ds_size = len(samples) + step = max(math.ceil(ds_size / num_workers), 1) + chunks = [samples[i:i + step] for i in range(0, ds_size, step)] + chunks = [(i, chunks[i]) for i in range(len(chunks))] + + logger.info("Split data into %d chunks", len(chunks)) + + processed = 0 + _parse_batch = partial(_preprocess_reader_samples_chunk, out_file_prefix=out_file_prefix, + gold_passages_file=gold_passages_file, tensorizer=tensorizer, + is_train_set=is_train_set) + serialized_files = [] + for file_name in workers.map(_parse_batch, chunks): + processed += 1 + serialized_files.append(file_name) + logger.info('Chunks processed %d', processed) + logger.info('Data saved to %s', file_name) + logger.info('Preprocessed data stored in %s', serialized_files) + return serialized_files + + +def get_best_spans(tensorizer: Tensorizer, start_logits: List, end_logits: List, ctx_ids: List, max_answer_length: int, + passage_idx: int, relevance_score: float, top_spans: int = 1) -> List[SpanPrediction]: + """ + Finds the best answer span for the extractive Q&A model + """ + scores = [] + for (i, s) in enumerate(start_logits): + for (j, e) in enumerate(end_logits[i:i + max_answer_length]): + scores.append(((i, i + j), s + e)) + + scores = sorted(scores, key=lambda x: x[1], reverse=True) + + chosen_span_intervals = [] + best_spans = [] + + for (start_index, end_index), score in scores: + assert start_index <= end_index + length = end_index - start_index + 1 + assert length <= max_answer_length + + if any([start_index <= prev_start_index <= prev_end_index <= end_index or + prev_start_index <= start_index <= end_index <= prev_end_index + for (prev_start_index, prev_end_index) in chosen_span_intervals]): + continue + + # extend bpe subtokens to full tokens + start_index, end_index = _extend_span_to_full_words(tensorizer, ctx_ids, + (start_index, end_index)) + + predicted_answer = tensorizer.to_string(ctx_ids[start_index:end_index + 1]) + best_spans.append(SpanPrediction(predicted_answer, score, relevance_score, passage_idx, ctx_ids)) + chosen_span_intervals.append((start_index, end_index)) + + if len(chosen_span_intervals) == top_spans: + break + return best_spans + + +def _select_reader_passages(sample: Dict, + question: str, + tensorizer: Tensorizer, gold_passage_map: Dict[str, ReaderPassage], + gold_page_only_positives: bool, + max_positives: int, + max1_negatives: int, + max2_negatives: int, + max_retriever_passages: int, + include_gold_passage: bool, + is_train_set: bool + ) -> Tuple[List[ReaderPassage], List[ReaderPassage]]: + answers = sample['answers'] + + ctxs = [ReaderPassage(**ctx) for ctx in sample['ctxs']][0:max_retriever_passages] + answers_token_ids = [tensorizer.text_to_tensor(a, add_special_tokens=False) for a in answers] + + if is_train_set: + positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs)) + negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs)) + else: + positive_samples = [] + negative_samples = ctxs + + positive_ctxs_from_gold_page = list( + filter(lambda ctx: _is_from_gold_wiki_page(gold_passage_map, ctx.title, question), + positive_samples)) if gold_page_only_positives else [] + + def find_answer_spans(ctx: ReaderPassage): + if ctx.has_answer: + if ctx.passage_token_ids is None: + ctx.passage_token_ids = tensorizer.text_to_tensor(ctx.passage_text, add_special_tokens=False) + + answer_spans = [_find_answer_positions(ctx.passage_token_ids, answers_token_ids[i]) for i in + range(len(answers))] + + # flatten spans list + answer_spans = [item for sublist in answer_spans for item in sublist] + answers_spans = list(filter(None, answer_spans)) + ctx.answers_spans = answers_spans + + if not answers_spans: + logger.warning('No answer found in passage id=%s text=%s, answers=%s, question=%s', ctx.id, + ctx.passage_text, + answers, question) + + ctx.has_answer = bool(answers_spans) + + return ctx + + # check if any of the selected ctx+ has answer spans + selected_positive_ctxs = list( + filter(lambda ctx: ctx.has_answer, [find_answer_spans(ctx) for ctx in positive_ctxs_from_gold_page])) + + if not selected_positive_ctxs: # fallback to positive ctx not from gold pages + selected_positive_ctxs = list( + filter(lambda ctx: ctx.has_answer, [find_answer_spans(ctx) for ctx in positive_samples]) + )[0:max_positives] + + # optionally include gold passage itself if it is still not in the positives list + if include_gold_passage and question in gold_passage_map: + gold_passage = gold_passage_map[question] + included_gold_passage = next(iter(ctx for ctx in selected_positive_ctxs if ctx.id == gold_passage.id), None) + if not included_gold_passage: + gold_passage = find_answer_spans(gold_passage) + if not gold_passage.has_answer: + logger.warning('No answer found in gold passage %s', gold_passage) + else: + selected_positive_ctxs.append(gold_passage) + + max_negatives = min(max(10 * len(selected_positive_ctxs), max1_negatives), + max2_negatives) if is_train_set else DEFAULT_EVAL_PASSAGES + negative_samples = negative_samples[0:max_negatives] + return selected_positive_ctxs, negative_samples + + +def _find_answer_positions(ctx_ids: T, answer: T) -> List[Tuple[int, int]]: + c_len = ctx_ids.size(0) + a_len = answer.size(0) + answer_occurences = [] + for i in range(0, c_len - a_len + 1): + if (answer == ctx_ids[i: i + a_len]).all(): + answer_occurences.append((i, i + a_len - 1)) + return answer_occurences + + +def _concat_pair(t1: T, t2: T, middle_sep: T = None, tailing_sep: T = None): + middle = ([middle_sep] if middle_sep else []) + r = [t1] + middle + [t2] + ([tailing_sep] if tailing_sep else []) + return torch.cat(r, dim=0), t1.size(0) + len(middle) + + +def _get_gold_ctx_dict(file: str) -> Tuple[Dict[str, ReaderPassage], Dict[str, str]]: + gold_passage_infos = {} # question|question_tokens -> ReaderPassage (with title and gold ctx) + + # original NQ dataset has 2 forms of same question - original, and tokenized. + # Tokenized form is not fully consisted with the original question if tokenized by some encoder tokenizers + # Specifically, this is the case for the BERT tokenizer. + # Depending of which form was used for retriever training and results generation, it may be useful to convert + # all questions to the canonical original representation. + original_questions = {} # question from tokens -> original question (NQ only) + + with open(file, 'r', encoding="utf-8") as f: + logger.info('Reading file %s' % file) + data = json.load(f)['data'] + + for sample in data: + question = sample['question'] + question_from_tokens = sample['question_tokens'] if 'question_tokens' in sample else question + original_questions[question_from_tokens] = question + title = sample['title'].lower() + context = sample['context'] # Note: This one is cased + rp = ReaderPassage(sample['example_id'], text=context, title=title) + if question in gold_passage_infos: + logger.info('Duplicate question %s', question) + rp_exist = gold_passage_infos[question] + logger.info('Duplicate question gold info: title new =%s | old title=%s', title, rp_exist.title) + logger.info('Duplicate question gold info: new ctx =%s ', context) + logger.info('Duplicate question gold info: old ctx =%s ', rp_exist.passage_text) + + gold_passage_infos[question] = rp + gold_passage_infos[question_from_tokens] = rp + return gold_passage_infos, original_questions + + +def _is_from_gold_wiki_page(gold_passage_map: Dict[str, ReaderPassage], passage_title: str, question: str): + gold_info = gold_passage_map.get(question, None) + if gold_info: + return passage_title.lower() == gold_info.title.lower() + return False + + +def _extend_span_to_full_words(tensorizer: Tensorizer, tokens: List[int], span: Tuple[int, int]) -> Tuple[int, int]: + start_index, end_index = span + max_len = len(tokens) + while start_index > 0 and tensorizer.is_sub_word_id(tokens[start_index]): + start_index -= 1 + + while end_index < max_len - 1 and tensorizer.is_sub_word_id(tokens[end_index + 1]): + end_index += 1 + + return start_index, end_index + + +def _preprocess_reader_samples_chunk(samples: List, out_file_prefix: str, gold_passages_file: str, + tensorizer: Tensorizer, + is_train_set: bool) -> str: + chunk_id, samples = samples + logger.info('Start batch %d', len(samples)) + iterator = preprocess_retriever_data( + samples, + gold_passages_file, + tensorizer, + is_train_set=is_train_set, + ) + + results = [] + + iterator = tqdm(iterator) + for i, r in enumerate(iterator): + r.on_serialize() + results.append(r) + + out_file = out_file_prefix + '.' + str(chunk_id) + '.pkl' + with open(out_file, mode='wb') as f: + logger.info('Serialize %d results to %s', len(results), out_file) + pickle.dump(results, f) + return out_file diff --git a/dpr/indexer/faiss_indexers.py b/dpr/indexer/faiss_indexers.py new file mode 100644 index 00000000..4178719b --- /dev/null +++ b/dpr/indexer/faiss_indexers.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + FAISS-based index components for dense retriver +""" + +import logging +import pickle +from typing import List, Tuple + +import faiss +import numpy as np + +logger = logging.getLogger() + + +class DenseIndexer(object): + + def __init__(self, buffer_size: int = 50000): + self.buffer_size = buffer_size + self.index_id_to_db_id = [] + self.index = None + + def index_data(self, data: List[Tuple[object, np.array]]): + raise NotImplementedError + + def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: + raise NotImplementedError + + def serialize(self, file: str): + logger.info('Serializing index to %s', file) + + index_file = file + '.index.dpr' + meta_file = file + '.index_meta.dpr' + + faiss.write_index(self.index, index_file) + with open(meta_file, mode='wb') as f: + pickle.dump(self.index_id_to_db_id, f) + + def deserialize_from(self, file: str): + logger.info('Loading index from %s', file) + + index_file = file + '.index.dpr' + meta_file = file + '.index_meta.dpr' + + self.index = faiss.read_index(index_file) + logger.info('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) + + with open(meta_file, "rb") as reader: + self.index_id_to_db_id = pickle.load(reader) + assert len( + self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' + + def _update_id_mapping(self, db_ids: List): + self.index_id_to_db_id.extend(db_ids) + + +class DenseFlatIndexer(DenseIndexer): + + def __init__(self, vector_sz: int, buffer_size: int = 50000): + super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) + self.index = faiss.IndexFlatIP(vector_sz) + + def index_data(self, data: List[Tuple[object, np.array]]): + n = len(data) + # indexing in batches is beneficial for many faiss index types + for i in range(0, n, self.buffer_size): + db_ids = [t[0] for t in data[i:i + self.buffer_size]] + vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] + vectors = np.concatenate(vectors, axis=0) + self._update_id_mapping(db_ids) + self.index.add(vectors) + + indexed_cnt = len(self.index_id_to_db_id) + logger.info('Total data indexed %d', indexed_cnt) + + def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: + scores, indexes = self.index.search(query_vectors, top_docs) + # convert to external ids + db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] + result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] + return result + + +class DenseHNSWFlatIndexer(DenseIndexer): + """ + Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage + """ + + def __init__(self, vector_sz: int, buffer_size: int = 50000, store_n: int = 512 + , ef_search: int = 128, ef_construction: int = 200): + super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) + + # IndexHNSWFlat supports L2 similarity only + # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension + index = faiss.IndexHNSWFlat(vector_sz + 1, store_n) + index.hnsw.efSearch = ef_search + index.hnsw.efConstruction = ef_construction + self.index = index + self.phi = 0 + + def index_data(self, data: List[Tuple[object, np.array]]): + n = len(data) + + # max norm is required before putting all vectors in the index to convert inner product similarity to L2 + if self.phi > 0: + raise RuntimeError('DPR HNSWF index needs to index all data at once,' + 'results will be unpredictable otherwise.') + phi = 0 + for i, item in enumerate(data): + id, doc_vector = item + norms = (doc_vector ** 2).sum() + phi = max(phi, norms) + logger.info('HNSWF DotProduct -> L2 space phi={}'.format(phi)) + self.phi = 0 + + # indexing in batches is beneficial for many faiss index types + for i in range(0, n, self.buffer_size): + db_ids = [t[0] for t in data[i:i + self.buffer_size]] + vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] + + norms = [(doc_vector ** 2).sum() for doc_vector in vectors] + aux_dims = [np.sqrt(phi - norm) for norm in norms] + hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in + enumerate(vectors)] + hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) + + self._update_id_mapping(db_ids) + self.index.add(hnsw_vectors) + logger.info('data indexed %d', len(self.index_id_to_db_id)) + + indexed_cnt = len(self.index_id_to_db_id) + logger.info('Total data indexed %d', indexed_cnt) + + def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: + + aux_dim = np.zeros(len(query_vectors), dtype='float32') + query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) + logger.info('query_hnsw_vectors %s', query_nhsw_vectors.shape) + scores, indexes = self.index.search(query_nhsw_vectors, top_docs) + # convert to external ids + db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] + result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] + return result + + def deserialize_from(self, file: str): + super(DenseHNSWFlatIndexer, self).deserialize_from(file) + # to trigger warning on subsequent indexing + self.phi = 1 diff --git a/dpr/models/__init__.py b/dpr/models/__init__.py new file mode 100644 index 00000000..cebca5c1 --- /dev/null +++ b/dpr/models/__init__.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import importlib + +""" + 'Router'-like set of methods for component initialization with lazy imports +""" + + +def init_hf_bert_biencoder(args, **kwargs): + if importlib.util.find_spec("transformers") is None: + raise RuntimeError('Please install transformers lib') + from .hf_models import get_bert_biencoder_components + return get_bert_biencoder_components(args, **kwargs) + + +def init_hf_bert_reader(args, **kwargs): + if importlib.util.find_spec("transformers") is None: + raise RuntimeError('Please install transformers lib') + from .hf_models import get_bert_reader_components + return get_bert_reader_components(args, **kwargs) + + +def init_pytext_bert_biencoder(args, **kwargs): + if importlib.util.find_spec("pytext") is None: + raise RuntimeError('Please install pytext lib') + from .pytext_models import get_bert_biencoder_components + return get_bert_biencoder_components(args, **kwargs) + + +def init_fairseq_roberta_biencoder(args, **kwargs): + if importlib.util.find_spec("fairseq") is None: + raise RuntimeError('Please install fairseq lib') + from .fairseq_models import get_roberta_biencoder_components + return get_roberta_biencoder_components(args, **kwargs) + + +def init_hf_bert_tenzorizer(args, **kwargs): + if importlib.util.find_spec("transformers") is None: + raise RuntimeError('Please install transformers lib') + from .hf_models import get_bert_tensorizer + return get_bert_tensorizer(args) + + +def init_hf_roberta_tenzorizer(args, **kwargs): + if importlib.util.find_spec("transformers") is None: + raise RuntimeError('Please install transformers lib') + from .hf_models import get_roberta_tensorizer + return get_roberta_tensorizer(args) + + +BIENCODER_INITIALIZERS = { + 'hf_bert': init_hf_bert_biencoder, + 'pytext_bert': init_pytext_bert_biencoder, + 'fairseq_roberta': init_fairseq_roberta_biencoder, +} + +READER_INITIALIZERS = { + 'hf_bert': init_hf_bert_reader, +} + +TENSORIZER_INITIALIZERS = { + 'hf_bert': init_hf_bert_tenzorizer, + 'hf_roberta': init_hf_roberta_tenzorizer, + 'pytext_bert': init_hf_bert_tenzorizer, # using HF's code as of now + 'fairseq_roberta': init_hf_roberta_tenzorizer, # using HF's code as of now +} + + +def init_comp(initializers_dict, type, args, **kwargs): + if type in initializers_dict: + return initializers_dict[type](args, **kwargs) + else: + raise RuntimeError('unsupported model type: {}'.format(type)) + + +def init_biencoder_components(encoder_type: str, args, **kwargs): + return init_comp(BIENCODER_INITIALIZERS, encoder_type, args, **kwargs) + + +def init_reader_components(encoder_type: str, args, **kwargs): + return init_comp(READER_INITIALIZERS, encoder_type, args, **kwargs) + + +def init_tenzorizer(encoder_type: str, args, **kwargs): + return init_comp(TENSORIZER_INITIALIZERS, encoder_type, args, **kwargs) diff --git a/dpr/models/biencoder.py b/dpr/models/biencoder.py new file mode 100644 index 00000000..cc4f50e3 --- /dev/null +++ b/dpr/models/biencoder.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +BiEncoder component + loss function for 'all-in-batch' training +""" + +import collections +import logging +import random +from typing import Tuple, List + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor as T +from torch import nn + +from dpr.utils.data_utils import Tensorizer +from dpr.utils.data_utils import normalize_question + +logger = logging.getLogger(__name__) + +BiEncoderBatch = collections.namedtuple('BiENcoderInput', + ['question_ids', 'question_segments', 'context_ids', 'ctx_segments', + 'is_positive', 'hard_negatives']) + + +def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T: + """ + calculates q->ctx scores for every row in ctx_vector + :param q_vector: + :param ctx_vector: + :return: + """ + # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 + r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1)) + return r + + +def cosine_scores(q_vector: T, ctx_vectors: T): + # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 + return F.cosine_similarity(q_vector, ctx_vectors, dim=1) + + +class BiEncoder(nn.Module): + """ Bi-Encoder model component. Encapsulates query/question and context/passage encoders. + """ + + def __init__(self, question_model: nn.Module, ctx_model: nn.Module, fix_q_encoder: bool = False, + fix_ctx_encoder: bool = False): + super(BiEncoder, self).__init__() + self.question_model = question_model + self.ctx_model = ctx_model + self.fix_q_encoder = fix_q_encoder + self.fix_ctx_encoder = fix_ctx_encoder + + @staticmethod + def get_representation(sub_model: nn.Module, ids: T, segments: T, attn_mask: T, fix_encoder: bool = False) -> ( + T, T, T): + sequence_output = None + pooled_output = None + hidden_states = None + if ids is not None: + if fix_encoder: + with torch.no_grad(): + sequence_output, pooled_output, hidden_states = sub_model(ids, segments, attn_mask) + + if sub_model.training: + sequence_output.requires_grad_(requires_grad=True) + pooled_output.requires_grad_(requires_grad=True) + else: + sequence_output, pooled_output, hidden_states = sub_model(ids, segments, attn_mask) + + return sequence_output, pooled_output, hidden_states + + def forward(self, question_ids: T, question_segments: T, question_attn_mask: T, context_ids: T, ctx_segments: T, + ctx_attn_mask: T) -> Tuple[T, T]: + + _q_seq, q_pooled_out, _q_hidden = self.get_representation(self.question_model, question_ids, question_segments, + question_attn_mask, self.fix_q_encoder) + _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation(self.ctx_model, context_ids, ctx_segments, + ctx_attn_mask, self.fix_ctx_encoder) + + return q_pooled_out, ctx_pooled_out + + @classmethod + def create_biencoder_input(cls, + samples: List, + tensorizer: Tensorizer, + insert_title: bool, + num_hard_negatives: int = 0, + num_other_negatives: int = 0, + shuffle: bool = True, + shuffle_positives: bool = False, + ) -> BiEncoderBatch: + """ + Creates a batch of the biencoder training tuple. + :param samples: list of data items (from json) to create the batch for + :param tensorizer: components to create model input tensors from a text sequence + :param insert_title: enables title insertion at the beginning of the context sequences + :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) + :param num_other_negatives: amount of other negatives per question (taken from samples' pools) + :param shuffle: shuffles negative passages pools + :param shuffle_positives: shuffles positive passages pools + :return: BiEncoderBatch tuple + """ + question_tensors = [] + ctx_tensors = [] + positive_ctx_indices = [] + hard_neg_ctx_indices = [] + + for sample in samples: + # ctx+ & [ctx-] composition + # as of now, take the first(gold) ctx+ only + if shuffle and shuffle_positives: + positive_ctxs = sample['positive_ctxs'] + positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))] + else: + positive_ctx = sample['positive_ctxs'][0] + + neg_ctxs = sample['negative_ctxs'] + hard_neg_ctxs = sample['hard_negative_ctxs'] + question = normalize_question(sample['question']) + + if shuffle: + random.shuffle(neg_ctxs) + random.shuffle(hard_neg_ctxs) + + neg_ctxs = neg_ctxs[0:num_other_negatives] + hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] + + all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs + hard_negatives_start_idx = 1 + hard_negatives_end_idx = 1 + len(hard_neg_ctxs) + + current_ctxs_len = len(ctx_tensors) + + sample_ctxs_tensors = [tensorizer.text_to_tensor(ctx['text'], title=ctx['title'] if insert_title else None) + for + ctx in all_ctxs] + + ctx_tensors.extend(sample_ctxs_tensors) + positive_ctx_indices.append(current_ctxs_len) + hard_neg_ctx_indices.append( + [i for i in + range(current_ctxs_len + hard_negatives_start_idx, current_ctxs_len + hard_negatives_end_idx)]) + + question_tensors.append(tensorizer.text_to_tensor(question)) + + ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) + questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) + + ctx_segments = torch.zeros_like(ctxs_tensor) + question_segments = torch.zeros_like(questions_tensor) + + return BiEncoderBatch(questions_tensor, question_segments, ctxs_tensor, ctx_segments, positive_ctx_indices, + hard_neg_ctx_indices) + + +class BiEncoderNllLoss(object): + + def calc(self, q_vectors: T, ctx_vectors: T, positive_idx_per_question: list, + hard_negatice_idx_per_question: list = None) -> Tuple[T, int]: + """ + Computes nll loss for the given lists of question and ctx vectors. + Note that although hard_negatice_idx_per_question in not currently in use, one can use it for the + loss modifications. For example - weighted NLL with different factors for hard vs regular negatives. + :return: a tuple of loss value and amount of correct predictions per batch + """ + scores = self.get_scores(q_vectors, ctx_vectors) + + if len(q_vectors.size()) > 1: + q_num = q_vectors.size(0) + scores = scores.view(q_num, -1) + + softmax_scores = F.log_softmax(scores, dim=1) + + loss = F.nll_loss(softmax_scores, torch.tensor(positive_idx_per_question).to(softmax_scores.device), + reduction='mean') + + max_score, max_idxs = torch.max(softmax_scores, 1) + correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum() + return loss, correct_predictions_count + + @staticmethod + def get_scores(q_vector: T, ctx_vectors: T) -> T: + f = BiEncoderNllLoss.get_similarity_function() + return f(q_vector, ctx_vectors) + + @staticmethod + def get_similarity_function(): + return dot_product_scores diff --git a/dpr/models/fairseq_models.py b/dpr/models/fairseq_models.py new file mode 100644 index 00000000..dd8a6513 --- /dev/null +++ b/dpr/models/fairseq_models.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Encoder model wrappers based on Fairseq code +""" + +import logging +from typing import Tuple + +from fairseq.models.roberta.hub_interface import RobertaHubInterface +from fairseq.models.roberta.model import RobertaModel as FaiseqRobertaModel +from fairseq.optim.adam import FairseqAdam +from torch import Tensor as T +from torch import nn + +from dpr.models.hf_models import get_roberta_tensorizer +from .biencoder import BiEncoder + +logger = logging.getLogger(__name__) + + +def get_roberta_biencoder_components(args, inference_only: bool = False, **kwargs): + question_encoder = RobertaEncoder.from_pretrained(args.pretrained_file) + ctx_encoder = RobertaEncoder.from_pretrained(args.pretrained_file) + biencoder = BiEncoder(question_encoder, ctx_encoder) + optimizer = get_fairseq_adamw_optimizer(biencoder, args) if not inference_only else None + + tensorizer = get_roberta_tensorizer(args) + + return tensorizer, biencoder, optimizer + + +def get_fairseq_adamw_optimizer(model: nn.Module, args): + setattr(args, 'lr', [args.learning_rate]) + return FairseqAdam(args, model.parameters()).optimizer + + +class RobertaEncoder(nn.Module): + + def __init__(self, fairseq_roberta_hub: RobertaHubInterface): + super(RobertaEncoder, self).__init__() + self.fairseq_roberta = fairseq_roberta_hub + + @classmethod + def from_pretrained(cls, pretrained_dir_path: str): + model = FaiseqRobertaModel.from_pretrained(pretrained_dir_path) + return cls(model) + + def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]: + roberta_out = self.fairseq_roberta.extract_features(input_ids) + cls_out = roberta_out[:, 0, :] + return roberta_out, cls_out, None + + def get_out_size(self): + raise NotImplementedError diff --git a/dpr/models/hf_models.py b/dpr/models/hf_models.py new file mode 100644 index 00000000..cc1350f3 --- /dev/null +++ b/dpr/models/hf_models.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Encoder model wrappers based on HuggingFace code +""" + +import logging +from typing import Tuple + +import torch +from torch import Tensor as T +from torch import nn +from transformers.modeling_bert import BertConfig, BertModel +from transformers.optimization import AdamW +from transformers.tokenization_bert import BertTokenizer +from transformers.tokenization_roberta import RobertaTokenizer + +from dpr.utils.data_utils import Tensorizer +from .biencoder import BiEncoder +from .reader import Reader + +logger = logging.getLogger(__name__) + + +def get_bert_biencoder_components(args, inference_only: bool = False, **kwargs): + dropout = args.dropout if hasattr(args, 'dropout') else 0.0 + question_encoder = HFBertEncoder.init_encoder(args.pretrained_model_cfg, + projection_dim=args.projection_dim, dropout=dropout, **kwargs) + ctx_encoder = HFBertEncoder.init_encoder(args.pretrained_model_cfg, + projection_dim=args.projection_dim, dropout=dropout, **kwargs) + + fix_ctx_encoder = args.fix_ctx_encoder if hasattr(args, 'fix_ctx_encoder') else False + biencoder = BiEncoder(question_encoder, ctx_encoder, fix_ctx_encoder=fix_ctx_encoder) + + optimizer = get_optimizer(biencoder, + learning_rate=args.learning_rate, + adam_eps=args.adam_eps, weight_decay=args.weight_decay, + ) if not inference_only else None + + tensorizer = get_bert_tensorizer(args) + + return tensorizer, biencoder, optimizer + + +def get_bert_reader_components(args, inference_only: bool = False, **kwargs): + dropout = args.dropout if hasattr(args, 'dropout') else 0.0 + encoder = HFBertEncoder.init_encoder(args.pretrained_model_cfg, + projection_dim=args.projection_dim, dropout=dropout) + + hidden_size = encoder.config.hidden_size + reader = Reader(encoder, hidden_size) + + optimizer = get_optimizer(reader, + learning_rate=args.learning_rate, + adam_eps=args.adam_eps, weight_decay=args.weight_decay, + ) if not inference_only else None + + tensorizer = get_bert_tensorizer(args) + return tensorizer, reader, optimizer + + +def get_bert_tensorizer(args, tokenizer=None): + if not tokenizer: + tokenizer = get_bert_tokenizer(args.pretrained_model_cfg, do_lower_case=args.do_lower_case) + return BertTensorizer(tokenizer, args.sequence_length) + + +def get_roberta_tensorizer(args, tokenizer=None): + if not tokenizer: + tokenizer = get_roberta_tokenizer(args.pretrained_model_cfg, do_lower_case=args.do_lower_case) + return RobertaTensorizer(tokenizer, args.sequence_length) + + +def get_optimizer(model: nn.Module, learning_rate: float = 1e-5, adam_eps: float = 1e-8, + weight_decay: float = 0.0, ) -> torch.optim.Optimizer: + no_decay = ['bias', 'LayerNorm.weight'] + + optimizer_grouped_parameters = [ + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + 'weight_decay': weight_decay}, + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps) + return optimizer + + +def get_bert_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True): + return BertTokenizer.from_pretrained(pretrained_cfg_name, do_lower_case=do_lower_case) + + +def get_roberta_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True): + # still uses HF code for tokenizer since they are the same + return RobertaTokenizer.from_pretrained(pretrained_cfg_name, do_lower_case=do_lower_case) + + +class HFBertEncoder(BertModel): + + def __init__(self, config, project_dim: int = 0): + BertModel.__init__(self, config) + assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero' + self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None + self.init_weights() + + @classmethod + def init_encoder(cls, cfg_name: str, projection_dim: int = 0, dropout: float = 0.1, **kwargs) -> BertModel: + cfg = BertConfig.from_pretrained(cfg_name if cfg_name else 'bert-base-uncased') + if dropout != 0: + cfg.attention_probs_dropout_prob = dropout + cfg.hidden_dropout_prob = dropout + return cls.from_pretrained(cfg_name, config=cfg, project_dim=projection_dim, **kwargs) + + def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]: + if self.config.output_hidden_states: + sequence_output, pooled_output, hidden_states = super().forward(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask) + else: + hidden_states = None + sequence_output, pooled_output = super().forward(input_ids=input_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask) + + pooled_output = sequence_output[:, 0, :] + if self.encode_proj: + pooled_output = self.encode_proj(pooled_output) + return sequence_output, pooled_output, hidden_states + + def get_out_size(self): + if self.encode_proj: + return self.encode_proj.out_features + return self.config.hidden_size + + +class BertTensorizer(Tensorizer): + def __init__(self, tokenizer: BertTokenizer, max_length: int, pad_to_max: bool = True): + self.tokenizer = tokenizer + self.max_length = max_length + self.pad_to_max = pad_to_max + + def text_to_tensor(self, text: str, title: str = None, add_special_tokens: bool = True): + text = text.strip() + + # tokenizer automatic padding is explicitly disabled since its inconsistent behavior + if title: + token_ids = self.tokenizer.encode(title, text_pair=text, add_special_tokens=add_special_tokens, + max_length=self.max_length, + pad_to_max_length=False) + else: + token_ids = self.tokenizer.encode(text, add_special_tokens=add_special_tokens, max_length=self.max_length, + pad_to_max_length=False) + + seq_len = self.max_length + if self.pad_to_max and len(token_ids) < seq_len: + token_ids = token_ids + [self.tokenizer.pad_token_id] * (seq_len - len(token_ids)) + if len(token_ids) > seq_len: + token_ids = token_ids[0:seq_len] + token_ids[-1] = self.tokenizer.sep_token_id + + return torch.tensor(token_ids) + + def get_pair_separator_ids(self) -> T: + return torch.tensor([self.tokenizer.sep_token_id]) + + def get_pad_id(self) -> int: + return self.tokenizer.pad_token_type_id + + def get_attn_mask(self, tokens_tensor: T) -> T: + return tokens_tensor != self.get_pad_id() + + def is_sub_word_id(self, token_id: int): + token = self.tokenizer.convert_ids_to_tokens([token_id])[0] + return token.startswith("##") or token.startswith(" ##") + + def to_string(self, token_ids, skip_special_tokens=True): + return self.tokenizer.decode(token_ids, skip_special_tokens=True) + + def set_pad_to_max(self, do_pad: bool): + self.pad_to_max = do_pad + + +class RobertaTensorizer(BertTensorizer): + def __init__(self, tokenizer, max_length: int, pad_to_max: bool = True): + super(RobertaTensorizer, self).__init__(tokenizer, max_length, pad_to_max=pad_to_max) diff --git a/dpr/models/pytext_models.py b/dpr/models/pytext_models.py new file mode 100644 index 00000000..97ccc920 --- /dev/null +++ b/dpr/models/pytext_models.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Encoder model wrappers based on HuggingFace code +""" + +import logging +from typing import Tuple + +import torch +from pytext.models.representations.transformer_sentence_encoder import TransformerSentenceEncoder +from pytext.optimizer.optimizers import AdamW +from torch import Tensor as T +from torch import nn + +from .biencoder import BiEncoder + +logger = logging.getLogger(__name__) + + +def get_bert_biencoder_components(args, inference_only: bool = False): + # since bert tokenizer is the same in HF and pytext/fairseq, just use HF's implementation here for now + from .hf_models import get_tokenizer, BertTensorizer + + tokenizer = get_tokenizer(args.pretrained_model_cfg, do_lower_case=args.do_lower_case) + + question_encoder = PytextBertEncoder.init_encoder(args.pretrained_file, + projection_dim=args.projection_dim, dropout=args.dropout, + vocab_size=tokenizer.vocab_size, + padding_idx=tokenizer.pad_token_type_id + ) + + ctx_encoder = PytextBertEncoder.init_encoder(args.pretrained_file, + projection_dim=args.projection_dim, dropout=args.dropout, + vocab_size=tokenizer.vocab_size, + padding_idx=tokenizer.pad_token_type_id + ) + + biencoder = BiEncoder(question_encoder, ctx_encoder) + + optimizer = get_optimizer(biencoder, + learning_rate=args.learning_rate, + adam_eps=args.adam_eps, weight_decay=args.weight_decay, + ) if not inference_only else None + + tensorizer = BertTensorizer(tokenizer, args.sequence_length) + return tensorizer, biencoder, optimizer + + +def get_optimizer(model: nn.Module, learning_rate: float = 1e-5, adam_eps: float = 1e-8, + weight_decay: float = 0.0) -> torch.optim.Optimizer: + cfg = AdamW.Config() + cfg.lr = learning_rate + cfg.weight_decay = weight_decay + cfg.eps = adam_eps + optimizer = AdamW.from_config(cfg, model) + return optimizer + + +def get_pytext_bert_base_cfg(): + cfg = TransformerSentenceEncoder.Config() + cfg.embedding_dim = 768 + cfg.ffn_embedding_dim = 3072 + cfg.num_encoder_layers = 12 + cfg.num_attention_heads = 12 + cfg.num_segments = 2 + cfg.use_position_embeddings = True + cfg.offset_positions_by_padding = True + cfg.apply_bert_init = True + cfg.encoder_normalize_before = True + cfg.activation_fn = "gelu" + cfg.projection_dim = 0 + cfg.max_seq_len = 512 + cfg.multilingual = False + cfg.freeze_embeddings = False + cfg.n_trans_layers_to_freeze = 0 + cfg.use_torchscript = False + return cfg + + +class PytextBertEncoder(TransformerSentenceEncoder): + + def __init__(self, config: TransformerSentenceEncoder.Config, + padding_idx: int, + vocab_size: int, + projection_dim: int = 0, + *args, + **kwarg + ): + + TransformerSentenceEncoder.__init__(self, config, False, padding_idx, vocab_size, *args, **kwarg) + + assert config.embedding_dim > 0, 'Encoder hidden_size can\'t be zero' + self.encode_proj = nn.Linear(config.embedding_dim, projection_dim) if projection_dim != 0 else None + + @classmethod + def init_encoder(cls, pretrained_file: str = None, projection_dim: int = 0, dropout: float = 0.1, + vocab_size: int = 0, + padding_idx: int = 0, **kwargs): + cfg = get_pytext_bert_base_cfg() + + if dropout != 0: + cfg.dropout = dropout + cfg.attention_dropout = dropout + cfg.activation_dropout = dropout + + encoder = cls(cfg, padding_idx, vocab_size, projection_dim, **kwargs) + + if pretrained_file: + logger.info('Loading pre-trained pytext encoder state from %s', pretrained_file) + state = torch.load(pretrained_file) + encoder.load_state_dict(state) + return encoder + + def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]: + pooled_output = super().forward((input_ids, attention_mask, token_type_ids, None))[0] + if self.encode_proj: + pooled_output = self.encode_proj(pooled_output) + + return None, pooled_output, None + + def get_out_size(self): + if self.encode_proj: + return self.encode_proj.out_features + return self.representation_dim diff --git a/dpr/models/reader.py b/dpr/models/reader.py new file mode 100644 index 00000000..761efc0b --- /dev/null +++ b/dpr/models/reader.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +The reader model code + its utilities (loss computation and input batch tensor generator) +""" + +import collections +import logging +from typing import List + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor as T +from torch.nn import CrossEntropyLoss + +from dpr.data.reader_data import ReaderSample, ReaderPassage +from dpr.utils.model_utils import init_weights + +logger = logging.getLogger() + +ReaderBatch = collections.namedtuple('ReaderBatch', ['input_ids', 'start_positions', 'end_positions', 'answers_mask']) + + +class Reader(nn.Module): + + def __init__(self, encoder: nn.Module, hidden_size): + super(Reader, self).__init__() + self.encoder = encoder + self.qa_outputs = nn.Linear(hidden_size, 2) + self.qa_classifier = nn.Linear(hidden_size, 1) + init_weights([self.qa_outputs, self.qa_classifier]) + + def forward(self, input_ids: T, attention_mask: T, start_positions=None, end_positions=None, answer_mask=None): + # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length + N, M, L = input_ids.size() + start_logits, end_logits, relevance_logits = self._forward(input_ids.view(N * M, L), + attention_mask.view(N * M, L)) + if self.training: + return compute_loss(start_positions, end_positions, answer_mask, start_logits, end_logits, relevance_logits, + N, M) + + return start_logits.view(N, M, L), end_logits.view(N, M, L), relevance_logits.view(N, M) + + def _forward(self, input_ids, attention_mask): + # TODO: provide segment values + sequence_output, _pooled_output, _hidden_states = self.encoder(input_ids, None, attention_mask) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + rank_logits = self.qa_classifier(sequence_output[:, 0, :]) + return start_logits, end_logits, rank_logits + + +def compute_loss(start_positions, end_positions, answer_mask, start_logits, end_logits, relevance_logits, N, M): + start_positions = start_positions.view(N * M, -1) + end_positions = end_positions.view(N * M, -1) + answer_mask = answer_mask.view(N * M, -1) + + start_logits = start_logits.view(N * M, -1) + end_logits = end_logits.view(N * M, -1) + relevance_logits = relevance_logits.view(N * M) + + answer_mask = answer_mask.type(torch.FloatTensor).cuda() + + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + loss_fct = CrossEntropyLoss(reduce=False, ignore_index=ignored_index) + + # compute switch loss + relevance_logits = relevance_logits.view(N, M) + switch_labels = torch.zeros(N, dtype=torch.long).cuda() + switch_loss = torch.sum(loss_fct(relevance_logits, switch_labels)) + + # compute span loss + start_losses = [(loss_fct(start_logits, _start_positions) * _span_mask) + for (_start_positions, _span_mask) + in zip(torch.unbind(start_positions, dim=1), torch.unbind(answer_mask, dim=1))] + + end_losses = [(loss_fct(end_logits, _end_positions) * _span_mask) + for (_end_positions, _span_mask) + in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_mask, dim=1))] + loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + \ + torch.cat([t.unsqueeze(1) for t in end_losses], dim=1) + + loss_tensor = loss_tensor.view(N, M, -1).max(dim=1)[0] + span_loss = _calc_mml(loss_tensor) + return span_loss + switch_loss + + +def create_reader_input(pad_token_id: int, + samples: List[ReaderSample], + passages_per_question: int, + max_length: int, + max_n_answers: int, + is_train: bool, + shuffle: bool, + ) -> ReaderBatch: + """ + Creates a reader batch instance out of a list of ReaderSample-s + :param pad_token_id: id of the padding token + :param samples: list of samples to create the batch for + :param passages_per_question: amount of passages for every question in a batch + :param max_length: max model input sequence length + :param max_n_answers: max num of answers per single question + :param is_train: if the samples are for a train set + :param shuffle: should passages selection be randomized + :return: ReaderBatch instance + """ + input_ids = [] + start_positions = [] + end_positions = [] + answers_masks = [] + empty_sequence = torch.Tensor().new_full((max_length,), pad_token_id, dtype=torch.long) + + for sample in samples: + positive_ctxs = sample.positive_passages + negative_ctxs = sample.negative_passages if is_train else sample.passages + + sample_tensors = _create_question_passages_tensors(positive_ctxs, + negative_ctxs, + passages_per_question, + empty_sequence, + max_n_answers, + pad_token_id, + is_train, + is_random=shuffle) + if not sample_tensors: + logger.warning('No valid passages combination for question=%s ', sample.question) + continue + sample_input_ids, starts_tensor, ends_tensor, answer_mask = sample_tensors + input_ids.append(sample_input_ids) + if is_train: + start_positions.append(starts_tensor) + end_positions.append(ends_tensor) + answers_masks.append(answer_mask) + input_ids = torch.cat([ids.unsqueeze(0) for ids in input_ids], dim=0) + + if is_train: + start_positions = torch.stack(start_positions, dim=0) + end_positions = torch.stack(end_positions, dim=0) + answers_masks = torch.stack(answers_masks, dim=0) + + return ReaderBatch(input_ids, start_positions, end_positions, answers_masks) + + +def _calc_mml(loss_tensor): + marginal_likelihood = torch.sum(torch.exp( + - loss_tensor - 1e10 * (loss_tensor == 0).float()), 1) + return -torch.sum(torch.log(marginal_likelihood + + torch.ones(loss_tensor.size(0)).cuda() * (marginal_likelihood == 0).float())) + + +def _pad_to_len(seq: T, pad_id: int, max_len: int): + s_len = seq.size(0) + if s_len > max_len: + return seq[0: max_len] + return torch.cat([seq, torch.Tensor().new_full((max_len - s_len,), pad_id, dtype=torch.long)], dim=0) + + +def _get_answer_spans(idx, positives: List[ReaderPassage], max_len: int): + positive_a_spans = positives[idx].answers_spans + return [span for span in positive_a_spans if (span[0] < max_len and span[1] < max_len)] + + +def _get_positive_idx(positives: List[ReaderPassage], max_len: int, is_random: bool): + # select just one positive + positive_idx = np.random.choice(len(positives)) if is_random else 0 + + if not _get_answer_spans(positive_idx, positives, max_len): + # question may be too long, find the first positive with at least one valid span + positive_idx = next((i for i in range(len(positives)) if _get_answer_spans(i, positives, max_len)), + None) + return positive_idx + + +def _create_question_passages_tensors(positives: List[ReaderPassage], negatives: List[ReaderPassage], total_size: int, + empty_ids: T, + max_n_answers: int, + pad_token_id: int, + is_train: bool, + is_random: bool = True): + max_len = empty_ids.size(0) + if is_train: + # select just one positive + positive_idx = _get_positive_idx(positives, max_len, is_random) + if positive_idx is None: + return None + + positive_a_spans = _get_answer_spans(positive_idx, positives, max_len)[0: max_n_answers] + + answer_starts = [span[0] for span in positive_a_spans] + answer_ends = [span[1] for span in positive_a_spans] + + assert all(s < max_len for s in answer_starts) + assert all(e < max_len for e in answer_ends) + + positive_input_ids = _pad_to_len(positives[positive_idx].sequence_ids, pad_token_id, max_len) + + answer_starts_tensor = torch.zeros((total_size, max_n_answers)).long() + answer_starts_tensor[0, 0:len(answer_starts)] = torch.tensor(answer_starts) + + answer_ends_tensor = torch.zeros((total_size, max_n_answers)).long() + answer_ends_tensor[0, 0:len(answer_ends)] = torch.tensor(answer_ends) + + answer_mask = torch.zeros((total_size, max_n_answers), dtype=torch.long) + answer_mask[0, 0:len(answer_starts)] = torch.tensor([1 for _ in range(len(answer_starts))]) + + positives_selected = [positive_input_ids] + + else: + positives_selected = [] + answer_starts_tensor = None + answer_ends_tensor = None + answer_mask = None + + positives_num = len(positives_selected) + negative_idxs = np.random.permutation(range(len(negatives))) if is_random else range( + len(negatives) - positives_num) + + negative_idxs = negative_idxs[:total_size - positives_num] + + negatives_selected = [_pad_to_len(negatives[i].sequence_ids, pad_token_id, max_len) for i in negative_idxs] + + while len(negatives_selected) < total_size - positives_num: + negatives_selected.append(empty_ids.clone()) + + input_ids = torch.stack([t for t in positives_selected + negatives_selected], dim=0) + return input_ids, answer_starts_tensor, answer_ends_tensor, answer_mask diff --git a/dpr/options.py b/dpr/options.py new file mode 100644 index 00000000..714dc5a4 --- /dev/null +++ b/dpr/options.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Command line arguments utils +""" + +import argparse +import logging +import os +import random +import socket + +import numpy as np +import torch + +logger = logging.getLogger() + + +def add_tokenizer_params(parser: argparse.ArgumentParser): + parser.add_argument("--do_lower_case", action='store_true', + help="Whether to lower case the input text. True for uncased models, False for cased models.") + + +def add_encoder_params(parser: argparse.ArgumentParser): + """ + Common parameters to initialize an encoder-based model + """ + parser.add_argument("--pretrained_model_cfg", default=None, type=str, help="config name for model initialization") + parser.add_argument("--encoder_model_type", default=None, type=str, + help="model type. One of [hf_bert, pytext_bert, fairseq_roberta]") + parser.add_argument('--pretrained_file', type=str, help="Some encoders need to be initialized from a file") + parser.add_argument("--model_file", default=None, type=str, + help="Saved bi-encoder checkpoint file to initialize the model") + parser.add_argument("--projection_dim", default=0, type=int, + help="Extra linear layer on top of standard bert/roberta encoder") + parser.add_argument("--sequence_length", type=int, default=512, help="Max length of the encoder input sequence") + + +def add_training_params(parser: argparse.ArgumentParser): + """ + Common parameters for training + """ + add_cuda_params(parser) + parser.add_argument("--train_file", default=None, type=str, help="File pattern for the train set") + parser.add_argument("--dev_file", default=None, type=str, help="") + + parser.add_argument("--batch_size", default=2, type=int, help="Amount of questions per batch") + parser.add_argument("--dev_batch_size", type=int, default=4, + help="amount of questions per batch for dev set validation") + parser.add_argument('--seed', type=int, default=0, help="random seed for initialization and dataset shuffling") + + parser.add_argument("--adam_eps", default=1e-8, type=float, help="Epsilon for Adam optimizer.") + parser.add_argument("--adam_betas", default='(0.9, 0.999)', type=str, help="Betas for Adam optimizer.") + + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--log_batch_step", default=100, type=int, help="") + parser.add_argument("--train_rolling_loss_step", default=100, type=int, help="") + parser.add_argument("--weight_decay", default=0.0, type=float, help="") + parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.") + + parser.add_argument("--warmup_steps", default=100, type=int, help="Linear warmup over warmup_steps.") + parser.add_argument("--dropout", default=0.1, type=float, help="") + + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--num_train_epochs", default=3.0, type=float, + help="Total number of training epochs to perform.") + + +def add_cuda_params(parser: argparse.ArgumentParser): + parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit float precision instead of 32-bit") + + parser.add_argument('--fp16_opt_level', type=str, default='O1', + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html") + + +def add_reader_preprocessing_params(parser: argparse.ArgumentParser): + parser.add_argument("--gold_passages_src", type=str, + help="File with the original dataset passages (json format). Required for train set") + parser.add_argument("--gold_passages_src_dev", type=str, + help="File with the original dataset passages (json format). Required for dev set") + parser.add_argument("--num_workers", type=int, default=16, + help="number of parallel processes to binarize reader data") + + +def get_encoder_checkpoint_params_names(): + return ['do_lower_case', 'pretrained_model_cfg', 'encoder_model_type', + 'pretrained_file', + 'projection_dim', 'sequence_length'] + + +def get_encoder_params_state(args): + """ + Selects the param values to be saved in a checkpoint, so that a trained model faile can be used for downstream + tasks without the need to specify these parameter again + :return: Dict of params to memorize in a checkpoint + """ + params_to_save = get_encoder_checkpoint_params_names() + + r = {} + for param in params_to_save: + r[param] = getattr(args, param) + return r + + +def set_encoder_params_from_state(state, args): + if not state: + return + params_to_save = get_encoder_checkpoint_params_names() + + override_params = [(param, state[param]) for param in params_to_save if param in state and state[param]] + for param, value in override_params: + if hasattr(args, param): + logger.warning('Overriding args parameter value from checkpoint state. Param = %s, value = %s', param, + value) + setattr(args, param, value) + return args + + +def set_seed(args): + seed = args.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(seed) + + +def setup_args_gpu(args): + """ + Setup arguments CUDA, GPU & distributed training + """ + + if args.local_rank == -1 or args.no_cuda: # single-node multi-gpu (or cpu) mode + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # distributed mode + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend="nccl") + args.n_gpu = 1 + args.device = device + ws = os.environ.get('WORLD_SIZE') + + args.distributed_world_size = int(ws) if ws else 1 + + logger.info( + 'Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d', socket.gethostname(), + args.local_rank, device, + args.n_gpu, + args.distributed_world_size) + logger.info("16-bits training: %s ", args.fp16) + + +def print_args(args): + logger.info(" **************** CONFIGURATION **************** ") + for key, val in sorted(vars(args).items()): + keystr = "{}".format(key) + (" " * (30 - len(key))) + logger.info("%s --> %s", keystr, val) + logger.info(" **************** CONFIGURATION **************** ") diff --git a/dpr/utils/__init__.py b/dpr/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dpr/utils/data_utils.py b/dpr/utils/data_utils.py new file mode 100644 index 00000000..ea9fccaa --- /dev/null +++ b/dpr/utils/data_utils.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for general purpose data processing +""" + +import json +import logging +import math +import pickle +import random +from typing import List, Iterator, Callable + +from torch import Tensor as T + +logger = logging.getLogger() + + +def read_serialized_data_from_files(paths: List[str]) -> List: + results = [] + for i, path in enumerate(paths): + with open(path, "rb") as reader: + logger.info('Reading file %s', path) + data = pickle.load(reader) + results.extend(data) + logger.info('Aggregated data size: {}'.format(len(results))) + logger.info('Total data size: {}'.format(len(results))) + return results + + +def read_data_from_json_files(paths: List[str], upsample_rates: List = None) -> List: + results = [] + if upsample_rates is None: + upsample_rates = [1] * len(paths) + + assert len(upsample_rates) == len(paths), 'up-sample rates parameter doesn\'t match input files amount' + + for i, path in enumerate(paths): + with open(path, 'r', encoding="utf-8") as f: + logger.info('Reading file %s' % path) + data = json.load(f) + upsample_factor = int(upsample_rates[i]) + data = data * upsample_factor + results.extend(data) + logger.info('Aggregated data size: {}'.format(len(results))) + return results + + +class ShardedDataIterator(object): + """ + General purpose data iterator to be used for Pytorch's DDP mode where every node should handle its own part of + the data. + Instead of cutting data shards by their min size, it sets the amount of iterations by the maximum shard size. + It fills the extra sample by just taking first samples in a shard. + It can also optionally enforce identical batch size for all iterations (might be useful for DP mode). + """ + def __init__(self, data: list, shard_id: int = 0, num_shards: int = 1, batch_size: int = 1, shuffle=True, + shuffle_seed: int = 0, offset: int = 0, + strict_batch_size: bool = False + ): + + self.data = data + total_size = len(data) + + self.shards_num = max(num_shards, 1) + self.shard_id = max(shard_id, 0) + + samples_per_shard = math.ceil(total_size / self.shards_num) + + self.shard_start_idx = self.shard_id * samples_per_shard + + self.shard_end_idx = min(self.shard_start_idx + samples_per_shard, total_size) + + if strict_batch_size: + self.max_iterations = math.ceil(samples_per_shard / batch_size) + else: + self.max_iterations = int(samples_per_shard / batch_size) + + logger.debug( + 'samples_per_shard=%d, shard_start_idx=%d, shard_end_idx=%d, max_iterations=%d', samples_per_shard, + self.shard_start_idx, + self.shard_end_idx, + self.max_iterations) + + self.iteration = offset # to track in-shard iteration status + self.shuffle = shuffle + self.batch_size = batch_size + self.shuffle_seed = shuffle_seed + self.strict_batch_size = strict_batch_size + + def total_data_len(self) -> int: + return len(self.data) + + def iterate_data(self, epoch: int = 0) -> Iterator[List]: + if self.shuffle: + # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration + epoch_rnd = random.Random(self.shuffle_seed + epoch) + epoch_rnd.shuffle(self.data) + + # if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations + + max_iterations = self.max_iterations - self.iteration + + shard_samples = self.data[self.shard_start_idx:self.shard_end_idx] + for i in range(self.iteration * self.batch_size, len(shard_samples), self.batch_size): + items = shard_samples[i:i + self.batch_size] + if self.strict_batch_size and len(items) < self.batch_size: + logger.debug('Extending batch to max size') + items.extend(shard_samples[0:self.batch_size - len(items)]) + self.iteration += 1 + yield items + + # some shards may done iterating while the others are at the last batch. Just return the first batch + while self.iteration < max_iterations: + logger.debug('Fulfilling non complete shard='.format(self.shard_id)) + self.iteration += 1 + batch = shard_samples[0:self.batch_size] + yield batch + + logger.debug('Finished iterating, iteration={}, shard={}'.format(self.iteration, self.shard_id)) + # reset the iteration status + self.iteration = 0 + + def get_iteration(self) -> int: + return self.iteration + + def apply(self, visitor_func: Callable): + for sample in self.data[self.shard_start_idx:self.shard_end_idx]: + visitor_func(sample) + + +def normalize_question(question: str) -> str: + if question[-1] == '?': + question = question[:-1] + return question + + +class Tensorizer(object): + """ + Component for all text to model input data conversions and related utility methods + """ + + # Note: title, if present, is supposed to be put before text (i.e. optional title + document body) + def text_to_tensor(self, text: str, title: str = None, add_special_tokens: bool = True): + raise NotImplementedError + + def get_pair_separator_ids(self) -> T: + raise NotImplementedError + + def get_pad_id(self) -> int: + raise NotImplementedError + + def get_attn_mask(self, tokens_tensor: T): + raise NotImplementedError + + def is_sub_word_id(self, token_id: int): + raise NotImplementedError + + def to_string(self, token_ids, skip_special_tokens=True): + raise NotImplementedError + + def set_pad_to_max(self, pad: bool): + raise NotImplementedError diff --git a/dpr/utils/dist_utils.py b/dpr/utils/dist_utils.py new file mode 100644 index 00000000..3b0bf85c --- /dev/null +++ b/dpr/utils/dist_utils.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for distributed model training +""" + +import pickle + +import torch +import torch.distributed as dist + + +def get_rank(): + return dist.get_rank() + + +def get_world_size(): + return dist.get_world_size() + + +def get_default_group(): + return dist.group.WORLD + + +def all_reduce(tensor, group=None): + if group is None: + group = get_default_group() + return dist.all_reduce(tensor, group=group) + + +def all_gather_list(data, group=None, max_size=16384): + """Gathers arbitrary data from all nodes into a list. + Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python + data. Note that *data* must be picklable. + Args: + data (Any): data from the local worker to be gathered on other workers + group (optional): group of the collective + """ + SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size + + enc = pickle.dumps(data) + enc_size = len(enc) + + if enc_size + SIZE_STORAGE_BYTES > max_size: + raise ValueError( + 'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size)) + + rank = get_rank() + world_size = get_world_size() + buffer_size = max_size * world_size + + if not hasattr(all_gather_list, '_buffer') or \ + all_gather_list._buffer.numel() < buffer_size: + all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) + all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() + + buffer = all_gather_list._buffer + buffer.zero_() + cpu_buffer = all_gather_list._cpu_buffer + + assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format( + 256 ** SIZE_STORAGE_BYTES) + + size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big') + + cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) + cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) + + start = rank * max_size + size = enc_size + SIZE_STORAGE_BYTES + buffer[start: start + size].copy_(cpu_buffer[:size]) + + all_reduce(buffer, group=group) + + try: + result = [] + for i in range(world_size): + out_buffer = buffer[i * max_size: (i + 1) * max_size] + size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big') + if size > 0: + result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist()))) + return result + except pickle.UnpicklingError: + raise Exception( + 'Unable to unpickle data from other workers. all_gather_list requires all ' + 'workers to enter the function together, so this error usually indicates ' + 'that the workers have fallen out of sync somehow. Workers can fall out of ' + 'sync if one of them runs out of memory, or if there are other conditions ' + 'in your training script that can cause one worker to finish an epoch ' + 'while other workers are still iterating over their portions of the data.' + ) diff --git a/dpr/utils/model_utils.py b/dpr/utils/model_utils.py new file mode 100644 index 00000000..a25533ea --- /dev/null +++ b/dpr/utils/model_utils.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import glob +import logging +import os +from typing import List + +import torch +from torch import nn +from torch.optim.lr_scheduler import LambdaLR +from torch.serialization import default_restore_location + +logger = logging.getLogger() + +CheckpointState = collections.namedtuple("CheckpointState", + ['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', + 'encoder_params']) + + +def setup_for_distributed_mode(model: nn.Module, optimizer: torch.optim.Optimizer, device: object, n_gpu: int = 1, + local_rank: int = -1, + fp16: bool = False, + fp16_opt_level: str = "O1") -> (nn.Module, torch.optim.Optimizer): + model.to(device) + if fp16: + try: + import apex + from apex import amp + apex.amp.register_half_function(torch, "einsum") + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + + model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) + + if n_gpu > 1: + model = torch.nn.DataParallel(model) + + if local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=True) + return model, optimizer + + +def move_to_cuda(sample): + if len(sample) == 0: + return {} + + def _move_to_cuda(maybe_tensor): + if torch.is_tensor(maybe_tensor): + return maybe_tensor.cuda() + elif isinstance(maybe_tensor, dict): + return { + key: _move_to_cuda(value) + for key, value in maybe_tensor.items() + } + elif isinstance(maybe_tensor, list): + return [_move_to_cuda(x) for x in maybe_tensor] + elif isinstance(maybe_tensor, tuple): + return [_move_to_cuda(x) for x in maybe_tensor] + else: + return maybe_tensor + + return _move_to_cuda(sample) + + +def move_to_device(sample, device): + if len(sample) == 0: + return {} + + def _move_to_device(maybe_tensor, device): + if torch.is_tensor(maybe_tensor): + return maybe_tensor.to(device) + elif isinstance(maybe_tensor, dict): + return { + key: _move_to_device(value, device) + for key, value in maybe_tensor.items() + } + elif isinstance(maybe_tensor, list): + return [_move_to_device(x, device) for x in maybe_tensor] + elif isinstance(maybe_tensor, tuple): + return [_move_to_device(x, device) for x in maybe_tensor] + else: + return maybe_tensor + + return _move_to_device(sample, device) + + +def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1): + """ Create a schedule with a learning rate that decreases linearly after + linearly increasing during a warmup period. + """ + + def lr_lambda(current_step): + if current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + return max( + 0.0, float(training_steps - current_step) / float(max(1, training_steps - warmup_steps)) + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def init_weights(modules: List): + for module in modules: + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +def get_model_obj(model: nn.Module): + return model.module if hasattr(model, 'module') else model + + +def get_model_file(args, file_prefix) -> str: + out_cp_files = glob.glob(os.path.join(args.output_dir, file_prefix + '*')) if args.output_dir else [] + logger.info('Checkpoint files %s', out_cp_files) + model_file = None + + if args.model_file and os.path.exists(args.model_file): + model_file = args.model_file + elif len(out_cp_files) > 0: + model_file = max(out_cp_files, key=os.path.getctime) + return model_file + + +def load_states_from_checkpoint(model_file: str) -> CheckpointState: + logger.info('Reading saved model from %s', model_file) + state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu')) + logger.info('model_state_dict keys %s', state_dict.keys()) + return CheckpointState(**state_dict) diff --git a/dpr/utils/tokenizers.py b/dpr/utils/tokenizers.py new file mode 100644 index 00000000..a5234a52 --- /dev/null +++ b/dpr/utils/tokenizers.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +""" +Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency +""" + +import copy +import logging + +import regex +import spacy + +logger = logging.getLogger(__name__) + + +class Tokens(object): + """A class to represent a list of tokenized text.""" + TEXT = 0 + TEXT_WS = 1 + SPAN = 2 + POS = 3 + LEMMA = 4 + NER = 5 + + def __init__(self, data, annotators, opts=None): + self.data = data + self.annotators = annotators + self.opts = opts or {} + + def __len__(self): + """The number of tokens.""" + return len(self.data) + + def slice(self, i=None, j=None): + """Return a view of the list of tokens from [i, j).""" + new_tokens = copy.copy(self) + new_tokens.data = self.data[i: j] + return new_tokens + + def untokenize(self): + """Returns the original text (with whitespace reinserted).""" + return ''.join([t[self.TEXT_WS] for t in self.data]).strip() + + def words(self, uncased=False): + """Returns a list of the text of each token + + Args: + uncased: lower cases text + """ + if uncased: + return [t[self.TEXT].lower() for t in self.data] + else: + return [t[self.TEXT] for t in self.data] + + def offsets(self): + """Returns a list of [start, end) character offsets of each token.""" + return [t[self.SPAN] for t in self.data] + + def pos(self): + """Returns a list of part-of-speech tags of each token. + Returns None if this annotation was not included. + """ + if 'pos' not in self.annotators: + return None + return [t[self.POS] for t in self.data] + + def lemmas(self): + """Returns a list of the lemmatized text of each token. + Returns None if this annotation was not included. + """ + if 'lemma' not in self.annotators: + return None + return [t[self.LEMMA] for t in self.data] + + def entities(self): + """Returns a list of named-entity-recognition tags of each token. + Returns None if this annotation was not included. + """ + if 'ner' not in self.annotators: + return None + return [t[self.NER] for t in self.data] + + def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): + """Returns a list of all ngrams from length 1 to n. + + Args: + n: upper limit of ngram length + uncased: lower cases text + filter_fn: user function that takes in an ngram list and returns + True or False to keep or not keep the ngram + as_string: return the ngram as a string vs list + """ + + def _skip(gram): + if not filter_fn: + return False + return filter_fn(gram) + + words = self.words(uncased) + ngrams = [(s, e + 1) + for s in range(len(words)) + for e in range(s, min(s + n, len(words))) + if not _skip(words[s:e + 1])] + + # Concatenate into strings + if as_strings: + ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] + + return ngrams + + def entity_groups(self): + """Group consecutive entity tokens with the same NER tag.""" + entities = self.entities() + if not entities: + return None + non_ent = self.opts.get('non_ent', 'O') + groups = [] + idx = 0 + while idx < len(entities): + ner_tag = entities[idx] + # Check for entity tag + if ner_tag != non_ent: + # Chomp the sequence + start = idx + while (idx < len(entities) and entities[idx] == ner_tag): + idx += 1 + groups.append((self.slice(start, idx).untokenize(), ner_tag)) + else: + idx += 1 + return groups + + +class Tokenizer(object): + """Base tokenizer class. + Tokenizers implement tokenize, which should return a Tokens class. + """ + + def tokenize(self, text): + raise NotImplementedError + + def shutdown(self): + pass + + def __del__(self): + self.shutdown() + + +class SimpleTokenizer(Tokenizer): + ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' + NON_WS = r'[^\p{Z}\p{C}]' + + def __init__(self, **kwargs): + """ + Args: + annotators: None or empty set (only tokenizes). + """ + self._regexp = regex.compile( + '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), + flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE + ) + if len(kwargs.get('annotators', {})) > 0: + logger.warning('%s only tokenizes! Skipping annotators: %s' % + (type(self).__name__, kwargs.get('annotators'))) + self.annotators = set() + + def tokenize(self, text): + data = [] + matches = [m for m in self._regexp.finditer(text)] + for i in range(len(matches)): + # Get text + token = matches[i].group() + + # Get whitespace + span = matches[i].span() + start_ws = span[0] + if i + 1 < len(matches): + end_ws = matches[i + 1].span()[0] + else: + end_ws = span[1] + + # Format data + data.append(( + token, + text[start_ws: end_ws], + span, + )) + return Tokens(data, self.annotators) + + +class SpacyTokenizer(Tokenizer): + + def __init__(self, **kwargs): + """ + Args: + annotators: set that can include pos, lemma, and ner. + model: spaCy model to use (either path, or keyword like 'en'). + """ + model = kwargs.get('model', 'en') + self.annotators = copy.deepcopy(kwargs.get('annotators', set())) + nlp_kwargs = {'parser': False} + if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): + nlp_kwargs['tagger'] = False + if 'ner' not in self.annotators: + nlp_kwargs['entity'] = False + self.nlp = spacy.load(model, **nlp_kwargs) + + def tokenize(self, text): + # We don't treat new lines as tokens. + clean_text = text.replace('\n', ' ') + tokens = self.nlp.tokenizer(clean_text) + if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): + self.nlp.tagger(tokens) + if 'ner' in self.annotators: + self.nlp.entity(tokens) + + data = [] + for i in range(len(tokens)): + # Get whitespace + start_ws = tokens[i].idx + if i + 1 < len(tokens): + end_ws = tokens[i + 1].idx + else: + end_ws = tokens[i].idx + len(tokens[i].text) + + data.append(( + tokens[i].text, + text[start_ws: end_ws], + (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), + tokens[i].tag_, + tokens[i].lemma_, + tokens[i].ent_type_, + )) + + # Set special option for non-entity tag: '' vs 'O' in spaCy + return Tokens(data, self.annotators, opts={'non_ent': ''}) diff --git a/generate_dense_embeddings.py b/generate_dense_embeddings.py new file mode 100644 index 00000000..8210f113 --- /dev/null +++ b/generate_dense_embeddings.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Command line tool that produces embeddings for a large documents base based on the pretrained ctx & question encoders + Supposed to be used in a 'sharded' way to speed up the process. +""" +import argparse +import csv +import logging +import pickle +from typing import List, Tuple + +import numpy as np +import torch +from torch import nn + +from dpr.models import init_biencoder_components +from dpr.options import add_encoder_params, setup_args_gpu, print_args, set_encoder_params_from_state, \ + add_tokenizer_params, add_cuda_params +from dpr.utils.data_utils import Tensorizer +from dpr.utils.model_utils import setup_for_distributed_mode, get_model_obj, load_states_from_checkpoint + +logger = logging.getLogger() + + +def gen_ctx_vectors(ctx_rows: List[Tuple[object, str, str]], model: nn.Module, tensorizer: Tensorizer, + insert_title: bool = True) -> List[Tuple[object, np.array]]: + n = len(ctx_rows) + bsz = args.batch_size + total = 0 + results = [] + for j, batch_start in enumerate(range(0, n, bsz)): + + batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in + ctx_rows[batch_start:batch_start + bsz]] + + ctx_ids_batch = torch.stack(batch_token_tensors, dim=0) + ctx_seg_batch = torch.zeros_like(ctx_ids_batch) + ctx_attn_mask = tensorizer.get_attn_mask(ctx_ids_batch) + with torch.no_grad(): + _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) + out = out.cpu() + + ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]] + + assert len(ctx_ids) == out.size(0) + + total += len(ctx_ids) + + results.extend([ + (ctx_ids[i], out[i].view(-1).numpy()) + for i in range(out.size(0)) + ]) + + if total % 10 == 0: + logger.info('Encoded passages %d', total) + + return results + + +def main(args): + saved_state = load_states_from_checkpoint(args.model_file) + set_encoder_params_from_state(saved_state.encoder_params, args) + + tensorizer, encoder, _ = init_biencoder_components(args.encoder_model_type, args, inference_only=True) + + encoder = encoder.ctx_model + + encoder, _ = setup_for_distributed_mode(encoder, None, args.device, args.n_gpu, + args.local_rank, + args.fp16, + args.fp16_opt_level) + encoder.eval() + + # load weights from the model file + model_to_load = get_model_obj(encoder) + logger.info('Loading saved model state ...') + logger.debug('saved model keys =%s', saved_state.model_dict.keys()) + + prefix_len = len('ctx_model.') + ctx_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if + key.startswith('ctx_model.')} + model_to_load.load_state_dict(ctx_state) + + logger.info('reading data from file=%s', args.ctx_file) + + rows = [] + with open(args.ctx_file) as tsvfile: + reader = csv.reader(tsvfile, delimiter='\t') + # file format: doc_id, doc_text, title + rows.extend([(row[0], row[1], row[2]) for row in reader if row[0] != 'id']) + + shard_size = int(len(rows) / args.num_shards) + start_idx = args.shard_id * shard_size + end_idx = start_idx + shard_size + + logger.info('Producing encodings for passages range: %d to %d (out of total %d)', start_idx, end_idx, len(rows)) + rows = rows[start_idx:end_idx] + + data = gen_ctx_vectors(rows, encoder, tensorizer, True) + + file = args.out_file + '_' + str(args.shard_id) + logger.info('Writing results to %s' % file) + with open(file, mode='wb') as f: + pickle.dump(data, f) + + logger.info('Total passages processed %d. Written to %s', len(data), file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + add_encoder_params(parser) + add_tokenizer_params(parser) + add_cuda_params(parser) + + parser.add_argument('--ctx_file', type=str, default=None, help='Path to passages set .tsv file') + parser.add_argument('--out_file', required=True, type=str, default=None, + help='output .tsv file path to write results to ') + parser.add_argument('--shard_id', type=int, default=0, help="Number(0-based) of data shard to process") + parser.add_argument('--num_shards', type=int, default=1, help="Total amount of data shards") + parser.add_argument('--batch_size', type=int, default=32, help="Batch size for the passage encoder forward pass") + args = parser.parse_args() + + assert args.model_file, 'Please specify --model_file checkpoint to init model weights' + + setup_args_gpu(args) + + print_args(args) + main(args) diff --git a/preprocess_reader_data.py b/preprocess_reader_data.py new file mode 100644 index 00000000..3f44e00a --- /dev/null +++ b/preprocess_reader_data.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Reader data preprocessor command line tool +""" +import argparse +import logging + +from dpr.data.reader_data import convert_retriever_results +from dpr.models import init_tenzorizer +from dpr.options import print_args, add_encoder_params, add_reader_preprocessing_params, add_tokenizer_params + +logger = logging.getLogger() + + +def main(args): + tensorizer = init_tenzorizer(args.encoder_model_type, args) + + # disable auto-padding to save disk space of serialized files + tensorizer.set_pad_to_max(False) + + convert_retriever_results(args.is_train_set, args.retriever_results, args.out_file, args.gold_passages_src, + tensorizer, args.num_workers) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_encoder_params(parser) + add_tokenizer_params(parser) + add_reader_preprocessing_params(parser) + + parser.add_argument("--is_train_set", action='store_true', + help="If true, the data will be binarised for train model usage (split into ctx+ and ctx- \ + and with answer spans selected)") + parser.add_argument("--retriever_results", required=True, type=str, + help="File with retriever results file(json format)") + parser.add_argument("--out_file", required=True, type=str, help="The file to write serialized results to") + + args = parser.parse_args() + + print_args(args) + + main(args) diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..f4b9bde3 --- /dev/null +++ b/setup.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from setuptools import setup + +with open('README.md') as f: + readme = f.read() + +setup( + name='dpr', + version='0.1.0', + description='Facebook AI Research Open Domain Q&A Toolkit', + url='', # TODO + classifiers=[ + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.6', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + ], + long_description=readme, + long_description_content_type='text/markdown', + setup_requires=[ + 'setuptools>=18.0', + ], + install_requires=[ + 'cython', + 'faiss-cpu>=1.6.1', + 'filelock', + 'numpy', + 'regex', + 'torch>=1.2.0', + 'transformers>=2.2.2', + 'tqdm>=4.27', + 'wget', + 'spacy>=2.1.8', + ], +) diff --git a/train_dense_encoder.py b/train_dense_encoder.py new file mode 100644 index 00000000..95162a00 --- /dev/null +++ b/train_dense_encoder.py @@ -0,0 +1,564 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +""" + Pipeline to train DPR Biencoder +""" + +import argparse +import glob +import logging +import math +import os +import random +import time + + +import torch + +from typing import Tuple +from torch import nn +from torch import Tensor as T + +from dpr.models import init_biencoder_components +from dpr.models.biencoder import BiEncoder, BiEncoderNllLoss, BiEncoderBatch +from dpr.options import add_encoder_params, add_training_params, setup_args_gpu, set_seed, print_args, \ + get_encoder_params_state, add_tokenizer_params, set_encoder_params_from_state +from dpr.utils.data_utils import ShardedDataIterator, read_data_from_json_files, Tensorizer +from dpr.utils.dist_utils import all_gather_list +from dpr.utils.model_utils import setup_for_distributed_mode, move_to_device, get_schedule_linear, CheckpointState, \ + get_model_file, get_model_obj, load_states_from_checkpoint + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +if (logger.hasHandlers()): + logger.handlers.clear() +console = logging.StreamHandler() +logger.addHandler(console) + + +class BiEncoderTrainer(object): + """ + BiEncoder training pipeline component. Can be used to initiate or resume training and validate the trained model + using either binary classification's NLL loss or average rank of the question's gold passages across dataset + provided pools of negative passages. For full IR accuracy evaluation, please see generate_dense_embeddings.py + and dense_retriever.py CLI tools. + """ + + def __init__(self, args): + self.args = args + self.shard_id = args.local_rank if args.local_rank != -1 else 0 + self.distributed_factor = args.distributed_world_size or 1 + + logger.info("***** Initializing components for training *****") + + # if model file is specified, encoder parameters from saved state should be used for initialization + model_file = get_model_file(self.args, self.args.checkpoint_file_name) + saved_state = None + if model_file: + saved_state = load_states_from_checkpoint(model_file) + set_encoder_params_from_state(saved_state.encoder_params, args) + + tensorizer, model, optimizer = init_biencoder_components(args.encoder_model_type, args) + + model, optimizer = setup_for_distributed_mode(model, optimizer, args.device, args.n_gpu, + args.local_rank, + args.fp16, + args.fp16_opt_level) + self.biencoder = model + self.optimizer = optimizer + self.tensorizer = tensorizer + self.start_epoch = 0 + self.start_batch = 0 + self.scheduler_state = None + self.best_validation_result = None + self.best_cp_name = None + if saved_state: + self._load_saved_state(saved_state) + + def get_data_iterator(self, path: str, batch_size: int, shuffle=True, + shuffle_seed: int = 0, + offset: int = 0, upsample_rates: list = None) -> ShardedDataIterator: + data_files = glob.glob(path) + data = read_data_from_json_files(data_files, upsample_rates) + + # filter those without positive ctx + data = [r for r in data if len(r['positive_ctxs']) > 0] + logger.info('Total cleaned data size: {}'.format(len(data))) + + return ShardedDataIterator(data, shard_id=self.shard_id, + num_shards=self.distributed_factor, + batch_size=batch_size, shuffle=shuffle, shuffle_seed=shuffle_seed, offset=offset, + strict_batch_size=True, # this is not really necessary, one can probably disable it + ) + + def run_train(self, ): + args = self.args + upsample_rates = None + if args.train_files_upsample_rates is not None: + upsample_rates = eval(args.train_files_upsample_rates) + + train_iterator = self.get_data_iterator(args.train_file, args.batch_size, + shuffle=True, + shuffle_seed=args.seed, offset=self.start_batch, + upsample_rates=upsample_rates) + + logger.info(" Total iterations per epoch=%d", train_iterator.max_iterations) + updates_per_epoch = train_iterator.max_iterations // args.gradient_accumulation_steps + total_updates = max(updates_per_epoch * (args.num_train_epochs - self.start_epoch - 1), 0) + \ + (train_iterator.max_iterations - self.start_batch) // args.gradient_accumulation_steps + logger.info(" Total updates=%d", total_updates) + warmup_steps = args.warmup_steps + scheduler = get_schedule_linear(self.optimizer, warmup_steps, total_updates) + + if self.scheduler_state: + logger.info("Loading scheduler state %s", self.scheduler_state) + scheduler.load_state_dict(self.scheduler_state) + + eval_step = math.ceil(updates_per_epoch / args.eval_per_epoch) + logger.info(" Eval step = %d", eval_step) + logger.info("***** Training *****") + + for epoch in range(self.start_epoch, int(args.num_train_epochs)): + logger.info("***** Epoch %d *****", epoch) + self._train_epoch(scheduler, epoch, eval_step, train_iterator) + + if args.local_rank in [-1, 0]: + logger.info('Training finished. Best validation checkpoint %s', self.best_cp_name) + + def validate_and_save(self, epoch: int, iteration: int, scheduler): + args = self.args + # for distributed mode, save checkpoint for only one process + save_cp = args.local_rank in [-1, 0] + + if epoch == args.val_av_rank_start_epoch: + self.best_validation_result = None + + if epoch >= args.val_av_rank_start_epoch: + validation_loss = self.validate_average_rank() + else: + validation_loss = self.validate_nll() + + if save_cp: + cp_name = self._save_checkpoint(scheduler, epoch, iteration) + logger.info('Saved checkpoint to %s', cp_name) + + if validation_loss < (self.best_validation_result or validation_loss + 1): + self.best_validation_result = validation_loss + self.best_cp_name = cp_name + logger.info('New Best validation checkpoint %s', cp_name) + + def validate_nll(self) -> float: + logger.info('NLL validation ...') + args = self.args + self.biencoder.eval() + data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, shuffle=False) + + total_loss = 0.0 + start_time = time.time() + total_correct_predictions = 0 + num_hard_negatives = args.hard_negatives + num_other_negatives = args.other_negatives + log_result_step = args.log_batch_step + batches = 0 + for i, samples_batch in enumerate(data_iterator.iterate_data()): + biencoder_input = BiEncoder.create_biencoder_input(samples_batch, self.tensorizer, + True, + num_hard_negatives, num_other_negatives, shuffle=False) + + loss, correct_cnt = _do_biencoder_fwd_pass(self.biencoder, biencoder_input, self.tensorizer, args) + total_loss += loss.item() + total_correct_predictions += correct_cnt + batches += 1 + if (i + 1) % log_result_step == 0: + logger.info('Eval step: %d , used_time=%f sec., loss=%f ', i, time.time() - start_time, loss.item()) + + total_loss = total_loss / batches + total_samples = batches * args.dev_batch_size * self.distributed_factor + correct_ratio = float(total_correct_predictions / total_samples) + logger.info('NLL Validation: loss = %f. correct prediction ratio %d/%d ~ %f', total_loss, + total_correct_predictions, + total_samples, + correct_ratio + ) + return total_loss + + def validate_average_rank(self) -> float: + """ + Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset. + It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params) + and stores them in RAM as well as question vectors. + Then the similarity scores are calculted for the entire + num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion. + Each question's gold passage rank in that sorted list of scores is averaged across all the questions. + :return: averaged rank number + """ + logger.info('Average rank validation ...') + + args = self.args + self.biencoder.eval() + distributed_factor = self.distributed_factor + + data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, shuffle=False) + + sub_batch_size = args.val_av_rank_bsz + sim_score_f = BiEncoderNllLoss.get_similarity_function() + q_represenations = [] + ctx_represenations = [] + positive_idx_per_question = [] + + num_hard_negatives = args.val_av_rank_hard_neg + num_other_negatives = args.val_av_rank_other_neg + + log_result_step = args.log_batch_step + + for i, samples_batch in enumerate(data_iterator.iterate_data()): + # samples += 1 + if len(q_represenations) > args.val_av_rank_max_qs / distributed_factor: + break + + biencoder_input = BiEncoder.create_biencoder_input(samples_batch, self.tensorizer, + True, + num_hard_negatives, num_other_negatives, shuffle=False) + total_ctxs = len(ctx_represenations) + ctxs_ids = biencoder_input.context_ids + ctxs_segments = biencoder_input.ctx_segments + bsz = ctxs_ids.size(0) + + # split contexts batch into sub batches since it is supposed to be too large to be processed in one batch + for j, batch_start in enumerate(range(0, bsz, sub_batch_size)): + + q_ids, q_segments = (biencoder_input.question_ids, biencoder_input.question_segments) if j == 0 \ + else (None, None) + + if j == 0 and args.n_gpu > 1 and q_ids.size(0) == 1: + # if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0, + # otherwise the other input tensors will be split but only the first split will be called + continue + + ctx_ids_batch = ctxs_ids[batch_start:batch_start + sub_batch_size] + ctx_seg_batch = ctxs_segments[batch_start:batch_start + sub_batch_size] + + q_attn_mask = self.tensorizer.get_attn_mask(q_ids) + ctx_attn_mask = self.tensorizer.get_attn_mask(ctx_ids_batch) + with torch.no_grad(): + q_dense, ctx_dense = self.biencoder(q_ids, q_segments, q_attn_mask, ctx_ids_batch, ctx_seg_batch, + ctx_attn_mask) + + if q_dense is not None: + q_represenations.extend(q_dense.cpu().split(1, dim=0)) + + ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0)) + + batch_positive_idxs = biencoder_input.is_positive + positive_idx_per_question.extend([total_ctxs + v for v in batch_positive_idxs]) + + if (i + 1) % log_result_step == 0: + logger.info('Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d', i, + len(ctx_represenations), len(q_represenations)) + + ctx_represenations = torch.cat(ctx_represenations, dim=0) + q_represenations = torch.cat(q_represenations, dim=0) + + logger.info('Av.rank validation: total q_vectors size=%s', q_represenations.size()) + logger.info('Av.rank validation: total ctx_vectors size=%s', ctx_represenations.size()) + + q_num = q_represenations.size(0) + assert q_num == len(positive_idx_per_question) + + scores = sim_score_f(q_represenations, ctx_represenations) + values, indices = torch.sort(scores, dim=1, descending=True) + + rank = 0 + for i, idx in enumerate(positive_idx_per_question): + # aggregate the rank of the known gold passage in the sorted results for each question + gold_idx = (indices[i] == idx).nonzero() + rank += gold_idx.item() + + if distributed_factor > 1: + # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank + # NOTE: the set of passages is still unique for every node + eval_stats = all_gather_list([rank, q_num], max_size=100) + for i, item in enumerate(eval_stats): + remote_rank, remote_q_num = item + if i != args.local_rank: + rank += remote_rank + q_num += remote_q_num + + av_rank = float(rank / q_num) + logger.info('Av.rank validation: average rank %s, total questions=%d', av_rank, q_num) + return av_rank + + def _train_epoch(self, scheduler, epoch: int, eval_step: int, + train_data_iterator: ShardedDataIterator, ): + + args = self.args + rolling_train_loss = 0.0 + epoch_loss = 0 + epoch_correct_predictions = 0 + + log_result_step = args.log_batch_step + rolling_loss_step = args.train_rolling_loss_step + num_hard_negatives = args.hard_negatives + num_other_negatives = args.other_negatives + seed = args.seed + self.biencoder.train() + epoch_batches = train_data_iterator.max_iterations + data_iteration = 0 + for i, samples_batch in enumerate(train_data_iterator.iterate_data(epoch=epoch)): + + # to be able to resume shuffled ctx- pools + data_iteration = train_data_iterator.get_iteration() + random.seed(seed + epoch + data_iteration) + biencoder_batch = BiEncoder.create_biencoder_input(samples_batch, self.tensorizer, + True, + num_hard_negatives, num_other_negatives, shuffle=True, + shuffle_positives=args.shuffle_positive_ctx + ) + + loss, correct_cnt = _do_biencoder_fwd_pass(self.biencoder, biencoder_batch, self.tensorizer, args) + + epoch_correct_predictions += correct_cnt + epoch_loss += loss.item() + rolling_train_loss += loss.item() + + if args.fp16: + from apex import amp + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + if args.max_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), args.max_grad_norm) + else: + loss.backward() + if args.max_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(self.biencoder.parameters(), args.max_grad_norm) + + if (i + 1) % args.gradient_accumulation_steps == 0: + self.optimizer.step() + scheduler.step() + self.biencoder.zero_grad() + + if i % log_result_step == 0: + lr = self.optimizer.param_groups[0]['lr'] + logger.info( + 'Epoch: %d: Step: %d/%d, loss=%f, lr=%f', epoch, data_iteration, epoch_batches, loss.item(), lr) + + if (i + 1) % rolling_loss_step == 0: + logger.info('Train batch %d', data_iteration) + latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step + logger.info('Avg. loss per last %d batches: %f', rolling_loss_step, latest_rolling_train_av_loss) + rolling_train_loss = 0.0 + + if data_iteration % eval_step == 0: + logger.info('Validation: Epoch: %d Step: %d/%d', epoch, data_iteration, epoch_batches) + self.validate_and_save(epoch, train_data_iterator.get_iteration(), scheduler) + self.biencoder.train() + + self.validate_and_save(epoch, data_iteration, scheduler) + + epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 + logger.info('Av Loss per epoch=%f', epoch_loss) + logger.info('epoch total correct predictions=%d', epoch_correct_predictions) + + def _save_checkpoint(self, scheduler, epoch: int, offset: int) -> str: + args = self.args + model_to_save = get_model_obj(self.biencoder) + cp = os.path.join(args.output_dir, + args.checkpoint_file_name + '.' + str(epoch) + ('.' + str(offset) if offset > 0 else '')) + + meta_params = get_encoder_params_state(args) + + state = CheckpointState(model_to_save.state_dict(), + self.optimizer.state_dict(), + scheduler.state_dict(), + offset, + epoch, meta_params + ) + torch.save(state._asdict(), cp) + logger.info('Saved checkpoint at %s', cp) + return cp + + def _load_saved_state(self, saved_state: CheckpointState): + epoch = saved_state.epoch + offset = saved_state.offset + if offset == 0: # epoch has been completed + epoch += 1 + logger.info('Loading checkpoint @ batch=%s and epoch=%s', offset, epoch) + + self.start_epoch = epoch + self.start_batch = offset + + model_to_load = get_model_obj(self.biencoder) + logger.info('Loading saved model state ...') + model_to_load.load_state_dict(saved_state.model_dict) # set strict=False if you use extra projection + + if saved_state.optimizer_dict: + logger.info('Loading saved optimizer state ...') + self.optimizer.load_state_dict(saved_state.optimizer_dict) + + if saved_state.scheduler_dict: + self.scheduler_state = saved_state.scheduler_dict + + +def _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, local_positive_idxs, + local_hard_negatives_idxs: list = None, + ) -> Tuple[T, bool]: + """ + Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations + across all the nodes. + """ + distributed_world_size = args.distributed_world_size or 1 + if distributed_world_size > 1: + q_vector_to_send = torch.empty_like(local_q_vector).cpu().copy_(local_q_vector).detach_() + ctx_vector_to_send = torch.empty_like(local_ctx_vectors).cpu().copy_(local_ctx_vectors).detach_() + + global_question_ctx_vectors = all_gather_list( + [q_vector_to_send, ctx_vector_to_send, local_positive_idxs, local_hard_negatives_idxs], + max_size=args.global_loss_buf_sz) + + global_q_vector = [] + global_ctxs_vector = [] + + # ctxs_per_question = local_ctx_vectors.size(0) + positive_idx_per_question = [] + hard_negatives_per_question = [] + + total_ctxs = 0 + + for i, item in enumerate(global_question_ctx_vectors): + q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item + + if i != args.local_rank: + global_q_vector.append(q_vector.to(local_q_vector.device)) + global_ctxs_vector.append(ctx_vectors.to(local_q_vector.device)) + positive_idx_per_question.extend([v + total_ctxs for v in positive_idx]) + hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in hard_negatives_idxs]) + else: + global_q_vector.append(local_q_vector) + global_ctxs_vector.append(local_ctx_vectors) + positive_idx_per_question.extend([v + total_ctxs for v in local_positive_idxs]) + hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in local_hard_negatives_idxs]) + total_ctxs += ctx_vectors.size(0) + + global_q_vector = torch.cat(global_q_vector, dim=0) + global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0) + + else: + global_q_vector = local_q_vector + global_ctxs_vector = local_ctx_vectors + positive_idx_per_question = local_positive_idxs + hard_negatives_per_question = local_hard_negatives_idxs + + loss, is_correct = loss_function.calc(global_q_vector, global_ctxs_vector, positive_idx_per_question, + hard_negatives_per_question) + + return loss, is_correct + + +def _do_biencoder_fwd_pass(model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, args) -> ( + torch.Tensor, int): + input = BiEncoderBatch(**move_to_device(input._asdict(), args.device)) + + q_attn_mask = tensorizer.get_attn_mask(input.question_ids) + ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) + + if model.training: + model_out = model(input.question_ids, input.question_segments, q_attn_mask, input.context_ids, + input.ctx_segments, ctx_attn_mask) + else: + with torch.no_grad(): + model_out = model(input.question_ids, input.question_segments, q_attn_mask, input.context_ids, + input.ctx_segments, ctx_attn_mask) + + local_q_vector, local_ctx_vectors = model_out + + loss_function = BiEncoderNllLoss() + + loss, is_correct = _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, + input.hard_negatives) + + is_correct = is_correct.sum().item() + + if args.n_gpu > 1: + loss = loss.mean() + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + return loss, is_correct + + +def main(): + parser = argparse.ArgumentParser() + + add_encoder_params(parser) + add_training_params(parser) + add_tokenizer_params(parser) + + # biencoder specific training features + parser.add_argument("--eval_per_epoch", default=1, type=int, + help="How many times it evaluates on dev set per epoch and saves a checkpoint") + + parser.add_argument("--global_loss_buf_sz", type=int, default=150000, + help='Buffer size for distributed mode representations al gather operation. \ + Increase this if you see errors like "encoded data exceeds max_size ..."') + + parser.add_argument("--fix_ctx_encoder", action='store_true') + parser.add_argument("--shuffle_positive_ctx", action='store_true') + + # input/output src params + parser.add_argument("--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be written or resumed from") + + # data handling parameters + parser.add_argument("--hard_negatives", default=1, type=int, + help="amount of hard negative ctx per question") + parser.add_argument("--other_negatives", default=0, type=int, + help="amount of 'other' negative ctx per question") + parser.add_argument("--train_files_upsample_rates", type=str, + help="list of up-sample rates per each train file. Example: [1,2,1]") + + # parameters for Av.rank validation method + parser.add_argument("--val_av_rank_start_epoch", type=int, default=10000, + help="Av.rank validation: the epoch from which to enable this validation") + parser.add_argument("--val_av_rank_hard_neg", type=int, default=30, + help="Av.rank validation: how many hard negatives to take from each question pool") + parser.add_argument("--val_av_rank_other_neg", type=int, default=30, + help="Av.rank validation: how many 'other' negatives to take from each question pool") + parser.add_argument("--val_av_rank_bsz", type=int, default=128, + help="Av.rank validation: batch size to process passages") + parser.add_argument("--val_av_rank_max_qs", type=int, default=10000, + help="Av.rank validation: max num of questions") + parser.add_argument('--checkpoint_file_name', type=str, default='dpr_biencoder', help="Checkpoints file prefix") + + args = parser.parse_args() + + if args.gradient_accumulation_steps < 1: + raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( + args.gradient_accumulation_steps)) + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + setup_args_gpu(args) + set_seed(args) + print_args(args) + + trainer = BiEncoderTrainer(args) + + if args.train_file is not None: + trainer.run_train() + elif args.model_file and args.dev_file: + logger.info("No train files are specified. Run 2 types of validation for specified model file") + trainer.validate_nll() + trainer.validate_average_rank() + else: + logger.warning("Neither train_file or (model_file & dev_file) parameters are specified. Nothing to do.") + + +if __name__ == "__main__": + main() diff --git a/train_reader.py b/train_reader.py new file mode 100644 index 00000000..36389410 --- /dev/null +++ b/train_reader.py @@ -0,0 +1,507 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +""" + Pipeline to train the reader model on top of the retriever results +""" + +import argparse +import collections +import glob +import json +import logging +import os +from collections import defaultdict +from typing import List + +import numpy as np +import torch + +from dpr.data.qa_validation import exact_match_score +from dpr.data.reader_data import ReaderSample, get_best_spans, SpanPrediction, convert_retriever_results +from dpr.models import init_reader_components +from dpr.models.reader import create_reader_input, ReaderBatch, compute_loss +from dpr.options import add_encoder_params, setup_args_gpu, set_seed, add_training_params, \ + add_reader_preprocessing_params, set_encoder_params_from_state, get_encoder_params_state, add_tokenizer_params, \ + print_args +from dpr.utils.data_utils import ShardedDataIterator, read_serialized_data_from_files, Tensorizer +from dpr.utils.model_utils import get_schedule_linear, load_states_from_checkpoint, move_to_device, CheckpointState, \ + get_model_file, setup_for_distributed_mode, get_model_obj + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +if (logger.hasHandlers()): + logger.handlers.clear() +console = logging.StreamHandler() +logger.addHandler(console) + +ReaderQuestionPredictions = collections.namedtuple('ReaderQuestionPredictions', ['id', 'predictions', 'gold_answers']) + + +class ReaderTrainer(object): + def __init__(self, args): + self.args = args + + self.shard_id = args.local_rank if args.local_rank != -1 else 0 + self.distributed_factor = args.distributed_world_size or 1 + + logger.info("***** Initializing components for training *****") + + model_file = get_model_file(self.args, self.args.checkpoint_file_name) + saved_state = None + if model_file: + saved_state = load_states_from_checkpoint(model_file) + set_encoder_params_from_state(saved_state.encoder_params, args) + + tensorizer, reader, optimizer = init_reader_components(args.encoder_model_type, args) + + reader, optimizer = setup_for_distributed_mode(reader, optimizer, args.device, args.n_gpu, + args.local_rank, + args.fp16, + args.fp16_opt_level) + self.reader = reader + self.optimizer = optimizer + self.tensorizer = tensorizer + self.start_epoch = 0 + self.start_batch = 0 + self.scheduler_state = None + self.best_validation_result = None + self.best_cp_name = None + if saved_state: + self._load_saved_state(saved_state) + + def get_data_iterator(self, path: str, batch_size: int, is_train: bool, shuffle=True, + shuffle_seed: int = 0, + offset: int = 0) -> ShardedDataIterator: + data_files = glob.glob(path) + logger.info("Data files: %s", data_files) + if not data_files: + raise RuntimeError('No Data files found') + preprocessed_data_files = self._get_preprocessed_files(data_files, is_train) + data = read_serialized_data_from_files(preprocessed_data_files) + + iterator = ShardedDataIterator(data, shard_id=self.shard_id, + num_shards=self.distributed_factor, + batch_size=batch_size, shuffle=shuffle, shuffle_seed=shuffle_seed, offset=offset) + + # apply deserialization hook + iterator.apply(lambda sample: sample.on_deserialize()) + return iterator + + def run_train(self): + args = self.args + + train_iterator = self.get_data_iterator(args.train_file, args.batch_size, + True, + shuffle=True, + shuffle_seed=args.seed, offset=self.start_batch) + + num_train_epochs = args.num_train_epochs - self.start_epoch + + logger.info("Total iterations per epoch=%d", train_iterator.max_iterations) + updates_per_epoch = train_iterator.max_iterations // args.gradient_accumulation_steps + total_updates = updates_per_epoch * num_train_epochs - self.start_batch + logger.info(" Total updates=%d", total_updates) + + warmup_steps = args.warmup_steps + scheduler = get_schedule_linear(self.optimizer, warmup_steps=warmup_steps, + training_steps=total_updates) + if self.scheduler_state: + logger.info("Loading scheduler state %s", self.scheduler_state) + scheduler.load_state_dict(self.scheduler_state) + + eval_step = args.eval_step + logger.info(" Eval step = %d", eval_step) + logger.info("***** Training *****") + + global_step = self.start_epoch * updates_per_epoch + self.start_batch + + for epoch in range(self.start_epoch, int(args.num_train_epochs)): + logger.info("***** Epoch %d *****", epoch) + global_step = self._train_epoch(scheduler, epoch, eval_step, train_iterator, global_step) + + if args.local_rank in [-1, 0]: + logger.info('Training finished. Best validation checkpoint %s', self.best_cp_name) + + return + + def validate_and_save(self, epoch: int, iteration: int, scheduler): + args = self.args + # in distributed DDP mode, save checkpoint for only one process + save_cp = args.local_rank in [-1, 0] + reader_validation_score = self.validate() + + if save_cp: + cp_name = self._save_checkpoint(scheduler, epoch, iteration) + logger.info('Saved checkpoint to %s', cp_name) + + if reader_validation_score < (self.best_validation_result or 0): + self.best_validation_result = reader_validation_score + self.best_cp_name = cp_name + logger.info('New Best validation checkpoint %s', cp_name) + + def validate(self): + logger.info('Validation ...') + args = self.args + self.reader.eval() + data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, False, shuffle=False) + + log_result_step = args.log_batch_step + all_results = [] + + eval_top_docs = args.eval_top_docs + for i, samples_batch in enumerate(data_iterator.iterate_data()): + input = create_reader_input(self.tensorizer.get_pad_id(), + samples_batch, + args.passages_per_question_predict, + args.sequence_length, + args.max_n_answers, + is_train=False, shuffle=False) + + input = ReaderBatch(**move_to_device(input._asdict(), args.device)) + attn_mask = self.tensorizer.get_attn_mask(input.input_ids) + + with torch.no_grad(): + start_logits, end_logits, relevance_logits = self.reader(input.input_ids, attn_mask) + + batch_predictions = self._get_best_prediction(start_logits, end_logits, relevance_logits, samples_batch, + passage_thresholds=eval_top_docs) + + all_results.extend(batch_predictions) + + if (i + 1) % log_result_step == 0: + logger.info('Eval step: %d ', i) + + ems = defaultdict(list) + + for q_predictions in all_results: + gold_answers = q_predictions.gold_answers + span_predictions = q_predictions.predictions # {top docs threshold -> SpanPrediction()} + for (n, span_prediction) in span_predictions.items(): + em_hit = max([exact_match_score(span_prediction.prediction_text, ga) for ga in gold_answers]) + ems[n].append(em_hit) + em = 0 + for n in sorted(ems.keys()): + em = np.mean(ems[n]) + logger.info("n=%d\tEM %.2f" % (n, em * 100)) + + if args.prediction_results_file: + self._save_predictions(args.prediction_results_file, all_results) + + return em + + def _train_epoch(self, scheduler, epoch: int, eval_step: int, + train_data_iterator: ShardedDataIterator, global_step: int): + args = self.args + rolling_train_loss = 0.0 + epoch_loss = 0 + log_result_step = args.log_batch_step + rolling_loss_step = args.train_rolling_loss_step + + self.reader.train() + epoch_batches = train_data_iterator.max_iterations + + for i, samples_batch in enumerate(train_data_iterator.iterate_data(epoch=epoch)): + + data_iteration = train_data_iterator.get_iteration() + + # enables to resume to exactly same train state + if args.fully_resumable: + np.random.seed(args.seed + global_step) + torch.manual_seed(args.seed + global_step) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed + global_step) + + input = create_reader_input(self.tensorizer.get_pad_id(), + samples_batch, + args.passages_per_question, + args.sequence_length, + args.max_n_answers, + is_train=True, shuffle=True) + + loss = self._calc_loss(input) + + epoch_loss += loss.item() + rolling_train_loss += loss.item() + + if args.fp16: + from apex import amp + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + if args.max_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), args.max_grad_norm) + else: + loss.backward() + if args.max_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(self.reader.parameters(), args.max_grad_norm) + + global_step += 1 + + if (i + 1) % args.gradient_accumulation_steps == 0: + self.optimizer.step() + scheduler.step() + self.reader.zero_grad() + + if global_step % log_result_step == 0: + lr = self.optimizer.param_groups[0]['lr'] + logger.info( + 'Epoch: %d: Step: %d/%d, global_step=%d, lr=%f', epoch, data_iteration, epoch_batches, global_step, + lr) + + if (i + 1) % rolling_loss_step == 0: + logger.info('Train batch %d', data_iteration) + latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step + logger.info('Avg. loss per last %d batches: %f', rolling_loss_step, latest_rolling_train_av_loss) + rolling_train_loss = 0.0 + + if global_step % eval_step == 0: + logger.info('Validation: Epoch: %d Step: %d/%d', epoch, data_iteration, epoch_batches) + self.validate_and_save(epoch, train_data_iterator.get_iteration(), scheduler) + self.reader.train() + + epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 + logger.info('Av Loss per epoch=%f', epoch_loss) + return global_step + + def _save_checkpoint(self, scheduler, epoch: int, offset: int) -> str: + args = self.args + model_to_save = get_model_obj(self.reader) + cp = os.path.join(args.output_dir, + args.checkpoint_file_name + '.' + str(epoch) + ('.' + str(offset) if offset > 0 else '')) + + meta_params = get_encoder_params_state(args) + + state = CheckpointState(model_to_save.state_dict(), self.optimizer.state_dict(), scheduler.state_dict(), offset, + epoch, meta_params + ) + torch.save(state._asdict(), cp) + return cp + + def _load_saved_state(self, saved_state: CheckpointState): + epoch = saved_state.epoch + offset = saved_state.offset + if offset == 0: # epoch has been completed + epoch += 1 + logger.info('Loading checkpoint @ batch=%s and epoch=%s', offset, epoch) + self.start_epoch = epoch + self.start_batch = offset + + model_to_load = get_model_obj(self.reader) + if saved_state.model_dict: + logger.info('Loading model weights from saved state ...') + model_to_load.load_state_dict(saved_state.model_dict) + + logger.info('Loading saved optimizer state ...') + if saved_state.optimizer_dict: + self.optimizer.load_state_dict(saved_state.optimizer_dict) + self.scheduler_state = saved_state.scheduler_dict + + def _get_best_prediction(self, start_logits, end_logits, relevance_logits, + samples_batch: List[ReaderSample], passage_thresholds: List[int] = None) \ + -> List[ReaderQuestionPredictions]: + + args = self.args + max_answer_length = args.max_answer_length + questions_num, passages_per_question = relevance_logits.size() + + _, idxs = torch.sort(relevance_logits, dim=1, descending=True, ) + + batch_results = [] + for q in range(questions_num): + sample = samples_batch[q] + + non_empty_passages_num = len(sample.passages) + nbest = [] + for p in range(passages_per_question): + passage_idx = idxs[q, p].item() + if passage_idx >= non_empty_passages_num: # empty passage selected, skip + continue + reader_passage = sample.passages[passage_idx] + sequence_ids = reader_passage.sequence_ids + sequence_len = sequence_ids.size(0) + # assuming question & title information is at the beginning of the sequence + passage_offset = reader_passage.passage_offset + + p_start_logits = start_logits[q, passage_idx].tolist()[passage_offset:sequence_len] + p_end_logits = end_logits[q, passage_idx].tolist()[passage_offset:sequence_len] + + ctx_ids = sequence_ids.tolist()[passage_offset:] + best_spans = get_best_spans(self.tensorizer, p_start_logits, p_end_logits, ctx_ids, max_answer_length, + passage_idx, relevance_logits[q, passage_idx].item(), top_spans=10) + nbest.extend(best_spans) + if len(nbest) > 0 and not passage_thresholds: + break + + if passage_thresholds: + passage_rank_matches = {} + for n in passage_thresholds: + curr_nbest = [pred for pred in nbest if pred.passage_index < n] + passage_rank_matches[n] = curr_nbest[0] + predictions = passage_rank_matches + else: + if len(nbest) == 0: + predictions = {passages_per_question: SpanPrediction('', -1, -1, -1, '')} + else: + predictions = {passages_per_question: nbest[0]} + batch_results.append(ReaderQuestionPredictions(sample.question, predictions, sample.answers)) + return batch_results + + def _calc_loss(self, input: ReaderBatch) -> torch.Tensor: + args = self.args + input = ReaderBatch(**move_to_device(input._asdict(), args.device)) + attn_mask = self.tensorizer.get_attn_mask(input.input_ids) + questions_num, passages_per_question, _ = input.input_ids.size() + + if self.reader.training: + # start_logits, end_logits, rank_logits = self.reader(input.input_ids, attn_mask) + loss = self.reader(input.input_ids, attn_mask, input.start_positions, input.end_positions, + input.answers_mask) + + else: + # TODO: remove? + with torch.no_grad(): + start_logits, end_logits, rank_logits = self.reader(input.input_ids, attn_mask) + + loss = compute_loss(input.start_positions, input.end_positions, input.answers_mask, start_logits, + end_logits, + rank_logits, + questions_num, passages_per_question) + if args.n_gpu > 1: + loss = loss.mean() + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + return loss + + def _get_preprocessed_files(self, data_files: List, is_train: bool, ): + + serialized_files = [file for file in data_files if file.endswith('.pkl')] + if serialized_files: + return serialized_files + assert len(data_files) == 1, 'Only 1 source file pre-processing is supported.' + + # data may have been serialized and cached before, try to find ones from same dir + def _find_cached_files(path: str): + dir_path, base_name = os.path.split(path) + base_name = base_name.replace('.json', '') + out_file_prefix = os.path.join(dir_path, base_name) + out_file_pattern = out_file_prefix + '*.pkl' + return glob.glob(out_file_pattern), out_file_prefix + + serialized_files, out_file_prefix = _find_cached_files(data_files[0]) + if serialized_files: + logger.info('Found preprocessed files. %s', serialized_files) + return serialized_files + + gold_passages_src = None + if self.args.gold_passages_src: + gold_passages_src = self.args.gold_passages_src if is_train else self.args.gold_passages_src_dev + assert os.path.exists(gold_passages_src), 'Please specify valid gold_passages_src/gold_passages_src_dev' + logger.info('Data are not preprocessed for reader training. Start pre-processing ...') + + # start pre-processing and save results + def _run_preprocessing(tensorizer: Tensorizer): + # temporarily disable auto-padding to save disk space usage of serialized files + tensorizer.set_pad_to_max(False) + serialized_files = convert_retriever_results(is_train, data_files[0], out_file_prefix, + gold_passages_src, + self.tensorizer, + num_workers=self.args.num_workers) + tensorizer.set_pad_to_max(True) + return serialized_files + + if self.distributed_factor > 1: + # only one node in DDP model will do pre-processing + if self.args.local_rank in [-1, 0]: + serialized_files = _run_preprocessing(self.tensorizer) + torch.distributed.barrier() + else: + torch.distributed.barrier() + serialized_files = _find_cached_files(data_files[0]) + else: + serialized_files = _run_preprocessing(self.tensorizer) + + return serialized_files + + def _save_predictions(self, out_file: str, prediction_results: List[ReaderQuestionPredictions]): + logger.info('Saving prediction results to %s', out_file) + with open(out_file, 'w', encoding="utf-8") as output: + save_results = [] + for r in prediction_results: + save_results.append({ + 'question': r.id, + 'gold_answers': r.gold_answers, + 'predictions': [{ + 'top_k': top_k, + 'prediction': { + 'text': span_pred.prediction_text, + 'score': span_pred.span_score, + 'relevance_score': span_pred.relevance_score, + 'passage_idx': span_pred.passage_index, + 'passage': self.tensorizer.to_string(span_pred.passage_token_ids) + } + } for top_k, span_pred in r.predictions.items()] + }) + output.write(json.dumps(save_results, indent=4) + "\n") + + +def main(): + parser = argparse.ArgumentParser() + + add_encoder_params(parser) + add_training_params(parser) + add_tokenizer_params(parser) + add_reader_preprocessing_params(parser) + + # reader specific params + parser.add_argument("--max_n_answers", default=10, type=int, + help="Max amount of answer spans to marginalize per singe passage") + parser.add_argument('--passages_per_question', type=int, default=2, + help="Total amount of positive and negative passages per question") + parser.add_argument('--passages_per_question_predict', type=int, default=50, + help="Total amount of positive and negative passages per question for evaluation") + parser.add_argument("--max_answer_length", default=10, type=int, + help="The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another.") + parser.add_argument('--eval_top_docs', nargs='+', type=int, + help="top retrival passages thresholds to analyze prediction results for") + parser.add_argument('--checkpoint_file_name', type=str, default='dpr_reader') + parser.add_argument('--prediction_results_file', type=str, help='path to a file to write prediction results to') + + # training parameters + parser.add_argument("--eval_step", default=2000, type=int, + help="batch steps to run validation and save checkpoint") + parser.add_argument("--output_dir", default=None, type=str, + help="The output directory where the model checkpoints will be written to") + + parser.add_argument('--fully_resumable', action='store_true', + help="Enables resumable mode by specifying global step dependent random seed before shuffling " + "in-batch data") + + args = parser.parse_args() + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + setup_args_gpu(args) + set_seed(args) + print_args(args) + + trainer = ReaderTrainer(args) + + if args.train_file is not None: + trainer.run_train() + elif args.dev_file: + logger.info("No train files are specified. Run validation.") + trainer.validate() + else: + logger.warning("Neither train_file or (model_file & dev_file) parameters are specified. Nothing to do.") + + +if __name__ == "__main__": + main()