Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 3914574

Browse files
Passage retrieval compression (#297)
* adding IR elastic stuff * adding data download and modified es dense ranking * adding Doc2query * adding DPR code * updating doc2quyery code * adding msmarco eval scri[t * making dataset HF compatible * making dataset HF compatible * running doc2query t5 * model running * working on integrating * done with yaml recipe for all prunable layers * fixing config spacing for pruning yaml * work on dataset making * updaed thedownload data script and model training * running doc2query but missing the work for pruning * fixing issues in pruning * moving around DPR * added optimal lobotomizing project * adding to readme for baseline * new structures * cleaning up structure and pushing baseline numbers * moving sparse_ml_utils.py to src Co-authored-by: Mark Kurtz <mark@neuralmagic.com>
1 parent 942ed6e commit 3914574

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+12705
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Compressing DPR
2+
Author: @spacemanidol
3+
4+
Methods
5+
1. Varying models
6+
2. Sturctured Pruning
7+
3. Unstructured Pruning
8+
4. Dimensionality Reduction
9+
## Usage
10+
batch_size: 4
11+
dev_batch_size: 16
12+
adam_eps: 1e-8
13+
adam_betas: (0.9, 0.999)
14+
max_grad_norm: 2.0
15+
log_batch_step: 1
16+
train_rolling_loss_step: 100
17+
weight_decay: 0.0
18+
learning_rate: 2e-5
19+
# Linear warmup over warmup_steps.
20+
warmup_steps: 1237
21+
22+
# Number of updates steps to accumulate before performing a backward/update pass.
23+
gradient_accumulation_steps: 1
24+
25+
# Total number of training epochs to perform.
26+
num_train_epochs: 40
27+
eval_per_epoch: 1
28+
hard_negatives: 1
29+
other_negatives: 0
30+
val_av_rank_hard_neg: 30
31+
val_av_rank_other_neg: 30
32+
val_av_rank_bsz: 128
33+
val_av_rank_max_qs: 10000
34+
35+
https://www.dropbox.com/s/lvvpsx0cjk4vemv/collection.tar.gz?dl=1
36+
https://www.dropbox.com/s/hq6xjhswiz60siu/queries.dev.small.tsv?dl=1
37+
https://www.dropbox.com/s/khsplt2fhqwjs0v/qrels.dev.small.tsv?dl=1
38+
https://www.dropbox.com/s/uzkvv4gpj3a596a/predicted_queries_topk_sampling.zip?dl=1
39+
https://www.dropbox.com/s/nc1drdkjpxxsngg/run.dev.small.tsv?dl=1
40+
## Results
41+
42+
| Top-k passages | Original DPR NQ model | New DPR model |
43+
| ------------- |:-------------:| -----:|
44+
| 1 | 45.87 | 52.47 |
45+
| 5 | 68.14 | 72.24 |
46+
| 20 | 79.97 | 81.33 |
47+
| 100 | 85.87 | 87.29 |
48+
### requirements.txt
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
## Hydra
2+
3+
[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
4+
framework that simplifies the development of research and other complex
5+
applications. The key feature is the ability to dynamically create a
6+
hierarchical configuration by composition and override it through config files
7+
and the command line.
8+
9+
## DPR configuration
10+
All DPR tools configuration parameters are now split between different config groups and you can either modify them in the config files or override from command line.
11+
12+
Each tools's (train_dense_encoder.py, generate_dense_embeddings.py, dense_retriever.py and train_reader.py) main method has now a hydra @hydra.main decorator with the name of the configuration file in the conf/ dir.
13+
For example, dense_retriever.py takes all its parameters from conf/dense_retriever.yaml file.
14+
Every tool's configuration files refers to other configuration files via "defaults:" parameter.
15+
It is called a [configuration group](https://hydra.cc/docs/tutorials/structured_config/config_groups) in Hydra.
16+
17+
Let's take a look at dense_retriever.py's configuration:
18+
19+
20+
```yaml
21+
22+
defaults:
23+
- encoder: hf_bert
24+
- datasets: retriever_default
25+
- ctx_sources: default_sources
26+
27+
indexers:
28+
flat:
29+
_target_: dpr.indexer.faiss_indexers.DenseFlatIndexer
30+
31+
hnsw:
32+
_target_: dpr.indexer.faiss_indexers.DenseHNSWFlatIndexer
33+
34+
hnsw_sq:
35+
_target_: dpr.indexer.faiss_indexers.DenseHNSWSQIndexer
36+
37+
...
38+
qa_dataset:
39+
...
40+
ctx_datatsets:
41+
...
42+
indexer: flat
43+
...
44+
45+
```
46+
47+
" - encoder: " - a configuration group that contains all parameters to instantiate the encoder. The actual parameters are located in conf/encoder/hf_bert.yaml file.
48+
If you want to override some of them, you can either
49+
- Modify that config file
50+
- Create a new config group file under conf/encoder/ folder and enable to use it by providing encoder={your file name} command line argument
51+
- Override specific parameter from command line. For example: encoder.sequence_length=300
52+
53+
" - datasets:" - a configuration group that contains a list of all possible sources of queries for evaluation. One can find them in conf/datasets/retriever_default.yaml file.
54+
One should specify the dataset to use by providing qa_dataset parameter in order to use one of them during evaluation. For example, if you want to run the retriever on NQ test set, set qa_dataset=nq_test as a command line parameter.
55+
56+
It is much easier now to use custom datasets, without the need to convert them to DPR format. Just define your own class that provides relevant __getitem__(), __len__() and load_data() methods (inherit from QASrc).
57+
58+
" - ctx_sources: " - a configuration group that contains a list of all possible passage sources. One can find them in conf/ctx_sources/default_sources.yaml file.
59+
One should specify a list of names of the passages datasets as ctx_datatsets parameter. For example, if you want to use dpr's old wikipedia passages, set ctx_datatsets=[dpr_wiki].
60+
Please note that this parameter is a list and you can effectively concatenate different passage source into one. In order to use multiple sources at once, one also needs to provide relevant embeddings files in encoded_ctx_files parameter, which is also a list.
61+
62+
63+
"indexers:" - a parameters map that defines various indexes. The actual index is selected by indexer parameter which is 'flat' by default but you can use loss index types by setting indexer=hnsw or indexer=hnsw_sq in the command line.
64+
65+
Please refer to the configuration files comments for every parameter.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
# configuration groups
3+
defaults:
4+
- encoder: hf_bert
5+
- train: biencoder_default
6+
- datasets: encoder_train_default
7+
8+
train_datasets:
9+
dev_datasets:
10+
output_dir:
11+
train_sampling_rates:
12+
loss_scale_factors:
13+
14+
# Whether to lower case the input text. Set True for uncased models, False for the cased ones.
15+
do_lower_case: True
16+
17+
fix_ctx_encoder: False
18+
val_av_rank_start_epoch: 30
19+
seed: 12345
20+
checkpoint_file_name: dpr_biencoder
21+
22+
# A trained bi-encoder checkpoint file to initialize the model
23+
model_file:
24+
25+
# TODO: move to a conf group
26+
# local_rank for distributed training on gpus
27+
local_rank: -1
28+
global_loss_buf_sz: 592000
29+
device:
30+
distributed_world_size:
31+
distributed_port:
32+
no_cuda: False
33+
n_gpu:
34+
fp16: True
35+
36+
# For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
37+
# "See details at https://nvidia.github.io/apex/amp.html
38+
fp16_opt_level: O1
39+
40+
# tokens which won't be slit by tokenizer
41+
special_tokens:
42+
43+
ignore_checkpoint_offset: False
44+
ignore_checkpoint_optimizer: False
45+
46+
# set to >1 to enable multiple query encoders
47+
multi_q_encoder: False
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# @package _group_
2+
3+
dpr_wiki:
4+
_target_: dpr.data.retriever_data.CsvCtxSrc
5+
file: data.wikipedia_split.psgs_w100
6+
id_prefix: 'wiki:'
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# @package _group_
2+
3+
nq_train:
4+
_target_: dpr.data.biencoder_data.JsonQADataset
5+
file: data.retriever.nq-train
6+
7+
nq_train_hn1:
8+
_target_: dpr.data.biencoder_data.JsonQADataset
9+
file: data.retriever.nq-adv-hn-train
10+
11+
nq_dev:
12+
_target_: dpr.data.biencoder_data.JsonQADataset
13+
file: data.retriever.nq-dev
14+
15+
trivia_train:
16+
_target_: dpr.data.biencoder_data.JsonQADataset
17+
file: data.retriever.trivia-train
18+
19+
trivia_dev:
20+
_target_: dpr.data.biencoder_data.JsonQADataset
21+
file: data.retriever.trivia-dev
22+
23+
squad1_train:
24+
_target_: dpr.data.biencoder_data.JsonQADataset
25+
file: data.retriever.squad1-train
26+
27+
squad1_dev:
28+
_target_: dpr.data.biencoder_data.JsonQADataset
29+
file: data.retriever.squad1-dev
30+
31+
webq_train:
32+
_target_: dpr.data.biencoder_data.JsonQADataset
33+
file: data.retriever.webq-train
34+
35+
webq_dev:
36+
_target_: dpr.data.biencoder_data.JsonQADataset
37+
file: data.retriever.webq-dev
38+
39+
curatedtrec_train:
40+
_target_: dpr.data.biencoder_data.JsonQADataset
41+
file: data.retriever.curatedtrec-train
42+
43+
curatedtrec_dev:
44+
_target_: dpr.data.biencoder_data.JsonQADataset
45+
file: data.retriever.curatedtrec-dev
46+
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# @package _group_
2+
3+
nq_test:
4+
_target_: dpr.data.retriever_data.CsvQASrc
5+
file: data.retriever.qas.nq-test
6+
7+
nq_train:
8+
_target_: dpr.data.retriever_data.CsvQASrc
9+
file: data.retriever.qas.nq-train
10+
11+
nq_dev:
12+
_target_: dpr.data.retriever_data.CsvQASrc
13+
file: data.retriever.qas.nq-dev
14+
15+
trivia_test:
16+
_target_: dpr.data.retriever_data.CsvQASrc
17+
file: data.retriever.qas.trivia-test
18+
19+
trivia_train:
20+
_target_: dpr.data.retriever_data.CsvQASrc
21+
file: data.retriever.qas.trivia-train
22+
23+
trivia_dev:
24+
_target_: dpr.data.retriever_data.CsvQASrc
25+
file: data.retriever.qas.trivia-dev
26+
27+
webq_test:
28+
_target_: dpr.data.retriever_data.CsvQASrc
29+
file: data.retriever.qas.webq-test
30+
31+
curatedtrec_test:
32+
_target_: dpr.data.retriever_data.CsvQASrc
33+
file: data.retriever.qas.curatedtrec-test
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
defaults:
2+
- encoder: hf_bert # defines encoder initialization parameters
3+
- datasets: retriever_default # contains a list of all possible sources of queries for evaluation. Specific set is selected by qa_dataset parameter
4+
- ctx_sources: default_sources # contains a list of all possible passage sources. Specific passages sources selected by ctx_datatsets parameter
5+
6+
indexers:
7+
flat:
8+
_target_: dpr.indexer.faiss_indexers.DenseFlatIndexer
9+
10+
hnsw:
11+
_target_: dpr.indexer.faiss_indexers.DenseHNSWFlatIndexer
12+
13+
hnsw_sq:
14+
_target_: dpr.indexer.faiss_indexers.DenseHNSWSQIndexer
15+
16+
# the name of the queries dataset from the 'datasets' config group
17+
qa_dataset:
18+
19+
# a list of names of the passages datasets from the 'ctx_sources' config group
20+
ctx_datatsets:
21+
22+
#Glob paths to encoded passages (from generate_dense_embeddings tool)
23+
encoded_ctx_files: []
24+
25+
out_file:
26+
# "regex" or "string"
27+
match: string
28+
n_docs: 100
29+
validation_workers: 16
30+
31+
# Batch size to generate query embeddings
32+
batch_size: 128
33+
34+
# Whether to lower case the input text. Set True for uncased models, False for the cased ones.
35+
do_lower_case: True
36+
37+
# The attribute name of encoder to use for queries. Options for the BiEncoder model: question_model, ctx_model
38+
# question_model is used if this param is empty
39+
encoder_path:
40+
41+
# path to the FAISS index location - it is only needed if you want to serialize faiss index to files or read from them
42+
# (instead of using encoded_ctx_files)
43+
# it should point to either directory or a common index files prefix name
44+
# if there is no index at the specific location, the index will be created from encoded_ctx_files
45+
index_path:
46+
47+
kilt_out_file:
48+
49+
# A trained bi-encoder checkpoint file to initialize the model
50+
model_file:
51+
52+
validate_as_tables: False
53+
rpc_retriever_cfg_file:
54+
indexer: flat
55+
56+
# tokens which won't be slit by tokenizer
57+
special_tokens:
58+
59+
# TODO: move to a conf group
60+
# local_rank for distributed training on gpus
61+
local_rank: -1
62+
global_loss_buf_sz: 150000
63+
device:
64+
distributed_world_size:
65+
no_cuda: False
66+
n_gpu:
67+
fp16: False
68+
69+
# For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
70+
# "See details at https://nvidia.github.io/apex/amp.html
71+
fp16_opt_level: O1
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# @package _group_
2+
3+
# model type. One of [hf_bert, pytext_bert, fairseq_roberta]
4+
encoder_model_type: hf_bert
5+
6+
# HuggingFace's config name for model initialization
7+
pretrained_model_cfg: bert-base-uncased
8+
9+
# Some encoders need to be initialized from a file
10+
pretrained_file:
11+
12+
# Extra linear layer on top of standard bert/roberta encoder
13+
projection_dim: 0
14+
15+
# Max length of the encoder input sequence
16+
sequence_length: 256
17+
18+
dropout: 0.1
19+
20+
# whether to fix (don't update) context encoder during training or not
21+
fix_ctx_encoder: False
22+
23+
# if False, the model won't load pre-trained BERT weights
24+
pretrained: True

0 commit comments

Comments
 (0)