/
modeling_tf_rag.py
1812 lines (1553 loc) 路 88 KB
/
modeling_tf_rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFRAG model implementation."""
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
import tensorflow as tf
from ...configuration_utils import PretrainedConfig
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, shape_list, unpack_inputs
from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_rag import RagConfig
from .retrieval_rag import RagRetriever
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "RagConfig"
@dataclass
class TFRetrievAugLMMarginOutput(ModelOutput):
"""
Base class for retriever augmented marginalized models outputs.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss.
logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
each vocabulary token.
past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
sequence_length, embed_size_per_head)`).
Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
(see `past_key_values` input) to speed up sequential decoding.
doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
`question_encoder_last_hidden_state`.
retrieved_doc_embeds (`tf.Tensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
the `doc_scores`.
retrieved_doc_ids (`tf.Tensor` (int32) of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
The indexes of the embedded documents retrieved by the retriever.
context_input_ids (`tf.Tensor`(int32) of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
context_attention_mask (`tf.Tensor` (int32) of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
retriever.
question_encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
model.
question_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
question_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
generator_enc_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
generator_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
generator_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
generator_dec_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
generator_dec_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
doc_scores: Optional[tf.Tensor] = None
retrieved_doc_embeds: Optional[tf.Tensor] = None
retrieved_doc_ids: Optional[tf.Tensor] = None
context_input_ids: Optional[tf.Tensor] = None
context_attention_mask: Optional[tf.Tensor] = None
question_encoder_last_hidden_state: Optional[tf.Tensor] = None
question_enc_hidden_states: Optional[Tuple[tf.Tensor]] = None
question_enc_attentions: Optional[Tuple[tf.Tensor]] = None
generator_enc_last_hidden_state: Optional[tf.Tensor] = None
generator_enc_hidden_states: Optional[Tuple[tf.Tensor]] = None
generator_enc_attentions: Optional[Tuple[tf.Tensor]] = None
generator_dec_hidden_states: Optional[Tuple[tf.Tensor]] = None
generator_dec_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFRetrievAugLMOutput(ModelOutput):
"""
Args:
logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
each vocabulary token.
past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
sequence_length, embed_size_per_head)`).
Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
(see `past_key_values` input) to speed up sequential decoding.
doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
`question_encoder_last_hidden_state`.
retrieved_doc_embeds (`tf.Tensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
the `doc_scores`.
retrieved_doc_ids (`tf.Tensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
The indexes of the embedded documents retrieved by the retriever.
context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
context_attention_mask (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
retriever.
question_encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
model.
question_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
question_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
generator_enc_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
generator_enc_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
generator_enc_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
generator_dec_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings and one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
generator_dec_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
"""
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
doc_scores: Optional[tf.Tensor] = None
retrieved_doc_embeds: Optional[tf.Tensor] = None
retrieved_doc_ids: Optional[tf.Tensor] = None
context_input_ids: Optional[tf.Tensor] = None
context_attention_mask: Optional[tf.Tensor] = None
question_encoder_last_hidden_state: Optional[tf.Tensor] = None
question_enc_hidden_states: Optional[Tuple[tf.Tensor]] = None
question_enc_attentions: Optional[Tuple[tf.Tensor]] = None
generator_enc_last_hidden_state: Optional[tf.Tensor] = None
generator_enc_hidden_states: Optional[Tuple[tf.Tensor]] = None
generator_enc_attentions: Optional[Tuple[tf.Tensor]] = None
generator_dec_hidden_states: Optional[Tuple[tf.Tensor]] = None
generator_dec_attentions: Optional[Tuple[tf.Tensor]] = None
class TFRagPreTrainedModel(TFPreTrainedModel):
r"""
RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP
Tasks](https://arxiv.org/abs/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.
RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a
generator, the encoder and generator are trainable while the retriever is just an indexed dataset.
"""
config_class = RagConfig
base_model_prefix = "rag"
_keys_to_ignore_on_load_missing = [r"position_ids"]
@classmethod
def from_pretrained_question_encoder_generator(
cls,
question_encoder_pretrained_model_name_or_path: str = None,
generator_pretrained_model_name_or_path: str = None,
retriever: RagRetriever = None,
*model_args,
**kwargs
) -> TFPreTrainedModel:
r"""
Instantiates an question encoder and a generator from one or two base classes of the library from pretrained
model checkpoints.
Params:
question_encoder_pretrained_model_name_or_path (`str`, *optional*):
Information necessary to initiate the question encoder. Can be either:
- A string with the *shortcut name* of a pretrained model to load from cache or download, e.g.,
`bert-base-uncased`.
- A string with the *identifier name* of a pretrained model that was user-uploaded to our S3, e.g.,
`dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case,
`question_encoder_from_pt` should be set to `True`.
generator_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
Information necessary to initiate the generator. Can be either:
- A string with the *shortcut name* of a pretrained model to load from cache or download, e.g.,
`t5-small`.
- A string with the *identifier name* of a pretrained model that was user-uploaded to our S3, e.g.,
`facebook/bart-base`.
- A path to a *directory* containing model weights saved using
[`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case,
`generator_from_pt` should be set to `True`.
model_args (remaining positional arguments, *optional*):
All remaining positional arguments will be passed to the underlying model's `__init__` method.
retriever ([`RagRetriever`], *optional*):
The retriever to use.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`).
- To update the question_encoder configuration, use the prefix *question_encoder_* for each
configuration parameter.
- To update the generator configuration, use the prefix *generator_* for each configuration parameter.
- To update the parent model configuration, do not use a prefix for each configuration parameter.
Behaves differently depending on whether a `config` is provided or automatically loaded.
Example:
```python
>>> from transformers import RagRetriever, TFRagModel
>>> # initialize a RAG from two pretrained models.
>>> model = TFRagModel.from_pretrained_question_encoder_generator(
... "facebook/dpr-question_encoder-single-nq-base", "t5-small"
... )
>>> # alternatively, initialize from pytorch pretrained models can also be done
>>> model = TFRagModel.from_pretrained_question_encoder_generator(
... "facebook/dpr-question_encoder-single-nq-base",
... "facebook/bart-base",
... generator_from_pt=True,
... question_encoder_from_pt=True,
... )
>>> # saving model after fine-tuning
>>> model.save_pretrained("./rag")
>>> # load retriever
>>> retriever = RagRetriever.from_pretrained(
... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
... )
>>> # load fine-tuned model with retriever
>>> model = TFRagModel.from_pretrained("./rag", retriever=retriever)
```"""
kwargs_question_encoder = {
argument[len("question_encoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("question_encoder_")
}
kwargs_generator = {
argument[len("generator_") :]: value
for argument, value in kwargs.items()
if argument.startswith("generator_")
}
# remove question_encoder, generator kwargs from kwargs
for key in kwargs_question_encoder.keys():
del kwargs["question_encoder_" + key]
for key in kwargs_generator.keys():
del kwargs["generator_" + key]
# Load and initialize the question_encoder and generator
# The distinction between question_encoder and generator at the model level is made
# by the value of the flag `is_generator` that we need to set correctly.
question_encoder = kwargs_question_encoder.pop("model", None)
if question_encoder is None:
assert question_encoder_pretrained_model_name_or_path is not None, (
"If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to"
" be defined"
)
from ..auto.modeling_tf_auto import TFAutoModel
if "config" not in kwargs_question_encoder:
from ..auto.configuration_auto import AutoConfig
question_encoder_config = AutoConfig.from_pretrained(question_encoder_pretrained_model_name_or_path)
kwargs_question_encoder["config"] = question_encoder_config
question_encoder = TFAutoModel.from_pretrained(
question_encoder_pretrained_model_name_or_path,
name="question_encoder",
load_weight_prefix=cls.load_weight_prefix,
*model_args,
**kwargs_question_encoder,
)
generator = kwargs_generator.pop("generator", None)
if generator is None:
assert generator_pretrained_model_name_or_path is not None, (
"If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has"
" to be defined"
)
from ..auto.modeling_tf_auto import TFAutoModelForSeq2SeqLM
if "config" not in kwargs_generator:
from ..auto.configuration_auto import AutoConfig
generator_config = AutoConfig.from_pretrained(generator_pretrained_model_name_or_path)
kwargs_generator["config"] = generator_config
generator = TFAutoModelForSeq2SeqLM.from_pretrained(
generator_pretrained_model_name_or_path,
name="generator",
load_weight_prefix=cls.load_weight_prefix,
**kwargs_generator,
)
# instantiate config with corresponding kwargs
config = kwargs.get("config", None)
if config is None:
config = RagConfig.from_question_encoder_generator_configs(
question_encoder.config, generator.config, **kwargs
)
return cls(question_encoder=question_encoder, generator=generator, config=config, retriever=retriever)
RAG_START_DOCSTRING = r"""
RAG is a sequence-to-sequence model which encapsulates two core components: a question encoder and a generator.
During a forward pass, we encode the input with the question encoder and pass it to the retriever to extract
relevant context documents. The documents are then prepended to the input. Such contextualized inputs is passed to
the generator.
The question encoder can be any *autoencoding* model, preferably [`TFDPRQuestionEncoder`], and the generator can be
any *seq2seq* model, preferably [`TFBartForConditionalGeneration`].
The model can be initialized with a [`RagRetriever`] for end-to-end generation or used in combination with the
outputs of a retriever in multiple steps---see examples for more details. The model is compatible any
*autoencoding* model as the `question_encoder` and any *seq2seq* model with language model head as the `generator`.
It has been tested with [`TFDPRQuestionEncoder`] as the `question_encoder` and [`TFBartForConditionalGeneration`]
as the `generator`.
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Tensorflow [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to
general usage and behavior.
The model is in a developing state as it is now fully supports in eager-mode only, and may not be exported in
SavedModel format.
Args:
config ([`RagConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
question_encoder ([`TFPreTrainedModel`]):
An encoder model compatible with the faiss index encapsulated by the `retriever`.
generator ([`TFPreTrainedModel`]):
A seq2seq model used as the generator in the RAG architecture.
retriever ([`RagRetriever`]):
A retriever class encapsulating a faiss index queried to obtain context documents for current inputs.
"""
RAG_FORWARD_INPUTS_DOCSTRING = r"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
obtain the indices.
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*)
Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
*optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
generator's encoder.
Used by the ([`TFRagModel`]) model during decoding.
decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Provide for generation tasks. `None` by default, construct as per instructions for the generator model
you're using with your RAG instance.
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
past_key_values (`tuple(tuple(tf.Tensor))`):
Tuple consists of two elements: `encoder_outputs` of the RAG model (see `encoder_outputs`) and
`past_key_values` of the underlying generator. Can be used to speed up decoding. `past_key_values` are used
in the ([`RagTokenForGeneration`]) model during decoding.
doc_scores (`tf.Tensor` of shape `(batch_size, config.n_docs)`):
Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
`question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
has to be provided to the forward pass. `doc_scores` can be computed via
`question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
context_input_ids (`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
retriever.
If the model has is not initialized with a `retriever` ``context_input_ids` has to be provided to the
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. context_attention_mask
(`tf.Tensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when
*output_retrieved=True*): Attention mask post-processed from the retrieved documents and the question
encoder `input_ids` by the retriever.
If the model has is not initialized with a `retriever` `context_attention_mask` has to be provided to the
forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
use_cache (`bool`, *optional*, defaults to `True`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
output_retrieved(`bool`, *optional*):
Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
`context_attention_mask`. See returned tensors for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`TFRetrievAugLMOutput`] instead of a plain tuple.
n_docs (`int`, *optional*, defaults to `config.n_docs``)
Number of documents to retrieve and/or number of documents for which to generate an answer.
"""
@add_start_docstrings_to_model_forward(RAG_START_DOCSTRING)
class TFRagModel(TFRagPreTrainedModel):
load_weight_prefix = "tf_rag_model_1"
def __init__(
self,
config: Optional[PretrainedConfig] = None,
question_encoder: Optional[TFPreTrainedModel] = None,
generator: Optional[TFPreTrainedModel] = None,
retriever: Optional = None,
load_weight_prefix: Optional[str] = None,
**kwargs,
):
assert config is not None or (
question_encoder is not None and generator is not None
), "Either a configuration or an question_encoder and a generator has to be provided."
if config is None:
config = RagConfig.from_question_encoder_generator_configs(
question_encoder.config, generator.config, **kwargs
)
else:
assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}"
super().__init__(config, **kwargs)
if question_encoder is None:
from ..auto.modeling_tf_auto import TFAutoModel
question_encoder = TFAutoModel.from_config(config.question_encoder, name="question_encoder")
if generator is None:
from ..auto.modeling_tf_auto import TFAutoModelForSeq2SeqLM
load_weight_prefix = load_weight_prefix if load_weight_prefix is not None else self.load_weight_prefix
generator = TFAutoModelForSeq2SeqLM.from_config(
config.generator, name="generator", load_weight_prefix=load_weight_prefix + "/generator"
)
self.retriever = retriever
if self.retriever is not None:
assert isinstance(
retriever, RagRetriever
), f"`self.retriever` is of type {type(self.retriever)}, but should be of type `RagRetriever`"
self.retriever = retriever
self.question_encoder = question_encoder
self.generator = generator
def set_retriever(self, retriever: RagRetriever):
self.retriever = retriever
@unpack_inputs
@add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFRetrievAugLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids=None,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
past_key_values=None,
doc_scores=None,
context_input_ids=None,
context_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
output_retrieved=None,
n_docs=None,
return_dict=None,
training=False,
**kwargs
):
r"""
Returns:
Example:
```python
>>> from transformers import RagTokenizer, RagRetriever, TFRagModel
>>> import torch
>>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
>>> retriever = RagRetriever.from_pretrained(
... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
... )
>>> # initialize with RagRetriever to do everything in one forward call
>>> model = TFRagModel.from_pretrained("facebook/rag-token-base", retriever=retriever, from_pt=True)
>>> input_dict = tokenizer.prepare_seq2seq_batch(
... "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf"
... )
>>> input_ids = input_dict["input_ids"]
>>> outputs = model(input_ids)
```"""
assert (
"decoder_cached_states" not in kwargs
), "Please use past_key_values to cache intermediate outputs" # from modeling_tf_bart.py
# aliasing to minimize code changing
n_docs = n_docs if n_docs is not None else self.config.n_docs
# whether retriever has to be used
has_to_retrieve = (
self.retriever is not None
and (context_input_ids is None or context_attention_mask is None or doc_scores is None)
and encoder_outputs is None
)
# encoder_outputs are pre-computed during RAG-token generation
if encoder_outputs is None:
if has_to_retrieve:
question_enc_outputs = self.question_encoder(
input_ids, attention_mask=attention_mask, return_dict=True, training=training
)
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/dpr/modeling_tf_dpr.py#L91
question_encoder_last_hidden_state = question_enc_outputs[
0
] # hidden states of question encoder => pooler_output
retriever_outputs = self.retriever(
input_ids,
question_encoder_last_hidden_state.numpy(),
prefix=self.generator.config.prefix,
n_docs=n_docs,
return_tensors="tf",
)
context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = (
retriever_outputs["context_input_ids"],
retriever_outputs["context_attention_mask"],
retriever_outputs["retrieved_doc_embeds"],
retriever_outputs["doc_ids"],
)
context_input_ids = tf.cast(context_input_ids, tf.int32)
context_attention_mask = tf.cast(context_attention_mask, tf.int32)
retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32)
retrieved_doc_ids = tf.cast(retrieved_doc_ids, tf.int32)
# compute doc_scores
doc_scores = tf.squeeze(
tf.matmul(
tf.expand_dims(question_encoder_last_hidden_state, axis=1),
retrieved_doc_embeds,
transpose_b=True,
),
axis=1,
)
else:
assert context_input_ids is not None, (
"Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can"
" set a retriever using the `set_retriever(...)` function."
)
assert context_attention_mask is not None, (
"Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you"
" can set a retriever using the `set_retriever(...)` function."
)
assert doc_scores is not None, (
"Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a"
" retriever using the `set_retriever(...)` function."
)
assert (
doc_scores is not None
), "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function."
assert (doc_scores.shape[1] % n_docs) == 0, (
f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
f" {context_input_ids.shape[0]}."
)
# Decoder input without context documents
if decoder_input_ids is not None:
decoder_input_ids = tf.repeat(decoder_input_ids, n_docs, axis=0)
if decoder_attention_mask is not None:
decoder_attention_mask = tf.repeat(decoder_attention_mask, n_docs, axis=0)
gen_outputs = self.generator(
context_input_ids,
attention_mask=context_attention_mask,
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
return_dict=True,
training=training,
)
if not has_to_retrieve:
question_encoder_last_hidden_state = None
question_enc_hidden_states = None
question_enc_attentions = None
retrieved_doc_embeds = None
retrieved_doc_ids = None
else:
question_enc_hidden_states = question_enc_outputs.hidden_states
question_enc_attentions = question_enc_outputs.attentions
if not has_to_retrieve or not output_retrieved:
# don't output retrieved docs
context_input_ids = (None,)
context_attention_mask = None
retrieved_doc_embeds = None
retrieved_doc_ids = None
return TFRetrievAugLMOutput(
logits=gen_outputs.logits,
doc_scores=doc_scores,
past_key_values=gen_outputs.past_key_values,
context_input_ids=context_input_ids,
context_attention_mask=context_attention_mask,
retrieved_doc_embeds=retrieved_doc_embeds,
retrieved_doc_ids=retrieved_doc_ids,
question_encoder_last_hidden_state=question_encoder_last_hidden_state,
question_enc_hidden_states=question_enc_hidden_states,
question_enc_attentions=question_enc_attentions,
generator_enc_last_hidden_state=gen_outputs.encoder_last_hidden_state,
generator_enc_hidden_states=gen_outputs.encoder_hidden_states,
generator_enc_attentions=gen_outputs.encoder_attentions,
generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
generator_dec_attentions=gen_outputs.decoder_attentions,
)
@add_start_docstrings_to_model_forward(
"""
A TF RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass.
""",
RAG_START_DOCSTRING,
)
class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss):
load_weight_prefix = "tf_rag_token_for_generation_1/rag"
def __init__(
self,
config: Optional[PretrainedConfig] = None,
question_encoder: Optional[TFPreTrainedModel] = None,
generator: Optional[TFPreTrainedModel] = None,
retriever: Optional = None,
**kwargs,
):
assert config is not None or (
question_encoder is not None and generator is not None
), "Either a configuration or an encoder and a generator has to be provided."
if config is None:
config = RagConfig.from_question_encoder_generator_configs(
question_encoder.config, generator.config, **kwargs
)
super().__init__(config)
# instantiate model
self.rag = TFRagModel(
config=config,
question_encoder=question_encoder,
generator=generator,
retriever=retriever,
load_weight_prefix=self.load_weight_prefix,
name="rag",
)
def set_retriever(self, retriever: RagRetriever):
self.rag.retriever = retriever
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_bart.py
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
doc_scores=None,
n_docs=None,
**kwargs
):
if past is not None:
# if past is defined use only last decoder_input_ids
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None,
"encoder_outputs": encoder_outputs,
"doc_scores": doc_scores,
"context_attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"past_key_values": past,
"use_cache": use_cache,
"do_marginalize": True,
"n_docs": n_docs,
}
@property
def retriever(self):
return self.rag.retriever
@property
def generator(self):
return self.rag.generator
@property
def question_encoder(self):
return self.rag.question_encoder
@staticmethod
def _reorder_cache(past, beam_idx):
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
def _reorder_stacked(hidden_states, new_order):
n_docs = hidden_states.shape[0] // new_order.shape[0]
hidden_states = tf.reshape(hidden_states, (-1, n_docs, *hidden_states.shape[1:]))
hidden_states = tf.gather(hidden_states, new_order, axis=0)
result = tf.reshape(hidden_states, (-1, *hidden_states.shape[2:]))
return result
reordered_past = ()
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
return reordered_past
def marginalize(self, seq_logits, doc_scores, n_docs=None):
n_docs = n_docs if n_docs is not None else self.config.n_docs
# RAG-token marginalization
seq_logprobs = tf.nn.log_softmax(seq_logits, axis=-1)
seq_logprobs = tf.reshape(seq_logprobs, [seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.shape[-1]])
doc_logprobs = tf.nn.log_softmax(doc_scores, axis=1)
doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1)
doc_logprobs = tf.expand_dims(doc_logprobs, axis=-1) # twice
log_prob_sum = seq_logprobs + doc_logprobs
return tf.reduce_logsumexp(log_prob_sum, axis=1)
@unpack_inputs
@add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFRetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
doc_scores=None,
context_input_ids=None,
context_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
output_retrieved=None,
n_docs=None,
do_marginalize=None,
labels=None,
reduce_loss=None,
return_dict=None,
training=False,
**kwargs # needs kwargs for generation
):
r"""
do_marginalize (`bool`, *optional*):
If `True`, the logits are marginalized over all documents by making use of
`torch.nn.functional.log_softmax`.
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the cross entropy classification loss according to Rag-Token model formulation See
https://arxiv.org/pdf/2005.11401.pdf Section 2.1 for details about Rag-Token formulation. Indices should be
in `[0, ..., config.vocab_size - 1]`.
reduce_loss (`bool`, *optional*):
Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `tf.Tensor.sum`
operation.
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
Legacy dictionary, which is required so that model can use *generate()* function.
Returns:
Example:
```python
>>> import tensorflow as tf
>>> from transformers import RagTokenizer, RagRetriever, TFRagTokenForGeneration
>>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
>>> retriever = RagRetriever.from_pretrained(
... "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True
... )
>>> # initialize with RagRetriever to do everything in one forward call
>>> model = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever, from_pt=True)
>>> input_dict = tokenizer.prepare_seq2seq_batch(
... "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf"
... )
>>> outputs = model(input_dict, output_retrieved=True)
>>> # or use retriever separately
>>> # 1. Encode
>>> input_ids = input_dict["input_ids"]
>>> question_hidden_states = model.question_encoder(input_ids)[0]
>>> # 2. Retrieve
>>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf")
>>> doc_scores = tf.squeeze(
... tf.matmul(
... tf.expand_dims(question_hidden_states, axis=1), docs_dict["retrieved_doc_embeds"], transpose_b=True
... ),
... axis=1,
... )
>>> # 3. Forward to generator
>>> outputs = model(
... inputs=None,
... context_input_ids=docs_dict["context_input_ids"],
... context_attention_mask=docs_dict["context_attention_mask"],
... doc_scores=doc_scores,
... decoder_input_ids=input_dict["labels"],
... )
>>> # or directly generate
>>> generated = model.generate(
... context_input_ids=docs_dict["context_input_ids"],
... context_attention_mask=docs_dict["context_attention_mask"],
... doc_scores=doc_scores,
... )
>>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
```"""
assert (
"decoder_cached_states" not in kwargs
), "Please use past_key_values to cache intermediate outputs" # from modeling_tf_bart.py
do_marginalize = do_marginalize if do_marginalize else self.config.do_marginalize
reduce_loss = reduce_loss if reduce_loss else self.config.reduce_loss
if labels is not None:
if decoder_input_ids is None:
decoder_input_ids = labels
use_cache = False
outputs = self.rag(
input_ids,
attention_mask=attention_mask,
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
context_input_ids=context_input_ids,
context_attention_mask=context_attention_mask,
doc_scores=doc_scores,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_retrieved=output_retrieved,
n_docs=n_docs,
training=training,
)
loss = None
logits = outputs.logits
if labels is not None:
assert decoder_input_ids is not None
loss = self.get_nll(
outputs.logits,
outputs.doc_scores,
labels,
reduce_loss=reduce_loss,
epsilon=self.config.label_smoothing,
n_docs=n_docs,
)
if do_marginalize:
logits = self.marginalize(logits, outputs.doc_scores, n_docs)
return TFRetrievAugLMMarginOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
doc_scores=outputs.doc_scores,
context_input_ids=outputs.context_input_ids,
context_attention_mask=outputs.context_attention_mask,
retrieved_doc_embeds=outputs.retrieved_doc_embeds,
retrieved_doc_ids=outputs.retrieved_doc_ids,
question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state,
question_enc_hidden_states=outputs.question_enc_hidden_states,
question_enc_attentions=outputs.question_enc_attentions,
generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state,
generator_enc_hidden_states=outputs.generator_enc_hidden_states,
generator_enc_attentions=outputs.generator_enc_attentions,
generator_dec_hidden_states=outputs.generator_dec_hidden_states,
generator_dec_attentions=outputs.generator_dec_attentions,
)
def generate(
self,
input_ids: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
context_input_ids=None,
context_attention_mask=None,
doc_scores=None,
max_length=None,
min_length=None,
early_stopping=None,
use_cache=None,
num_beams=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
bad_words_ids=None,
num_return_sequences=None,
decoder_start_token_id=None,