Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
code fromatting + misc
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad-karpukhin committed Aug 31, 2021
2 parents 6b7e36d + 49e5838 commit 1ee31c6
Show file tree
Hide file tree
Showing 11 changed files with 978 additions and 138 deletions.
306 changes: 173 additions & 133 deletions README.md

Large diffs are not rendered by default.

65 changes: 65 additions & 0 deletions conf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
## Hydra

[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
framework that simplifies the development of research and other complex
applications. The key feature is the ability to dynamically create a
hierarchical configuration by composition and override it through config files
and the command line.

## DPR configuration
All DPR tools configuration parameters are now split between different config groups and you can either modify them in the config files or override from command line.

Each tools's (train_dense_encoder.py, generate_dense_embeddings.py, dense_retriever.py and train_reader.py) main method has now a hydra @hydra.main decorator with the name of the configuration file in the conf/ dir.
For example, dense_retriever.py takes all its parameters from conf/dense_retriever.yaml file.
Every tool's configuration files refers to other configuration files via "defaults:" parameter.
It is called a [configuration group](https://hydra.cc/docs/tutorials/structured_config/config_groups) in Hydra.

Let's take a look at dense_retriever.py's configuration:


```yaml

defaults:
- encoder: hf_bert
- datasets: retriever_default
- ctx_sources: default_sources

indexers:
flat:
_target_: dpr.indexer.faiss_indexers.DenseFlatIndexer

hnsw:
_target_: dpr.indexer.faiss_indexers.DenseHNSWFlatIndexer

hnsw_sq:
_target_: dpr.indexer.faiss_indexers.DenseHNSWSQIndexer

...
qa_dataset:
...
ctx_datatsets:
...
indexer: flat
...

```

" - encoder: " - a configuration group that contains all parameters to instantiate the encoder. The actual parameters are located in conf/encoder/hf_bert.yaml file.
If you want to override some of them, you can either
- Modify that config file
- Create a new config group file under conf/encoder/ folder and enable to use it by providing encoder={your file name} command line argument
- Override specific parameter from command line. For example: encoder.sequence_length=300

" - datasets:" - a configuration group that contains a list of all possible sources of queries for evaluation. One can find them in conf/datasets/retriever_default.yaml file.
One should specify the dataset to use by providing qa_dataset parameter in order to use one of them during evaluation. For example, if you want to run the retriever on NQ test set, set qa_dataset=nq_test as a command line parameter.

It is much easier now to use custom datasets, without the need to convert them to DPR format. Just define your own class that provides relevant __getitem__(), __len__() and load_data() methods (inherit from QASrc).

" - ctx_sources: " - a configuration group that contains a list of all possible passage sources. One can find them in conf/ctx_sources/default_sources.yaml file.
One should specify a list of names of the passages datasets as ctx_datatsets parameter. For example, if you want to use dpr's old wikipedia passages, set ctx_datatsets=[dpr_wiki].
Please note that this parameter is a list and you can effectively concatenate different passage source into one. In order to use multiple sources at once, one also needs to provide relevant embeddings files in encoded_ctx_files parameter, which is also a list.


"indexers:" - a parameters map that defines various indexes. The actual index is selected by indexer parameter which is 'flat' by default but you can use loss index types by setting indexer=hnsw or indexer=hnsw_sq in the command line.

Please refer to the configuration files comments for every parameter.
17 changes: 17 additions & 0 deletions conf/datasets/encoder_train_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,20 @@ squad1_train:
squad1_dev:
_target_: dpr.data.biencoder_data.JsonQADataset
file: data.retriever.squad1-dev

webq_train:
_target_: dpr.data.biencoder_data.JsonQADataset
file: data.retriever.webq-train

webq_dev:
_target_: dpr.data.biencoder_data.JsonQADataset
file: data.retriever.webq-dev

curatedtrec_train:
_target_: dpr.data.biencoder_data.JsonQADataset
file: data.retriever.curatedtrec-train

curatedtrec_dev:
_target_: dpr.data.biencoder_data.JsonQADataset
file: data.retriever.curatedtrec-dev

8 changes: 8 additions & 0 deletions conf/datasets/retriever_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,11 @@ trivia_train:
trivia_dev:
_target_: dpr.data.retriever_data.CsvQASrc
file: data.retriever.qas.trivia-dev

webq_test:
_target_: dpr.data.retriever_data.CsvQASrc
file: data.retriever.qas.webq-test

curatedtrec_test:
_target_: dpr.data.retriever_data.CsvQASrc
file: data.retriever.qas.curatedtrec-test
3 changes: 2 additions & 1 deletion dense_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def main(cfg: DictConfig):
question_encoder_state = {
key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith(encoder_prefix)
}
model_to_load.load_state_dict(question_encoder_state)
# TODO: long term HF state compatibility fix
model_to_load.load_state_dict(question_encoder_state, strict=False)
vector_size = model_to_load.get_out_size()
logger.info("Encoder vector_size=%d", vector_size)

Expand Down
46 changes: 46 additions & 0 deletions dpr/data/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,30 @@
"compressed": True,
"desc": "SQUAD 1.1 dev subset with passages pools for the Retriever train time validation",
},
"data.retriever.webq-train": {
"s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-webquestions-train.json.gz",
"original_ext": ".json",
"compressed": True,
"desc": "WebQuestions dev subset with passages pools for the Retriever train time validation",
},
"data.retriever.webq-dev": {
"s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-webquestions-dev.json.gz",
"original_ext": ".json",
"compressed": True,
"desc": "WebQuestions dev subset with passages pools for the Retriever train time validation",
},
"data.retriever.curatedtrec-train": {
"s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-curatedtrec-train.json.gz",
"original_ext": ".json",
"compressed": True,
"desc": "CuratedTrec dev subset with passages pools for the Retriever train time validation",
},
"data.retriever.curatedtrec-dev": {
"s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-curatedtrec-dev.json.gz",
"original_ext": ".json",
"compressed": True,
"desc": "CuratedTrec dev subset with passages pools for the Retriever train time validation",
},
"data.retriever.qas.nq-dev": {
"s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-dev.qa.csv",
"original_ext": ".csv",
Expand Down Expand Up @@ -125,6 +149,18 @@
"compressed": False,
"desc": "Trivia test subset for Retriever validation and IR results generation",
},
"data.retriever.qas.webq-test": {
"s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/webquestions-test.qa.csv",
"original_ext": ".csv",
"compressed": False,
"desc": "WebQuestions test subset for Retriever validation and IR results generation",
},
"data.retriever.qas.curatedtrec-test": {
"s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever/curatedtrec-test.qa.csv",
"original_ext": ".csv",
"compressed": False,
"desc": "CuratedTrec test subset for Retriever validation and IR results generation",
},
"data.gold_passages_info.nq_train": {
"s3_url": "https://dl.fbaipublicfiles.com/dpr/data/nq_gold_info/nq-train_gold_info.json.gz",
"original_ext": ".json",
Expand Down Expand Up @@ -175,6 +211,16 @@
"desc": "Encoded wikipedia files using a biencoder checkpoint("
"checkpoint.retriever.single.nq.bert-base-encoder) trained on NQ dataset ",
},
"data.retriever_results.nq.single-adv-hn.wikipedia_passages": {
"s3_url": [
"https://dl.fbaipublicfiles.com/dpr/data/wiki_encoded/single-adv-hn/nq/wiki_passages_{}".format(i)
for i in range(50)
],
"original_ext": ".pkl",
"compressed": False,
"desc": "Encoded wikipedia files using a biencoder checkpoint("
"checkpoint.retriever.single-adv-hn.nq.bert-base-encoder) trained on NQ dataset + adversarial hard negatives",
},
"data.retriever_results.nq.single.test": {
"s3_url": "https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-test.json.gz",
"original_ext": ".json",
Expand Down

0 comments on commit 1ee31c6

Please sign in to comment.