-
Notifications
You must be signed in to change notification settings - Fork 25.9k
/
retrieval_rag.py
674 lines (573 loc) · 29.2 KB
/
retrieval_rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RAG Retriever model implementation."""
import os
import pickle
import time
from typing import Iterable, List, Optional, Tuple
import numpy as np
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer
if is_datasets_available():
from datasets import Dataset, load_dataset, load_from_disk
if is_faiss_available():
import faiss
logger = logging.get_logger(__name__)
LEGACY_INDEX_PATH = "https://storage.googleapis.com/huggingface-nlp/datasets/wiki_dpr/"
class Index:
"""
A base class for the Indices encapsulated by the [`RagRetriever`].
"""
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
"""
Returns a list of dictionaries, containing titles and text of the retrieved documents.
Args:
doc_ids (`np.ndarray` of shape `(batch_size, n_docs)`):
A tensor of document indices.
"""
raise NotImplementedError
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
"""
For each query in the batch, retrieves `n_docs` documents.
Args:
question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
An array of query vectors.
n_docs (`int`):
The number of docs retrieved per query.
Returns:
`np.ndarray` of shape `(batch_size, n_docs)`: A tensor of indices of retrieved documents. `np.ndarray` of
shape `(batch_size, vector_size)`: A tensor of vector representations of retrieved documents.
"""
raise NotImplementedError
def is_initialized(self):
"""
Returns `True` if index is already initialized.
"""
raise NotImplementedError
def init_index(self):
"""
A function responsible for loading the index into memory. Should be called only once per training run of a RAG
model. E.g. if the model is trained on multiple GPUs in a distributed setup, only one of the workers will load
the index.
"""
raise NotImplementedError
class LegacyIndex(Index):
"""
An index which can be deserialized from the files built using https://github.com/facebookresearch/DPR. We use
default faiss index parameters as specified in that repository.
Args:
vector_size (`int`):
The dimension of indexed vectors.
index_path (`str`):
A path to a *directory* containing index files compatible with [`~models.rag.retrieval_rag.LegacyIndex`]
"""
INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index"
PASSAGE_FILENAME = "psgs_w100.tsv.pkl"
def __init__(self, vector_size, index_path):
self.index_id_to_db_id = []
self.index_path = index_path
self.passages = self._load_passages()
self.vector_size = vector_size
self.index = None
self._index_initialized = False
def _resolve_path(self, index_path, filename):
is_local = os.path.isdir(index_path)
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_file(index_path, filename)
except EnvironmentError:
msg = (
f"Can't load '{filename}'. Make sure that:\n\n"
f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
)
raise EnvironmentError(msg)
if is_local:
logger.info(f"loading file {resolved_archive_file}")
else:
logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
return resolved_archive_file
def _load_passages(self):
logger.info(f"Loading passages from {self.index_path}")
passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(passages_path, "rb") as passages_file:
passages = pickle.load(passages_file)
return passages
def _deserialize_index(self):
logger.info(f"Loading index from {self.index_path}")
resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
self.index = faiss.read_index(resolved_index_path)
resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
raise ValueError(
"This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
"malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
"that could have been tampered with. If you already verified the pickle data and decided to use it, "
"you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
)
with open(resolved_meta_path, "rb") as metadata_file:
self.index_id_to_db_id = pickle.load(metadata_file)
assert (
len(self.index_id_to_db_id) == self.index.ntotal
), "Deserialized index_id_to_db_id should match faiss index size"
def is_initialized(self):
return self._index_initialized
def init_index(self):
index = faiss.IndexHNSWFlat(self.vector_size + 1, 512)
index.hnsw.efSearch = 128
index.hnsw.efConstruction = 200
self.index = index
self._deserialize_index()
self._index_initialized = True
def get_doc_dicts(self, doc_ids: np.array):
doc_list = []
for doc_ids_i in doc_ids:
ids = [str(int(doc_id)) for doc_id in doc_ids_i]
docs = [self.passages[doc_id] for doc_id in ids]
doc_list.append(docs)
doc_dicts = []
for docs in doc_list:
doc_dict = {}
doc_dict["title"] = [doc[1] for doc in docs]
doc_dict["text"] = [doc[0] for doc in docs]
doc_dicts.append(doc_dict)
return doc_dicts
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
aux_dim = np.zeros(len(question_hidden_states), dtype="float32").reshape(-1, 1)
query_nhsw_vectors = np.hstack((question_hidden_states, aux_dim))
_, docs_ids = self.index.search(query_nhsw_vectors, n_docs)
vectors = [[self.index.reconstruct(int(doc_id))[:-1] for doc_id in doc_ids] for doc_ids in docs_ids]
ids = [[int(self.index_id_to_db_id[doc_id]) for doc_id in doc_ids] for doc_ids in docs_ids]
return np.array(ids), np.array(vectors)
class HFIndexBase(Index):
def __init__(self, vector_size, dataset, index_initialized=False):
self.vector_size = vector_size
self.dataset = dataset
self._index_initialized = index_initialized
self._check_dataset_format(with_index=index_initialized)
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")
def _check_dataset_format(self, with_index: bool):
if not isinstance(self.dataset, Dataset):
raise ValueError(f"Dataset should be a datasets.Dataset object, but got {type(self.dataset)}")
if len({"title", "text", "embeddings"} - set(self.dataset.column_names)) > 0:
raise ValueError(
"Dataset should be a dataset with the following columns: "
"title (str), text (str) and embeddings (arrays of dimension vector_size), "
f"but got columns {self.dataset.column_names}"
)
if with_index and "embeddings" not in self.dataset.list_indexes():
raise ValueError(
"Missing faiss index in the dataset. Make sure you called `dataset.add_faiss_index` to compute it "
"or `dataset.load_faiss_index` to load one from the disk."
)
def init_index(self):
raise NotImplementedError()
def is_initialized(self):
return self._index_initialized
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
_, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs)
docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids]
vectors = [doc["embeddings"] for doc in docs]
for i in range(len(vectors)):
if len(vectors[i]) < n_docs:
vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))])
return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
class CanonicalHFIndex(HFIndexBase):
"""
A wrapper around an instance of [`~datasets.Datasets`]. If `index_path` is set to `None`, we load the pre-computed
index available with the [`~datasets.arrow_dataset.Dataset`], otherwise, we load the index from the indicated path
on disk.
Args:
vector_size (`int`): the dimension of the passages embeddings used by the index
dataset_name (`str`, optional, defaults to `wiki_dpr`):
A dataset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids
with `datasets.list_datasets()`).
dataset_split (`str`, optional, defaults to `train`)
Which split of the `dataset` to load.
index_name (`str`, optional, defaults to `train`)
The index_name of the index associated with the `dataset`. The index loaded from `index_path` will be saved
under this name.
index_path (`str`, optional, defaults to `None`)
The path to the serialized faiss index on disk.
use_dummy_dataset (`bool`, optional, defaults to `False`):
If True, use the dummy configuration of the dataset for tests.
"""
def __init__(
self,
vector_size: int,
dataset_name: str = "wiki_dpr",
dataset_split: str = "train",
index_name: Optional[str] = None,
index_path: Optional[str] = None,
use_dummy_dataset=False,
dataset_revision=None,
):
if int(index_path is None) + int(index_name is None) != 1:
raise ValueError("Please provide `index_name` or `index_path`.")
self.dataset_name = dataset_name
self.dataset_split = dataset_split
self.index_name = index_name
self.index_path = index_path
self.use_dummy_dataset = use_dummy_dataset
self.dataset_revision = dataset_revision
logger.info(f"Loading passages from {self.dataset_name}")
dataset = load_dataset(
self.dataset_name,
with_index=False,
split=self.dataset_split,
dummy=self.use_dummy_dataset,
revision=dataset_revision,
)
super().__init__(vector_size, dataset, index_initialized=False)
def init_index(self):
if self.index_path is not None:
logger.info(f"Loading index from {self.index_path}")
self.dataset.load_faiss_index("embeddings", file=self.index_path)
else:
logger.info(f"Loading index from {self.dataset_name} with index name {self.index_name}")
self.dataset = load_dataset(
self.dataset_name,
with_embeddings=True,
with_index=True,
split=self.dataset_split,
index_name=self.index_name,
dummy=self.use_dummy_dataset,
revision=self.dataset_revision,
)
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
self._index_initialized = True
class CustomHFIndex(HFIndexBase):
"""
A wrapper around an instance of [`~datasets.Datasets`]. The dataset and the index are both loaded from the
indicated paths on disk.
Args:
vector_size (`int`): the dimension of the passages embeddings used by the index
dataset_path (`str`):
The path to the serialized dataset on disk. The dataset should have 3 columns: title (str), text (str) and
embeddings (arrays of dimension vector_size)
index_path (`str`)
The path to the serialized faiss index on disk.
"""
def __init__(self, vector_size: int, dataset, index_path=None):
super().__init__(vector_size, dataset, index_initialized=index_path is None)
self.index_path = index_path
@classmethod
def load_from_disk(cls, vector_size, dataset_path, index_path):
logger.info(f"Loading passages from {dataset_path}")
if dataset_path is None or index_path is None:
raise ValueError(
"Please provide `dataset_path` and `index_path` after calling `dataset.save_to_disk(dataset_path)` "
"and `dataset.get_index('embeddings').save(index_path)`."
)
dataset = load_from_disk(dataset_path)
return cls(vector_size=vector_size, dataset=dataset, index_path=index_path)
def init_index(self):
if not self.is_initialized():
logger.info(f"Loading index from {self.index_path}")
self.dataset.load_faiss_index("embeddings", file=self.index_path)
self._index_initialized = True
class RagRetriever:
"""
Retriever used to get documents from vector queries. It retrieves the documents embeddings as well as the documents
contents, and it formats them to be used with a RagModel.
Args:
config ([`RagConfig`]):
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which
`Index` to build. You can load your own custom dataset with `config.index_name="custom"` or use a canonical
one (default) from the datasets library with `config.index_name="wiki_dpr"` for example.
question_encoder_tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that was used to tokenize the question. It is used to decode the question and then use the
generator_tokenizer.
generator_tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for the generator part of the RagModel.
index ([`~models.rag.retrieval_rag.Index`], optional, defaults to the one defined by the configuration):
If specified, use this index instead of the one built using the configuration
Examples:
```python
>>> # To load the default "wiki_dpr" dataset with 21M passages from wikipedia (index name is 'compressed' or 'exact')
>>> from transformers import RagRetriever
>>> retriever = RagRetriever.from_pretrained(
... "facebook/dpr-ctx_encoder-single-nq-base", dataset="wiki_dpr", index_name="compressed"
... )
>>> # To load your own indexed dataset built with the datasets library. More info on how to build the indexed dataset in examples/rag/use_own_knowledge_dataset.py
>>> from transformers import RagRetriever
>>> dataset = (
... ...
... ) # dataset must be a datasets.Datasets object with columns "title", "text" and "embeddings", and it must have a faiss index
>>> retriever = RagRetriever.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", indexed_dataset=dataset)
>>> # To load your own indexed dataset built with the datasets library that was saved on disk. More info in examples/rag/use_own_knowledge_dataset.py
>>> from transformers import RagRetriever
>>> dataset_path = "path/to/my/dataset" # dataset saved via *dataset.save_to_disk(...)*
>>> index_path = "path/to/my/index.faiss" # faiss index saved via *dataset.get_index("embeddings").save(...)*
>>> retriever = RagRetriever.from_pretrained(
... "facebook/dpr-ctx_encoder-single-nq-base",
... index_name="custom",
... passages_path=dataset_path,
... index_path=index_path,
... )
>>> # To load the legacy index built originally for Rag's paper
>>> from transformers import RagRetriever
>>> retriever = RagRetriever.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", index_name="legacy")
```"""
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True):
self._init_retrieval = init_retrieval
requires_backends(self, ["datasets", "faiss"])
super().__init__()
self.index = index or self._build_index(config)
self.generator_tokenizer = generator_tokenizer
self.question_encoder_tokenizer = question_encoder_tokenizer
self.n_docs = config.n_docs
self.batch_size = config.retrieval_batch_size
self.config = config
if self._init_retrieval:
self.init_retrieval()
self.ctx_encoder_tokenizer = None
self.return_tokenized_docs = False
@staticmethod
def _build_index(config):
if config.index_name == "legacy":
return LegacyIndex(
config.retrieval_vector_size,
config.index_path or LEGACY_INDEX_PATH,
)
elif config.index_name == "custom":
return CustomHFIndex.load_from_disk(
vector_size=config.retrieval_vector_size,
dataset_path=config.passages_path,
index_path=config.index_path,
)
else:
return CanonicalHFIndex(
vector_size=config.retrieval_vector_size,
dataset_name=config.dataset,
dataset_split=config.dataset_split,
index_name=config.index_name,
index_path=config.index_path,
use_dummy_dataset=config.use_dummy_dataset,
dataset_revision=config.dataset_revision,
)
@classmethod
def from_pretrained(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
requires_backends(cls, ["datasets", "faiss"])
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder
generator_tokenizer = rag_tokenizer.generator
if indexed_dataset is not None:
config.index_name = "custom"
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
else:
index = cls._build_index(config)
return cls(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
)
def save_pretrained(self, save_directory):
if isinstance(self.index, CustomHFIndex):
if self.config.index_path is None:
index_path = os.path.join(save_directory, "hf_dataset_index.faiss")
self.index.dataset.get_index("embeddings").save(index_path)
self.config.index_path = index_path
if self.config.passages_path is None:
passages_path = os.path.join(save_directory, "hf_dataset")
# datasets don't support save_to_disk with indexes right now
faiss_index = self.index.dataset._indexes.pop("embeddings")
self.index.dataset.save_to_disk(passages_path)
self.index.dataset._indexes["embeddings"] = faiss_index
self.config.passages_path = passages_path
self.config.save_pretrained(save_directory)
rag_tokenizer = RagTokenizer(
question_encoder=self.question_encoder_tokenizer,
generator=self.generator_tokenizer,
)
rag_tokenizer.save_pretrained(save_directory)
def init_retrieval(self):
"""
Retriever initialization function. It loads the index into memory.
"""
logger.info("initializing retrieval")
self.index.init_index()
def postprocess_docs(self, docs, input_strings, prefix, n_docs, return_tensors=None):
r"""
Postprocessing retrieved `docs` and combining them with `input_strings`.
Args:
docs (`dict`):
Retrieved documents.
input_strings (`str`):
Input strings decoded by `preprocess_query`.
prefix (`str`):
Prefix added at the beginning of each input, typically used with T5-based models.
Return:
`tuple(tensors)`: a tuple consisting of two elements: contextualized `input_ids` and a compatible
`attention_mask`.
"""
def cat_input_and_doc(doc_title, doc_text, input_string, prefix):
# TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation
# TODO(piktus): better handling of truncation
if doc_title.startswith('"'):
doc_title = doc_title[1:]
if doc_title.endswith('"'):
doc_title = doc_title[:-1]
if prefix is None:
prefix = ""
out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace(
" ", " "
)
return out
rag_input_strings = [
cat_input_and_doc(
docs[i]["title"][j],
docs[i]["text"][j],
input_strings[i],
prefix,
)
for i in range(len(docs))
for j in range(n_docs)
]
contextualized_inputs = self.generator_tokenizer.batch_encode_plus(
rag_input_strings,
max_length=self.config.max_combined_length,
return_tensors=return_tensors,
padding="max_length",
truncation=True,
)
return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"]
def _chunk_tensor(self, t: Iterable, chunk_size: int) -> List[Iterable]:
return [t[i : i + chunk_size] for i in range(0, len(t), chunk_size)]
def _main_retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, np.ndarray]:
question_hidden_states_batched = self._chunk_tensor(question_hidden_states, self.batch_size)
ids_batched = []
vectors_batched = []
for question_hidden_states in question_hidden_states_batched:
start_time = time.time()
ids, vectors = self.index.get_top_docs(question_hidden_states, n_docs)
logger.debug(
f"index search time: {time.time() - start_time} sec, batch size {question_hidden_states.shape}"
)
ids_batched.extend(ids)
vectors_batched.extend(vectors)
return (
np.array(ids_batched),
np.array(vectors_batched),
) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
"""
Retrieves documents for specified `question_hidden_states`.
Args:
question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
A batch of query vectors to retrieve with.
n_docs (`int`):
The number of docs retrieved per query.
Return:
`Tuple[np.ndarray, np.ndarray, List[dict]]`: A tuple with the following objects:
- **retrieved_doc_embeds** (`np.ndarray` of shape `(batch_size, n_docs, dim)`) -- The retrieval embeddings
of the retrieved docs per query.
- **doc_ids** (`np.ndarray` of shape `(batch_size, n_docs)`) -- The ids of the documents in the index
- **doc_dicts** (`List[dict]`): The `retrieved_doc_embeds` examples per query.
"""
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
def set_ctx_encoder_tokenizer(self, ctx_encoder_tokenizer: PreTrainedTokenizer):
# used in end2end retriever training
self.ctx_encoder_tokenizer = ctx_encoder_tokenizer
self.return_tokenized_docs = True
def __call__(
self,
question_input_ids: List[List[int]],
question_hidden_states: np.ndarray,
prefix=None,
n_docs=None,
return_tensors=None,
) -> BatchEncoding:
"""
Retrieves documents for specified `question_hidden_states`.
Args:
question_input_ids (`List[List[int]]`) batch of input ids
question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`:
A batch of query vectors to retrieve with.
prefix (`str`, *optional*):
The prefix used by the generator's tokenizer.
n_docs (`int`, *optional*):
The number of docs retrieved per query.
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
- **context_input_ids** -- List of token ids to be fed to a model.
[What are input IDs?](../glossary#input-ids)
- **context_attention_mask** -- List of indices specifying which tokens should be attended to by the model
(when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
[What are attention masks?](../glossary#attention-mask)
- **retrieved_doc_embeds** -- List of embeddings of the retrieved documents
- **doc_ids** -- List of ids of the retrieved documents
"""
n_docs = n_docs if n_docs is not None else self.n_docs
prefix = prefix if prefix is not None else self.config.generator.prefix
retrieved_doc_embeds, doc_ids, docs = self.retrieve(question_hidden_states, n_docs)
input_strings = self.question_encoder_tokenizer.batch_decode(question_input_ids, skip_special_tokens=True)
context_input_ids, context_attention_mask = self.postprocess_docs(
docs, input_strings, prefix, n_docs, return_tensors=return_tensors
)
if self.return_tokenized_docs:
retrieved_doc_text = []
retrieved_doc_title = []
for b_idx in range(len(docs)):
for doc_idx in range(n_docs):
retrieved_doc_text.append(docs[b_idx]["text"][doc_idx])
retrieved_doc_title.append(docs[b_idx]["title"][doc_idx])
tokenized_docs = self.ctx_encoder_tokenizer(
retrieved_doc_title,
retrieved_doc_text,
truncation=True,
padding="longest",
return_tensors=return_tensors,
)
return BatchEncoding(
{
"context_input_ids": context_input_ids,
"context_attention_mask": context_attention_mask,
"retrieved_doc_embeds": retrieved_doc_embeds,
"doc_ids": doc_ids,
"tokenized_doc_ids": tokenized_docs["input_ids"],
"tokenized_doc_attention_mask": tokenized_docs["attention_mask"],
},
tensor_type=return_tensors,
)
else:
return BatchEncoding(
{
"context_input_ids": context_input_ids,
"context_attention_mask": context_attention_mask,
"retrieved_doc_embeds": retrieved_doc_embeds,
"doc_ids": doc_ids,
},
tensor_type=return_tensors,
)