-
Notifications
You must be signed in to change notification settings - Fork 26.4k
/
modeling_tf_t5.py
1691 lines (1448 loc) · 74.8 KB
/
modeling_tf_t5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2020 T5 Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 T5 model."""
import copy
import itertools
import math
import warnings
from typing import Tuple
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...file_utils import (
DUMMY_INPUTS,
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPast,
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
)
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
input_processing,
keras_serializable,
shape_list,
)
from ...utils import logging
from .configuration_t5 import T5Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
# See all T5 models at https://huggingface.co/models?filter=t5
]
####################################################
# TF 2.0 Models are constructed using Keras imperative API by sub-classing
# - tf.keras.layers.Layer for the layers and
# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
####################################################
class TFT5LayerNorm(tf.keras.layers.Layer):
def __init__(self, epsilon=1e-6, **kwargs):
"""
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
"""
super().__init__(**kwargs)
self.variance_epsilon = epsilon
def build(self, input_shape):
"""Build shared word embedding layer"""
self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones")
super().build(input_shape)
def call(self, hidden_states):
variance = tf.math.reduce_mean(tf.math.square(hidden_states), axis=-1, keepdims=True)
hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class TFT5DenseReluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
wi_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
)
wo_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
)
self.wi = tf.keras.layers.Dense(
config.d_ff, use_bias=False, name="wi", kernel_initializer=wi_initializer
) # Update init weights as in flax
self.wo = tf.keras.layers.Dense(
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = tf.keras.activations.relu
def call(self, hidden_states, training=False):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.wo(hidden_states)
return hidden_states
class TFT5GatedGeluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
wi_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_model ** -0.5)
)
wo_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (config.d_ff ** -0.5)
)
self.wi_0 = tf.keras.layers.Dense(
config.d_ff, use_bias=False, name="wi_0", kernel_initializer=wi_initializer
) # Update init weights as in flax
self.wi_1 = tf.keras.layers.Dense(
config.d_ff, use_bias=False, name="wi_1", kernel_initializer=wi_initializer
) # Update init weights as in flax
self.wo = tf.keras.layers.Dense(
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = get_tf_activation("gelu_new")
def call(self, hidden_states, training=False):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.wo(hidden_states)
return hidden_states
class TFT5LayerFF(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if config.feed_forward_proj == "relu":
self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense")
elif config.feed_forward_proj == "gated-gelu":
self.DenseReluDense = TFT5GatedGeluDense(config, name="DenseReluDense")
else:
raise ValueError(
f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, training=False):
normed_hidden_states = self.layer_norm(hidden_states)
dense_output = self.DenseReluDense(normed_hidden_states, training=training)
hidden_states = hidden_states + self.dropout(dense_output, training=training)
return hidden_states
class TFT5Attention(tf.keras.layers.Layer):
NEW_ID = itertools.count()
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
super().__init__(**kwargs)
self.layer_id = next(TFT5Attention.NEW_ID)
self.is_decoder = config.is_decoder
self.use_cache = config.use_cache
self.has_relative_attention_bias = has_relative_attention_bias
self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.inner_dim = self.n_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax
q_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
)
k_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)
v_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)
o_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)
self.relative_attention_bias_initializer = tf.keras.initializers.RandomNormal(
mean=0, stddev=config.initializer_factor * (self.inner_dim ** -0.5)
)
self.q = tf.keras.layers.Dense(
self.inner_dim, use_bias=False, name="q", kernel_initializer=q_initializer
) # Update init weights as in flax
self.k = tf.keras.layers.Dense(
self.inner_dim, use_bias=False, name="k", kernel_initializer=k_initializer
) # Update init weights as in flax
self.v = tf.keras.layers.Dense(
self.inner_dim, use_bias=False, name="v", kernel_initializer=v_initializer
) # Update init weights as in flax
self.o = tf.keras.layers.Dense(
self.d_model, use_bias=False, name="o", kernel_initializer=o_initializer
) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.pruned_heads = set()
def build(self, input_shape):
if self.has_relative_attention_bias:
with tf.name_scope("relative_attention_bias"):
self.relative_attention_bias = self.add_weight(
name="embeddings",
shape=[self.relative_attention_num_buckets, self.n_heads],
initializer=self.relative_attention_bias_initializer, # Add initializer
)
return super().build(input_shape)
def prune_heads(self, heads):
raise NotImplementedError
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
# n = -relative_position
if bidirectional:
num_buckets //= 2
relative_buckets += (
tf.cast(tf.math.greater(relative_position, 0), dtype=relative_position.dtype) * num_buckets
)
relative_position = tf.math.abs(relative_position)
else:
relative_position = -tf.math.minimum(relative_position, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(relative_position, max_exact)
relative_position_if_large = max_exact + tf.cast(
tf.math.log(relative_position / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact),
dtype=relative_position.dtype,
)
relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += tf.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length):
"""Compute binned relative position bias"""
context_position = tf.range(query_length)[:, None]
memory_position = tf.range(key_length)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
)
values = tf.gather(
self.relative_attention_bias, relative_position_bucket
) # shape (query_length, key_length, num_heads)
values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]), axis=0
) # shape (1, num_heads, query_length, key_length)
return values
def call(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
training=False,
output_attentions=False,
):
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, query_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = shape_list(hidden_states)[:2]
real_seq_length = seq_length
if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
real_seq_length += shape_list(past_key_value[0])[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1]
def shape(hidden_states):
"""projection"""
return tf.transpose(
tf.reshape(hidden_states, (batch_size, -1, self.n_heads, self.key_value_proj_dim)), perm=(0, 2, 1, 3)
)
def unshape(hidden_states):
"""compute context"""
return tf.reshape(tf.transpose(hidden_states, perm=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim))
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = tf.concat([past_key_value, hidden_states], axis=2)
else:
# cross-attn
hidden_states = past_key_value
return hidden_states
# get query
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, query_length, dim_per_head)
# get key/value
key_states = project(
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
# to cope with keras serialization
if self.is_decoder and use_cache:
present_key_value_state = (key_states, value_states)
else:
present_key_value_state = None
scores = tf.einsum(
"bnqd,bnkd->bnqk", query_states, key_states
) # (batch_size, n_heads, query_length, key_length)
if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length))
else:
position_bias = self.compute_bias(real_seq_length, key_length)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -seq_length:, :]
if mask is not None:
position_bias = tf.cast(position_bias, dtype=mask.dtype)
position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length)
scores += position_bias
weights = tf.nn.softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length)
weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length)
# Mask heads if we want to
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.n_heads],
message=f"Head mask for a single layer should be of size {(self.n_heads)}, but is {shape_list(layer_head_mask)}",
)
weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights
attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
attn_output = self.o(unshape(attn_output))
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions:
outputs = outputs + (weights,)
return outputs
class TFT5LayerSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
super().__init__(**kwargs)
self.SelfAttention = TFT5Attention(
config,
has_relative_attention_bias=has_relative_attention_bias,
name="SelfAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
training=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
hidden_states = hidden_states + self.dropout(attention_output[0], training=training)
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
class TFT5LayerCrossAttention(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.EncDecAttention = TFT5Attention(
config,
has_relative_attention_bias=False,
name="EncDecAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
query_length=None,
use_cache=False,
output_attentions=False,
training=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
hidden_states = hidden_states + self.dropout(attention_output[0], training=training)
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
class TFT5Block(tf.keras.layers.Layer):
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
super().__init__(**kwargs)
self.is_decoder = config.is_decoder
self.layer = []
self.layer.append(
TFT5LayerSelfAttention(
config,
has_relative_attention_bias=has_relative_attention_bias,
name="layer_._0",
)
)
if self.is_decoder:
self.layer.append(
TFT5LayerCrossAttention(
config,
name="layer_._1",
)
)
self.layer.append(TFT5LayerFF(config, name=f"layer_._{len(self.layer)}"))
def call(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=None,
encoder_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
training=False,
):
if past_key_value is not None:
assert self.is_decoder, "Only decoder can use `past_key_values`"
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}."
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
if self.is_decoder and encoder_hidden_states is not None:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = shape_list(present_key_value_state[0])[2]
else:
query_length = None
cross_attention_outputs = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
)
hidden_states = cross_attention_outputs[0]
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:]
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states, training=training)
outputs = (hidden_states,)
# Add attentions if we output them
outputs = outputs + (present_key_value_state,) + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
####################################################
# The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
####################################################
@keras_serializable
class TFT5MainLayer(tf.keras.layers.Layer):
config_class = T5Config
def __init__(self, config, embed_tokens=None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.use_cache = config.use_cache
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
self.config = config
self.num_hidden_layers = config.num_layers
self.block = [
TFT5Block(config, has_relative_attention_bias=bool(i == 0), name=f"block_._{i}")
for i in range(config.num_layers)
]
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
def call(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
) -> Tuple:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
)
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
inputs["input_ids"] = tf.reshape(inputs["input_ids"], (-1, input_shape[-1]))
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if inputs["inputs_embeds"] is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = (
shape_list(inputs["past_key_values"][0][0])[2] + seq_length
if inputs["past_key_values"] is not None
else seq_length
)
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill((batch_size, mask_seq_length), 1)
if (
self.is_decoder
and inputs["encoder_attention_mask"] is None
and inputs["encoder_hidden_states"] is not None
):
encoder_seq_length = shape_list(inputs["encoder_hidden_states"])[1]
inputs["encoder_attention_mask"] = tf.fill((batch_size, encoder_seq_length), 1)
# initialize past_key_values with `None` if past does not exist
if inputs["past_key_values"] is None:
inputs["past_key_values"] = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=inputs["inputs_embeds"].dtype)
num_dims_attention_mask = len(shape_list(inputs["attention_mask"]))
if num_dims_attention_mask == 3:
extended_attention_mask = inputs["attention_mask"][:, None, :, :]
elif num_dims_attention_mask == 2:
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
if self.is_decoder:
seq_ids = tf.range(mask_seq_length)
causal_mask = tf.less_equal(
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
extended_attention_mask = causal_mask[:, None, :, :] * inputs["attention_mask"][:, None, None, :]
if inputs["past_key_values"][0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = inputs["attention_mask"][:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -1e9 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# extended_attention_mask = tf.math.equal(extended_attention_mask,
# tf.transpose(extended_attention_mask, perm=(-1, -2)))
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
inputs["encoder_attention_mask"] = tf.cast(
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
)
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
encoder_extended_attention_mask = None
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
position_bias = None
encoder_decoder_position_bias = None
hidden_states = self.dropout(inputs["inputs_embeds"], training=inputs["training"])
for idx, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value,
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
training=inputs["training"],
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, past_key_values, (self-attention weights),
# (self-attention position bias), (cross-attention position bias), (cross-attention weights),
position_bias = layer_outputs[2]
if self.is_decoder and inputs["encoder_hidden_states"] is not None:
encoder_decoder_position_bias = layer_outputs[4 if inputs["output_attentions"] else 3]
# append next layer key value states
if present_key_value_state is not None and inputs["use_cache"] and self.is_decoder:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if inputs["output_attentions"]:
all_attentions = all_attentions + (layer_outputs[3],)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
# Add last layer
if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,)
if not inputs["return_dict"]:
outputs = (hidden_states,)
# need to check if is decoder here as well for special cases when using keras compile
if inputs["use_cache"] and self.is_decoder:
outputs = outputs + (present_key_value_states,)
if inputs["output_hidden_states"]:
outputs = outputs + (all_hidden_states,)
if inputs["output_attentions"]:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
if self.is_decoder:
return TFBaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
else:
return TFBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
####################################################
# TFT5PreTrainedModel is a sub-class of tf.keras.Model
# which take care of loading and saving pretrained weights
# and various common utilities.
# Here you just need to specify a few (self-explanatory)
# pointers for your model.
####################################################
class TFT5PreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = T5Config
base_model_prefix = "transformer"
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
_keys_to_ignore_on_load_unexpected = [r"decoder\Wblock[\W_0]+layer[\W_1]+EncDecAttention\Wrelative_attention_bias"]
@property
def dummy_inputs(self):
inputs = tf.constant(DUMMY_INPUTS)
input_mask = tf.constant(DUMMY_MASK)
dummy_inputs = {
"input_ids": inputs,
"decoder_input_ids": inputs,
"decoder_attention_mask": input_mask,
}
return dummy_inputs
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
}
]
)
def serving(self, inputs):
output = self.call(inputs)
return self.serving_output(output)
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
try:
self.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
self.shared.weight = value
self.shared.vocab_size = shape_list(value)[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.embed_tokens = embed_tokens
if hasattr(self, "decoder"):
self.decoder.embed_tokens = embed_tokens
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
assert (
decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100,
tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype),
shifted_input_ids,
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(
shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype)
)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
T5_START_DOCSTRING = r"""
The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
text-to-text denoising generative setting.
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
<Tip>
TF 2.0 models accepts two formats as inputs:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
tensors in the first argument of the model call function: `model(inputs)`.
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the
first positional argument :
- a single Tensor with `input_ids` only and nothing else: `model(inputs_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
</Tip>
Parameters:
config ([`T5Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
T5_INPUTS_DOCSTRING = r"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
should be able to pad the inputs on the right or the left.
Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
[`PreTrainedTokenizer.encode`] for details.
[What are input IDs?](../glossary#input-ids)
To know more on how to prepare `inputs` for pretraining take a look at [T5 Training](./t5#training).