-
Notifications
You must be signed in to change notification settings - Fork 26.4k
/
modeling_tf_longformer.py
2686 lines (2274 loc) · 124 KB
/
modeling_tf_longformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow Longformer model."""
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFModelInputType,
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from .configuration_longformer import LongformerConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "allenai/longformer-base-4096"
_CONFIG_FOR_DOC = "LongformerConfig"
LARGE_NEGATIVE = -1e8
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"allenai/longformer-base-4096",
"allenai/longformer-large-4096",
"allenai/longformer-large-4096-finetuned-triviaqa",
"allenai/longformer-base-4096-extra.pos.embd.only",
"allenai/longformer-large-4096-extra.pos.embd.only",
# See all Longformer models at https://huggingface.co/models?filter=longformer
]
@dataclass
class TFLongformerBaseModelOutput(ModelOutput):
"""
Base class for Longformer's outputs, with potential hidden states, local and global attentions.
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where `x` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
(succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
If the attention window contains a token with global attention, the attention weight at the corresponding
index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
accessed from `global_attentions`.
global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
global_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFLongformerBaseModelOutputWithPooling(ModelOutput):
"""
Base class for Longformer's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where `x` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
(succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
If the attention window contains a token with global attention, the attention weight at the corresponding
index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
accessed from `global_attentions`.
global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
global_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFLongformerMaskedLMOutput(ModelOutput):
"""
Base class for masked language models outputs.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Masked language modeling (MLM) loss.
logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where `x` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
(succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
If the attention window contains a token with global attention, the attention weight at the corresponding
index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
accessed from `global_attentions`.
global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
global_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFLongformerQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering Longformer models.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where `x` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
(succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
If the attention window contains a token with global attention, the attention weight at the corresponding
index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
accessed from `global_attentions`.
global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[tf.Tensor] = None
start_logits: tf.Tensor = None
end_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
global_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFLongformerSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where `x` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
(succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
If the attention window contains a token with global attention, the attention weight at the corresponding
index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
accessed from `global_attentions`.
global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
global_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFLongformerMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
Args:
loss (`tf.Tensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
Classification loss.
logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
*num_choices* is the second dimension of the input tensors. (see *input_ids* above).
Classification scores (before SoftMax).
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where `x` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
(succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
If the attention window contains a token with global attention, the attention weight at the corresponding
index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
accessed from `global_attentions`.
global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
global_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFLongformerTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
Classification loss.
logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where `x` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first `x` values) and to every token in the attention window (remaining `attention_window
+ 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
(succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
If the attention window contains a token with global attention, the attention weight at the corresponding
index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
accessed from `global_attentions`.
global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
global_attentions: Optional[Tuple[tf.Tensor]] = None
def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True):
"""
Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is
True` else after `sep_token_id`.
"""
assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions"
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1][:, None]
# bool attention mask with True in locations of global attention
attention_mask = tf.expand_dims(tf.range(input_ids_shape[1], dtype=tf.int64), axis=0)
attention_mask = tf.tile(attention_mask, (input_ids_shape[0], 1))
if before_sep_token is True:
question_end_index = tf.tile(question_end_index, (1, input_ids_shape[1]))
attention_mask = tf.cast(attention_mask < question_end_index, dtype=question_end_index.dtype)
else:
# last token is separation token and should not be counted and in the middle are two separation tokens
question_end_index = tf.tile(question_end_index + 1, (1, input_ids_shape[1]))
attention_mask = tf.cast(
attention_mask > question_end_index,
dtype=question_end_index.dtype,
) * tf.cast(attention_mask < input_ids_shape[-1], dtype=question_end_index.dtype)
return attention_mask
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->Longformer
class TFLongformerLMHead(tf.keras.layers.Layer):
"""Longformer Head for masked language modeling."""
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.config = config
self.hidden_size = config.hidden_size
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.act = get_tf_activation("gelu")
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = input_embeddings
def build(self, input_shape):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, value):
self.decoder.weight = value
self.decoder.vocab_size = shape_list(value)[0]
def get_bias(self):
return {"bias": self.bias}
def set_bias(self, value):
self.bias = value["bias"]
self.config.vocab_size = shape_list(value["bias"])[0]
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.layer_norm(hidden_states)
# project back to size of vocabulary with bias
seq_length = shape_list(tensor=hidden_states)[1]
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True)
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
return hidden_states
class TFLongformerEmbeddings(tf.keras.layers.Layer):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing and some extra casting.
"""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.padding_idx = 1
self.config = config
self.hidden_size = config.hidden_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def build(self, input_shape: tf.TensorShape):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.config.vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.config.type_vocab_size, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
input_ids: tf.Tensor
Returns: tf.Tensor
"""
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask
return incremental_indices + self.padding_idx
def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
past_key_values_length=0,
training=False,
):
"""
Applies embedding based on inputs tensor.
Returns:
final_embeddings (`tf.Tensor`): output embedding tensor.
"""
assert not (input_ids is None and inputs_embeds is None)
if input_ids is not None:
# Note: tf.gather, on which the embedding layer is based, won't check positive out of bound
# indices on GPU, returning zeros instead. This is a dangerous silent behavior.
tf.debugging.assert_less(
input_ids,
tf.cast(self.config.vocab_size, dtype=input_ids.dtype),
message=(
"input_ids must be smaller than the embedding layer's input dimension (got"
f" {tf.math.reduce_max(input_ids)} >= {self.config.vocab_size})"
),
)
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
input_shape = shape_list(inputs_embeds)[:-1]
if token_type_ids is None:
token_type_ids = tf.cast(tf.fill(dims=input_shape, value=0), tf.int64)
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(
input_ids=input_ids, past_key_values_length=past_key_values_length
)
else:
position_ids = tf.expand_dims(
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1, dtype=tf.int64),
axis=0,
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
final_embeddings = self.LayerNorm(inputs=final_embeddings)
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
return final_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Longformer
class TFLongformerIntermediate(tf.keras.layers.Layer):
def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Longformer
class TFLongformerOutput(tf.keras.layers.Layer):
def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Longformer
class TFLongformerPooler(tf.keras.layers.Layer):
def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(inputs=first_token_tensor)
return pooled_output
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Longformer
class TFLongformerSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: LongformerConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
class TFLongformerSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, layer_id, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads}"
)
self.num_heads = config.num_attention_heads
self.head_dim = int(config.hidden_size / config.num_attention_heads)
self.embed_dim = config.hidden_size
self.query = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query",
)
self.key = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key",
)
self.value = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value",
)
# separate projection layers for tokens with global attention
self.query_global = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="query_global",
)
self.key_global = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="key_global",
)
self.value_global = tf.keras.layers.Dense(
self.embed_dim,
kernel_initializer=get_initializer(config.initializer_range),
name="value_global",
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.global_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.layer_id = layer_id
attention_window = config.attention_window[self.layer_id]
assert (
attention_window % 2 == 0
), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
assert (
attention_window > 0
), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
self.one_sided_attn_window_size = attention_window // 2
def call(
self,
inputs,
training=False,
):
"""
LongformerSelfAttention expects *len(hidden_states)* to be multiple of *attention_window*. Padding to
*attention_window* happens in LongformerModel.forward to avoid redoing the padding on each layer.
The *attention_mask* is changed in [`LongformerModel.forward`] from 0, 1, 2 to:
- -10000: no attention
- 0: local attention
- +10000: global attention
"""
# retrieve input args
(
hidden_states,
attention_mask,
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
) = inputs
# project hidden states
query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states)
value_vectors = self.value(hidden_states)
batch_size, seq_len, embed_dim = shape_list(hidden_states)
tf.debugging.assert_equal(
embed_dim,
self.embed_dim,
message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
)
# normalize query
query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
query_vectors, key_vectors, self.one_sided_attn_window_size
)
# values to pad for attention probs
remove_from_windowed_attention_mask = attention_mask != 0
# cast to fp32/fp16 then replace 1's with -inf
float_mask = tf.cast(remove_from_windowed_attention_mask, dtype=query_vectors.dtype) * LARGE_NEGATIVE
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(shape_list(attention_mask)),
float_mask,
self.one_sided_attn_window_size,
)
# pad local attention probs
attn_scores += diagonal_mask
tf.debugging.assert_equal(
shape_list(attn_scores),
[batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
message=(
f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
),
)
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn)
# this function is only relevant for global attention
attn_scores = tf.cond(
is_global_attn,
lambda: self._concat_with_global_key_attn_probs(
attn_scores=attn_scores,
query_vectors=query_vectors,
key_vectors=key_vectors,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
),
lambda: attn_scores,
)
attn_probs = stable_softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where(
masked_index,
tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype),
attn_probs,
)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=(
f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
f" {shape_list(layer_head_mask)}"
),
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
# apply dropout
attn_probs = self.dropout(attn_probs, training=training)
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
# if global attention, compute sum of global and local attn
attn_output = tf.cond(
is_global_attn,
lambda: self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
),
lambda: self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self.one_sided_attn_window_size
),
)
tf.debugging.assert_equal(
shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
)
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
# compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation
attn_output, global_attn_probs = tf.cond(
is_global_attn,
lambda: self._compute_global_attn_output_from_hidden(
attn_output=attn_output,
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
training=training,
),
lambda: (attn_output, tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))),
)
# make sure that local attention probabilities are set to 0 for indices of global attn
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_global_attn_index = tf.cond(
is_global_attn,
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
),
lambda: tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
)
attn_probs = tf.where(
masked_global_attn_index,
tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
attn_probs,
)
outputs = (attn_output, attn_probs, global_attn_probs)
return outputs
def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
"""
Matrix multiplication of query and key tensors using with a sliding window attention pattern. This
implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an
overlap of size window_overlap
"""
batch_size, seq_len, num_heads, head_dim = shape_list(query)
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
)
tf.debugging.assert_equal(
shape_list(query),
shape_list(key),
message=(
f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
f" {shape_list(key)}"
),
)
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
query = tf.reshape(
tf.transpose(query, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
chunked_query = self._chunk(query, window_overlap)
chunked_key = self._chunk(key, window_overlap)
# matrix multiplication
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype)
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply
# convert diagonals into columns
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]])
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to
# window_overlap previous words). The following column is attention score from each word to itself, then
# followed by window_overlap columns for the upper triangle.
# copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
# - copying the main diagonal and the upper triangle
# TODO: This code is most likely not very efficient and should be improved
diagonal_attn_scores_up_triang = tf.concat(
[
diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1],
diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1],
],
axis=1,
)
# - copying the lower triangle
diagonal_attn_scores_low_triang = tf.concat(
[
tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :],
],
axis=1,
)
diagonal_attn_scores_first_chunk = tf.concat(
[
tf.roll(
diagonal_chunked_attention_scores,
shift=[1, window_overlap],
axis=[2, 3],
)[:, :, :window_overlap, :window_overlap],
tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
],
axis=1,
)
first_chunk_mask = (
tf.tile(