-
Notifications
You must be signed in to change notification settings - Fork 25.8k
/
modeling_tf_deberta_v2.py
1871 lines (1596 loc) · 79.2 KB
/
modeling_tf_deberta_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2021 Microsoft and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 DeBERTa-v2 model."""
from __future__ import annotations
from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFMaskedLMOutput,
TFMultipleChoiceModelOutput,
TFQuestionAnsweringModelOutput,
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
)
from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFModelInputType,
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
keras,
unpack_inputs,
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_deberta_v2 import DebertaV2Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DebertaV2Config"
_CHECKPOINT_FOR_DOC = "kamalkraj/deberta-v2-xlarge"
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaContextPooler with Deberta->DebertaV2
class TFDebertaV2ContextPooler(keras.layers.Layer):
def __init__(self, config: DebertaV2Config, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(config.pooler_hidden_size, name="dense")
self.dropout = TFDebertaV2StableDropout(config.pooler_dropout, name="dropout")
self.config = config
def call(self, hidden_states, training: bool = False):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
context_token = hidden_states[:, 0]
context_token = self.dropout(context_token, training=training)
pooled_output = self.dense(context_token)
pooled_output = get_tf_activation(self.config.pooler_hidden_act)(pooled_output)
return pooled_output
@property
def output_dim(self) -> int:
return self.config.hidden_size
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.pooler_hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaXSoftmax with Deberta->DebertaV2
class TFDebertaV2XSoftmax(keras.layers.Layer):
"""
Masked Softmax which is optimized for saving memory
Args:
input (`tf.Tensor`): The input tensor that will apply softmax.
mask (`tf.Tensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
dim (int): The dimension that will apply softmax
"""
def __init__(self, axis=-1, **kwargs):
super().__init__(**kwargs)
self.axis = axis
def call(self, inputs: tf.Tensor, mask: tf.Tensor):
rmask = tf.logical_not(tf.cast(mask, tf.bool))
output = tf.where(rmask, float("-inf"), inputs)
output = stable_softmax(output, self.axis)
output = tf.where(rmask, 0.0, output)
return output
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaStableDropout with Deberta->DebertaV2
class TFDebertaV2StableDropout(keras.layers.Layer):
"""
Optimized dropout module for stabilizing the training
Args:
drop_prob (float): the dropout probabilities
"""
def __init__(self, drop_prob, **kwargs):
super().__init__(**kwargs)
self.drop_prob = drop_prob
@tf.custom_gradient
def xdropout(self, inputs):
"""
Applies dropout to the inputs, as vanilla dropout, but also scales the remaining elements up by 1/drop_prob.
"""
mask = tf.cast(
1
- tf.compat.v1.distributions.Bernoulli(probs=1.0 - self.drop_prob).sample(sample_shape=shape_list(inputs)),
tf.bool,
)
scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=tf.float32)
if self.drop_prob > 0:
inputs = tf.where(mask, 0.0, inputs) * scale
def grad(upstream):
if self.drop_prob > 0:
return tf.where(mask, 0.0, upstream) * scale
else:
return upstream
return inputs, grad
def call(self, inputs: tf.Tensor, training: tf.Tensor = False):
if training:
return self.xdropout(inputs)
return inputs
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaSelfOutput with Deberta->DebertaV2
class TFDebertaV2SelfOutput(keras.layers.Layer):
def __init__(self, config: DebertaV2Config, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(config.hidden_size, name="dense")
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
self.config = config
def call(self, hidden_states, input_tensor, training: bool = False):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaAttention with Deberta->DebertaV2
class TFDebertaV2Attention(keras.layers.Layer):
def __init__(self, config: DebertaV2Config, **kwargs):
super().__init__(**kwargs)
self.self = TFDebertaV2DisentangledSelfAttention(config, name="self")
self.dense_output = TFDebertaV2SelfOutput(config, name="output")
self.config = config
def call(
self,
input_tensor: tf.Tensor,
attention_mask: tf.Tensor,
query_states: tf.Tensor = None,
relative_pos: tf.Tensor = None,
rel_embeddings: tf.Tensor = None,
output_attentions: bool = False,
training: bool = False,
) -> Tuple[tf.Tensor]:
self_outputs = self.self(
hidden_states=input_tensor,
attention_mask=attention_mask,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
training=training,
)
if query_states is None:
query_states = input_tensor
attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=query_states, training=training
)
output = (attention_output,) + self_outputs[1:]
return output
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self", None) is not None:
with tf.name_scope(self.self.name):
self.self.build(None)
if getattr(self, "dense_output", None) is not None:
with tf.name_scope(self.dense_output.name):
self.dense_output.build(None)
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaIntermediate with Deberta->DebertaV2
class TFDebertaV2Intermediate(keras.layers.Layer):
def __init__(self, config: DebertaV2Config, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.config = config
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.hidden_size])
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaOutput with Deberta->DebertaV2
class TFDebertaV2Output(keras.layers.Layer):
def __init__(self, config: DebertaV2Config, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
self.config = config
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "dense", None) is not None:
with tf.name_scope(self.dense.name):
self.dense.build([None, None, self.config.intermediate_size])
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaLayer with Deberta->DebertaV2
class TFDebertaV2Layer(keras.layers.Layer):
def __init__(self, config: DebertaV2Config, **kwargs):
super().__init__(**kwargs)
self.attention = TFDebertaV2Attention(config, name="attention")
self.intermediate = TFDebertaV2Intermediate(config, name="intermediate")
self.bert_output = TFDebertaV2Output(config, name="output")
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
query_states: tf.Tensor = None,
relative_pos: tf.Tensor = None,
rel_embeddings: tf.Tensor = None,
output_attentions: bool = False,
training: bool = False,
) -> Tuple[tf.Tensor]:
attention_outputs = self.attention(
input_tensor=hidden_states,
attention_mask=attention_mask,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
training=training,
)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(hidden_states=attention_output)
layer_output = self.bert_output(
hidden_states=intermediate_output, input_tensor=attention_output, training=training
)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "attention", None) is not None:
with tf.name_scope(self.attention.name):
self.attention.build(None)
if getattr(self, "intermediate", None) is not None:
with tf.name_scope(self.intermediate.name):
self.intermediate.build(None)
if getattr(self, "bert_output", None) is not None:
with tf.name_scope(self.bert_output.name):
self.bert_output.build(None)
class TFDebertaV2ConvLayer(keras.layers.Layer):
def __init__(self, config: DebertaV2Config, **kwargs):
super().__init__(**kwargs)
self.kernel_size = getattr(config, "conv_kernel_size", 3)
# groups = getattr(config, "conv_groups", 1)
self.conv_act = get_tf_activation(getattr(config, "conv_act", "tanh"))
self.padding = (self.kernel_size - 1) // 2
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
self.config = config
def build(self, input_shape=None):
if self.built:
return
self.built = True
with tf.name_scope("conv"):
self.conv_kernel = self.add_weight(
name="kernel",
shape=[self.kernel_size, self.config.hidden_size, self.config.hidden_size],
initializer=get_initializer(self.config.initializer_range),
)
self.conv_bias = self.add_weight(
name="bias", shape=[self.config.hidden_size], initializer=tf.zeros_initializer()
)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
def call(
self, hidden_states: tf.Tensor, residual_states: tf.Tensor, input_mask: tf.Tensor, training: bool = False
) -> tf.Tensor:
out = tf.nn.conv2d(
tf.expand_dims(hidden_states, 1),
tf.expand_dims(self.conv_kernel, 0),
strides=1,
padding=[[0, 0], [0, 0], [self.padding, self.padding], [0, 0]],
)
out = tf.squeeze(tf.nn.bias_add(out, self.conv_bias), 1)
rmask = tf.cast(1 - input_mask, tf.bool)
out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out)
out = self.dropout(out, training=training)
out = self.conv_act(out)
layer_norm_input = residual_states + out
output = self.LayerNorm(layer_norm_input)
if input_mask is None:
output_states = output
else:
if len(shape_list(input_mask)) != len(shape_list(layer_norm_input)):
if len(shape_list(input_mask)) == 4:
input_mask = tf.squeeze(tf.squeeze(input_mask, axis=1), axis=1)
input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), tf.float32)
output_states = output * input_mask
return output_states
class TFDebertaV2Encoder(keras.layers.Layer):
def __init__(self, config: DebertaV2Config, **kwargs):
super().__init__(**kwargs)
self.layer = [TFDebertaV2Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
self.relative_attention = getattr(config, "relative_attention", False)
self.config = config
if self.relative_attention:
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
if self.max_relative_positions < 1:
self.max_relative_positions = config.max_position_embeddings
self.position_buckets = getattr(config, "position_buckets", -1)
self.pos_ebd_size = self.max_relative_positions * 2
if self.position_buckets > 0:
self.pos_ebd_size = self.position_buckets * 2
self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
if "layer_norm" in self.norm_rel_ebd:
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.conv = TFDebertaV2ConvLayer(config, name="conv") if getattr(config, "conv_kernel_size", 0) > 0 else None
def build(self, input_shape=None):
if self.built:
return
self.built = True
if self.relative_attention:
self.rel_embeddings = self.add_weight(
name="rel_embeddings.weight",
shape=[self.pos_ebd_size, self.config.hidden_size],
initializer=get_initializer(self.config.initializer_range),
)
if getattr(self, "conv", None) is not None:
with tf.name_scope(self.conv.name):
self.conv.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, self.config.hidden_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
layer.build(None)
def get_rel_embedding(self):
rel_embeddings = self.rel_embeddings if self.relative_attention else None
if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
rel_embeddings = self.LayerNorm(rel_embeddings)
return rel_embeddings
def get_attention_mask(self, attention_mask):
if len(shape_list(attention_mask)) <= 2:
extended_attention_mask = tf.expand_dims(tf.expand_dims(attention_mask, 1), 2)
attention_mask = extended_attention_mask * tf.expand_dims(tf.squeeze(extended_attention_mask, -2), -1)
attention_mask = tf.cast(attention_mask, tf.uint8)
elif len(shape_list(attention_mask)) == 3:
attention_mask = tf.expand_dims(attention_mask, 1)
return attention_mask
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
if self.relative_attention and relative_pos is None:
q = shape_list(query_states)[-2] if query_states is not None else shape_list(hidden_states)[-2]
relative_pos = build_relative_position(
q,
shape_list(hidden_states)[-2],
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
)
return relative_pos
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
query_states: tf.Tensor = None,
relative_pos: tf.Tensor = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
if len(shape_list(attention_mask)) <= 2:
input_mask = attention_mask
else:
input_mask = tf.cast(tf.math.reduce_sum(attention_mask, axis=-2) > 0, dtype=tf.uint8)
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
attention_mask = self.get_attention_mask(attention_mask)
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
next_kv = hidden_states
rel_embeddings = self.get_rel_embedding()
output_states = next_kv
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (output_states,)
layer_outputs = layer_module(
hidden_states=next_kv,
attention_mask=attention_mask,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
training=training,
)
output_states = layer_outputs[0]
if i == 0 and self.conv is not None:
output_states = self.conv(hidden_states, output_states, input_mask)
next_kv = output_states
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (output_states,)
if not return_dict:
return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
)
def make_log_bucket_position(relative_pos, bucket_size, max_position):
sign = tf.math.sign(relative_pos)
mid = bucket_size // 2
abs_pos = tf.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, tf.math.abs(relative_pos))
log_pos = (
tf.math.ceil(
tf.cast(tf.math.log(abs_pos / mid), tf.float32) / tf.math.log((max_position - 1) / mid) * (mid - 1)
)
+ mid
)
bucket_pos = tf.cast(
tf.where(abs_pos <= mid, tf.cast(relative_pos, tf.float32), log_pos * tf.cast(sign, tf.float32)), tf.int32
)
return bucket_pos
def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
"""
Build relative position according to the query and key
We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
\\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
P_k\\)
Args:
query_size (int): the length of query
key_size (int): the length of key
bucket_size (int): the size of position bucket
max_position (int): the maximum allowed absolute position
Return:
`tf.Tensor`: A tensor with shape [1, query_size, key_size]
"""
q_ids = tf.range(query_size, dtype=tf.int32)
k_ids = tf.range(key_size, dtype=tf.int32)
rel_pos_ids = q_ids[:, None] - tf.tile(tf.expand_dims(k_ids, axis=0), [shape_list(q_ids)[0], 1])
if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = tf.expand_dims(rel_pos_ids, axis=0)
return tf.cast(rel_pos_ids, tf.int64)
def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
shapes = [
shape_list(query_layer)[0],
shape_list(query_layer)[1],
shape_list(query_layer)[2],
shape_list(relative_pos)[-1],
]
return tf.broadcast_to(c2p_pos, shapes)
def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
shapes = [
shape_list(query_layer)[0],
shape_list(query_layer)[1],
shape_list(key_layer)[-2],
shape_list(key_layer)[-2],
]
return tf.broadcast_to(c2p_pos, shapes)
def pos_dynamic_expand(pos_index, p2c_att, key_layer):
shapes = shape_list(p2c_att)[:2] + [shape_list(pos_index)[-2], shape_list(key_layer)[-2]]
return tf.broadcast_to(pos_index, shapes)
def take_along_axis(x, indices):
# Only a valid port of np.take_along_axis when the gather axis is -1
# TPU + gathers and reshapes don't go along well -- see https://github.com/huggingface/transformers/issues/18239
if isinstance(tf.distribute.get_strategy(), tf.distribute.TPUStrategy):
# [B, S, P] -> [B, S, P, D]
one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype)
# if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
# grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
gathered = tf.einsum("ijkl,ijl->ijk", one_hot_indices, x)
# GPUs, on the other hand, prefer gathers instead of large one-hot+matmuls
else:
gathered = tf.gather(x, indices, batch_dims=2)
return gathered
class TFDebertaV2DisentangledSelfAttention(keras.layers.Layer):
"""
Disentangled self-attention module
Parameters:
config (`DebertaV2Config`):
A model config class instance with the configuration to build a new model. The schema is similar to
*BertConfig*, for more details, please refer [`DebertaV2Config`]
"""
def __init__(self, config: DebertaV2Config, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
_attention_head_size = config.hidden_size // config.num_attention_heads
self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_proj = keras.layers.Dense(
self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range),
name="query_proj",
use_bias=True,
)
self.key_proj = keras.layers.Dense(
self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range),
name="key_proj",
use_bias=True,
)
self.value_proj = keras.layers.Dense(
self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range),
name="value_proj",
use_bias=True,
)
self.share_att_key = getattr(config, "share_att_key", False)
self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
self.relative_attention = getattr(config, "relative_attention", False)
if self.relative_attention:
self.position_buckets = getattr(config, "position_buckets", -1)
self.max_relative_positions = getattr(config, "max_relative_positions", -1)
if self.max_relative_positions < 1:
self.max_relative_positions = config.max_position_embeddings
self.pos_ebd_size = self.max_relative_positions
if self.position_buckets > 0:
self.pos_ebd_size = self.position_buckets
self.pos_dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="pos_dropout")
if not self.share_att_key:
if "c2p" in self.pos_att_type:
self.pos_key_proj = keras.layers.Dense(
self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range),
name="pos_proj",
use_bias=True,
)
if "p2c" in self.pos_att_type:
self.pos_query_proj = keras.layers.Dense(
self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range),
name="pos_q_proj",
)
self.softmax = TFDebertaV2XSoftmax(axis=-1)
self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout")
self.config = config
def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor:
tensor_shape = shape_list(tensor)
# In graph mode mode, we can't reshape with -1 as the final dimension if the first dimension (batch size) is None
shape = tensor_shape[:-1] + [attention_heads, tensor_shape[-1] // attention_heads]
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=shape)
tensor = tf.transpose(tensor, perm=[0, 2, 1, 3])
x_shape = shape_list(tensor)
tensor = tf.reshape(tensor, shape=[-1, x_shape[-2], x_shape[-1]])
return tensor
def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
query_states: tf.Tensor = None,
relative_pos: tf.Tensor = None,
rel_embeddings: tf.Tensor = None,
output_attentions: bool = False,
training: bool = False,
) -> Tuple[tf.Tensor]:
"""
Call the module
Args:
hidden_states (`tf.Tensor`):
Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)*
attention_mask (`tf.Tensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token.
return_att (`bool`, optional):
Whether return the attention matrix.
query_states (`tf.Tensor`, optional):
The *Q* state in *Attention(Q,K,V)*.
relative_pos (`tf.Tensor`):
The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
values ranging in [*-max_relative_positions*, *max_relative_positions*].
rel_embeddings (`tf.Tensor`):
The embedding of relative distances. It's a tensor of shape [\\(2 \\times
\\text{max_relative_positions}\\), *hidden_size*].
"""
if query_states is None:
query_states = hidden_states
query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
rel_att = None
# Take the dot product between "query" and "key" to get the raw attention scores.
scale_factor = 1
if "c2p" in self.pos_att_type:
scale_factor += 1
if "p2c" in self.pos_att_type:
scale_factor += 1
scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, tf.float32))
attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 2, 1]) / scale)
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
if rel_att is not None:
attention_scores = attention_scores + rel_att
attention_scores = tf.reshape(
attention_scores,
(-1, self.num_attention_heads, shape_list(attention_scores)[-2], shape_list(attention_scores)[-1]),
)
# bsz x height x length x dimension
attention_probs = self.softmax(attention_scores, attention_mask)
attention_probs = self.dropout(attention_probs, training=training)
context_layer = tf.matmul(
tf.reshape(attention_probs, [-1, shape_list(attention_probs)[-2], shape_list(attention_probs)[-1]]),
value_layer,
)
context_layer = tf.transpose(
tf.reshape(
context_layer,
[-1, self.num_attention_heads, shape_list(context_layer)[-2], shape_list(context_layer)[-1]],
),
[0, 2, 1, 3],
)
# Set the final dimension here explicitly.
# Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing
# the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput
# requires final input dimension to be defined
context_layer_shape = shape_list(context_layer)
new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]]
context_layer = tf.reshape(context_layer, new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
if relative_pos is None:
q = shape_list(query_layer)[-2]
relative_pos = build_relative_position(
q,
shape_list(key_layer)[-2],
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
)
shape_list_pos = shape_list(relative_pos)
if len(shape_list_pos) == 2:
relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0)
elif len(shape_list_pos) == 3:
relative_pos = tf.expand_dims(relative_pos, 1)
# bsz x height x query x key
elif len(shape_list_pos) != 4:
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}")
att_span = self.pos_ebd_size
rel_embeddings = tf.expand_dims(
rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :], 0
)
if self.share_att_key:
pos_query_layer = tf.tile(
self.transpose_for_scores(self.query_proj(rel_embeddings), self.num_attention_heads),
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
)
pos_key_layer = tf.tile(
self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads),
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
)
else:
if "c2p" in self.pos_att_type:
pos_key_layer = tf.tile(
self.transpose_for_scores(self.pos_key_proj(rel_embeddings), self.num_attention_heads),
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
) # .split(self.all_head_size, dim=-1)
if "p2c" in self.pos_att_type:
pos_query_layer = tf.tile(
self.transpose_for_scores(self.pos_query_proj(rel_embeddings), self.num_attention_heads),
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
) # .split(self.all_head_size, dim=-1)
score = 0
# content->position
if "c2p" in self.pos_att_type:
scale = tf.math.sqrt(tf.cast(shape_list(pos_key_layer)[-1] * scale_factor, tf.float32))
c2p_att = tf.matmul(query_layer, tf.transpose(pos_key_layer, [0, 2, 1]))
c2p_pos = tf.clip_by_value(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = take_along_axis(
c2p_att,
tf.broadcast_to(
tf.squeeze(c2p_pos, 0),
[shape_list(query_layer)[0], shape_list(query_layer)[1], shape_list(relative_pos)[-1]],
),
)
score += c2p_att / scale
# position->content
if "p2c" in self.pos_att_type:
scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, tf.float32))
if shape_list(key_layer)[-2] != shape_list(query_layer)[-2]:
r_pos = build_relative_position(
shape_list(key_layer)[-2],
shape_list(key_layer)[-2],
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
)
r_pos = tf.expand_dims(r_pos, 0)
else:
r_pos = relative_pos
p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1)
p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 2, 1]))
p2c_att = tf.transpose(
take_along_axis(
p2c_att,
tf.broadcast_to(
tf.squeeze(p2c_pos, 0),
[shape_list(query_layer)[0], shape_list(key_layer)[-2], shape_list(key_layer)[-2]],
),
),
[0, 2, 1],
)
score += p2c_att / scale
return score
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "query_proj", None) is not None:
with tf.name_scope(self.query_proj.name):
self.query_proj.build([None, None, self.config.hidden_size])
if getattr(self, "key_proj", None) is not None:
with tf.name_scope(self.key_proj.name):
self.key_proj.build([None, None, self.config.hidden_size])
if getattr(self, "value_proj", None) is not None:
with tf.name_scope(self.value_proj.name):
self.value_proj.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "pos_dropout", None) is not None:
with tf.name_scope(self.pos_dropout.name):
self.pos_dropout.build(None)
if getattr(self, "pos_key_proj", None) is not None:
with tf.name_scope(self.pos_key_proj.name):
self.pos_key_proj.build([None, None, self.config.hidden_size])
if getattr(self, "pos_query_proj", None) is not None:
with tf.name_scope(self.pos_query_proj.name):
self.pos_query_proj.build([None, None, self.config.hidden_size])
# Copied from transformers.models.deberta.modeling_tf_deberta.TFDebertaEmbeddings Deberta->DebertaV2
class TFDebertaV2Embeddings(keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
self.hidden_size = config.hidden_size
self.max_position_embeddings = config.max_position_embeddings
self.position_biased_input = getattr(config, "position_biased_input", True)
self.initializer_range = config.initializer_range
if self.embedding_size != config.hidden_size:
self.embed_proj = keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
name="embed_proj",
use_bias=False,
)
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="dropout")
def build(self, input_shape=None):
with tf.name_scope("word_embeddings"):
self.weight = self.add_weight(
name="weight",
shape=[self.config.vocab_size, self.embedding_size],
initializer=get_initializer(self.initializer_range),
)
with tf.name_scope("token_type_embeddings"):
if self.config.type_vocab_size > 0:
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.config.type_vocab_size, self.embedding_size],
initializer=get_initializer(self.initializer_range),
)
else:
self.token_type_embeddings = None
with tf.name_scope("position_embeddings"):
if self.position_biased_input:
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(self.initializer_range),
)
else:
self.position_embeddings = None
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
if getattr(self, "embed_proj", None) is not None:
with tf.name_scope(self.embed_proj.name):
self.embed_proj.build([None, None, self.embedding_size])
def call(
self,
input_ids: tf.Tensor = None,
position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None,
mask: tf.Tensor = None,
training: bool = False,
) -> tf.Tensor:
"""
Applies embedding based on inputs tensor.
Returns:
final_embeddings (`tf.Tensor`): output embedding tensor.