-
Notifications
You must be signed in to change notification settings - Fork 25.5k
/
modeling_tf_idefics.py
1812 lines (1544 loc) 路 78.3 KB
/
modeling_tf_idefics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 Idefics model. """
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import tensorflow as tf
from ... import TFPreTrainedModel
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import ModelOutput
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
keras_serializable,
shape_list,
unpack_inputs,
)
from ...tf_utils import invert_attention_mask, scaled_dot_product_attention
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_idefics import IdeficsConfig
from .perceiver_tf import TFIdeficsPerceiverResampler
from .vision_tf import TFIdeficsVisionTransformer
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "IdeficsConfig"
@dataclass
class TFIdeficsBaseModelOutputWithPast(ModelOutput):
"""
Base class for Idefics model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(tf.Tensor)`, *optional*):
Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
"""
last_hidden_state: tf.Tensor = None
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
image_hidden_states: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFIdeficsCausalLMOutputWithPast(ModelOutput):
"""
Base class for Idefics causal language model (or autoregressive) outputs.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`tuple(tf.Tensor)`, *optional*):
Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images,
sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
image_hidden_states: Optional[Tuple[tf.Tensor]] = None
def expand_inputs_for_generation(
input_ids,
expand_size=1,
is_encoder_decoder=False,
attention_mask=None,
encoder_outputs=None,
**model_kwargs,
):
expanded_return_idx = tf.reshape(tf.repeat(tf.range(tf.shape(input_ids)[0]), expand_size), [-1])
input_ids = tf.gather(input_ids, expanded_return_idx)
model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
model_kwargs["image_encoder_embeddings"] = model_kwargs.get("image_encoder_embeddings", None)
model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings", None)
model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx)
if attention_mask is not None:
model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx)
if model_kwargs["image_attention_mask"] is not None:
model_kwargs["image_attention_mask"] = tf.gather(model_kwargs["image_attention_mask"], expanded_return_idx)
if model_kwargs["pixel_values"] is not None:
model_kwargs["pixel_values"] = tf.gather(model_kwargs["pixel_values"], expanded_return_idx)
elif model_kwargs["image_encoder_embeddings"] is not None:
model_kwargs["image_encoder_embeddings"] = tf.gather(
model_kwargs["image_encoder_embeddings"], expanded_return_idx
)
elif model_kwargs["perceiver_embeddings"] is not None:
model_kwargs["perceiver_embeddings"] = tf.gather(model_kwargs["perceiver_embeddings"], expanded_return_idx)
return input_ids, model_kwargs
def update_model_kwargs_for_generation(outputs, model_kwargs):
# must have this key set to at least None
if "past_key_values" in outputs:
model_kwargs["past_key_values"] = outputs.past_key_values
else:
model_kwargs["past_key_values"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = tf.concat([token_type_ids, token_type_ids[:, -1:, ...]], axis=-1)
# update attention masks
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = tf.concat(
[attention_mask, tf.ones_like(attention_mask[:, -1:, ...])], axis=-1
)
if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1:, ...]
model_kwargs["image_attention_mask"] = last_mask
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past_key_values is not None:
input_ids = input_ids[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = tf.math.cumsum(tf.cast(attention_mask, dtype=tf.int64), axis=-1) - 1
position_ids = tf.where(attention_mask == 0, 1, position_ids)
if past_key_values is not None:
position_ids = position_ids[:, -1:]
pixel_values = kwargs.get("pixel_values", None)
image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
image_attention_mask = kwargs.get("image_attention_mask", None)
interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"pixel_values": pixel_values,
"image_encoder_embeddings": image_encoder_embeddings,
"perceiver_embeddings": perceiver_embeddings,
"image_attention_mask": image_attention_mask,
"interpolate_pos_encoding": interpolate_pos_encoding,
}
def freeze_model(model, module_exceptions=[]):
mapping = {
"LayerNorm": tf.keras.layers.LayerNormalization,
"Dense": tf.keras.layers.Dense,
"Embedding": tf.keras.layers.Embedding,
}
module_exceptions_mapped = [mapping[m] for m in module_exceptions]
if not hasattr(model, "layers"):
model.trainable = False # It is just a layer
return model
for layer in model.layers:
if module_exceptions and any(isinstance(layer, t) for t in module_exceptions_mapped):
layer.trainable = True # Explicitly setting it to true to avoid any mistakes
else:
layer.trainable = False
return model
class TFIdeficsDecoupledEmbedding(tf.keras.layers.Embedding):
"""
Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
then it will create `num_additional_embeddings` additional parameters that are always trained. If
`num_additional_embeddings=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Embedding`.
"""
def __init__(
self,
num_embeddings,
num_additional_embeddings,
embedding_dim,
partially_freeze: Optional[bool] = False,
dtype=None,
**kwargs,
) -> None:
"""
Args:
num_embeddings (`int`):
Size of the dictionary of embeddings
num_additional_embeddings (`int`):
Number of additional embeddings. Only useful when you `partially_freeze=True`.
embedding_dim (`int`):
The size of each embedding vector
partially_freeze: (`bool`, *optional*, defaults to `False`):
If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
Note: there are a lot of other parameters to initialize a standard `tf.keras.layers.Embedding` such as `mask_zero`,
`input_length` or `embeddings_initializer`. We are not supporting these.
"""
super().__init__(
input_dim=num_embeddings,
output_dim=embedding_dim,
dtype=dtype,
**kwargs,
)
self.num_embeddings = num_embeddings
self.num_additional_embeddings = num_additional_embeddings
self.partially_freeze = partially_freeze
if partially_freeze:
self.trainable = False
if self.num_additional_embeddings > 0:
self.additional_embedding = tf.keras.layers.Embedding(
input_dim=self.num_additional_embeddings,
output_dim=embedding_dim,
dtype=dtype,
name="additional_embedding",
)
def call(self, input_ids):
"""
we have 2 embeddings, with different indices - one pretrained self.weight and another
self.additional_embedding.weight that is being trained.
in order to make a lookup of the input ids, we:
1. find out the indices of the entries belonging to the 2nd embedding
2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
embedding starts from 0 and not num_embeddings
3. perform the 2nd embedding lookup
4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
5. perform the 1st embedding lookup
6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
measure.
"""
if self.num_additional_embeddings == 0:
return super().call(input_ids)
# Clone so that we don't modify the original input_ids later on
input_ids = tf.identity(input_ids)
additional_vocab_indices = tf.where(input_ids >= self.num_embeddings)
input_ids_additional_vocab = tf.gather_nd(input_ids, additional_vocab_indices)
additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
input_ids = tf.tensor_scatter_nd_update(
input_ids,
additional_vocab_indices,
# tensor filled with 0, having the same length as additional_vocab_indices
tf.zeros(tf.shape(additional_vocab_indices)[0], dtype=input_ids.dtype),
)
full_vector = super().call(input_ids)
# overwrite the records with high indices
full_vector = tf.tensor_scatter_nd_update(full_vector, additional_vocab_indices, additional_embeddings)
return full_vector
def extra_repr(self) -> str:
return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
self.num_embeddings,
self.num_additional_embeddings,
self.output_dim,
self.partially_freeze,
)
class TFIdeficsDecoupledLinear(tf.keras.layers.Layer):
"""
Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0,
then it will create `out_additional_features * in_features` additional parameters that are always trained. If
`out_additional_features=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Dense`.
"""
def __init__(
self,
in_features: int,
out_features: int,
out_additional_features: int = 0,
bias: bool = True,
partially_freeze: bool = True,
**kwargs,
) -> None:
"""
out_additional_features: int. Number of additional trainable dimensions. Only makes sense when
`partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra
parameters (if any) will be trainable. If False, default to the regular behavior of tf.keras.layers.Dense.
"""
super().__init__(**kwargs)
self.out_additional_features = out_additional_features
self.partially_freeze = partially_freeze
self.in_features = in_features
self.out_features = out_features
self.use_bias = bias
if out_additional_features > 0:
self.additional_fc = tf.keras.layers.Dense(
units=out_additional_features, use_bias=bias, name="additional_fc"
)
def call(self, inputs: tf.Tensor) -> tf.Tensor:
output = tf.linalg.matmul(a=inputs, b=self.weight, transpose_b=True)
if self.bias is not None:
output = tf.nn.bias_add(output, self.bias)
if self.out_additional_features > 0:
additional_features = self.additional_fc(inputs)
output = tf.concat([output, additional_features], axis=-1)
return output
def get_config(self):
config = super().get_config()
config.update(
{
"in_features": self.in_features,
"out_features": self.out_features,
"out_additional_features": self.out_additional_features,
"bias": self.bias is not None,
"partially_freeze": self.partially_freeze,
}
)
return config
def extra_repr(self) -> str:
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
self.in_features,
self.out_features,
self.out_additional_features,
self.bias is not None,
self.partially_freeze,
)
@classmethod
def from_config(cls, config):
return cls(**config)
def build(self, input_shape=None):
if self.built:
return
self.built = True
self.weight = self.add_weight(
shape=(self.out_features, self.in_features), trainable=not self.partially_freeze, name="weight"
)
if self.use_bias:
self.bias = self.add_weight(shape=(self.out_features,), trainable=not self.partially_freeze, name="bias")
else:
self.bias = None
if getattr(self, "additional_fc", None) is not None:
with tf.name_scope(self.additional_fc.name):
self.additional_fc.build(self.in_features)
def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0):
"""
Make causal mask used for bi-directional self-attention, supporting both static and dynamic shapes.
"""
bsz, tgt_len = input_ids_shape
# Create a matrix where only the lower triangle and diagonal are filled with zeros (causal mask)
mask = tf.fill((tgt_len, tgt_len), tf.dtypes.as_dtype(dtype).min)
mask_cond = tf.range(tgt_len)
mask = tf.where(mask_cond[:, None] >= mask_cond[None, :], 0.0, mask)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1)
if bsz is None:
# When batch size is dynamic, expand and tile
# so we can compile a functional model
mask = tf.expand_dims(mask, 0)
mask = tf.expand_dims(mask, 0) # shape: (1, 1, tgt_len, tgt_len + past_key_values_length)
mask = tf.tile(mask, [bsz, 1, 1, 1])
else:
# When batch size is static, directly use broadcast_to
mask = tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
return mask
def _expand_mask(mask, dtype, tgt_len=None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = shape_list(mask)
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.expand_dims(tf.expand_dims(mask, 1), 1)
expanded_mask = tf.broadcast_to(expanded_mask, [bsz, 1, tgt_len, src_len])
inverted_mask = 1.0 - tf.cast(expanded_mask, dtype)
return tf.where(
tf.cast(inverted_mask, bool), tf.fill(dims=shape_list(inverted_mask), value=tf.float32.min), inverted_mask
)
class TFIdeficsRMSNorm(tf.keras.layers.Layer):
def __init__(self, hidden_size, eps=1e-6, **kwargs):
"""
TFIdeficsRMSNorm is equivalent to T5LayerNorm
"""
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.variance_epsilon = eps
def build(self, input_shape):
if self.built:
return
self.built = True
self.weight = self.add_weight(name="weight", shape=[self.hidden_size], initializer="ones")
super().build(input_shape)
def call(self, hidden_states):
variance = tf.math.reduce_mean(tf.math.square(tf.cast(hidden_states, tf.float32)), axis=-1, keepdims=True)
hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [tf.float16, tf.bfloat16]:
hidden_states = tf.cast(hidden_states, self.weight.dtype)
return self.weight * hidden_states
class TFIdeficsEmbedding(tf.keras.layers.Layer):
def __init__(self, dim, max_position_embeddings=2048, base=10000, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = tf.constant(
1.0 / (self.base ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim))
)
def _compute_cos_sin(self, seq_len):
t = tf.range(seq_len, dtype=self.inv_freq.dtype)
freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication
emb = tf.concat((freqs, freqs), axis=-1)
return tf.cos(emb), tf.sin(emb)
def call(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len is None:
seq_len = shape_list(x)[2]
return self._compute_cos_sin(seq_len=seq_len)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return tf.concat((-x2, x1), axis=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = tf.gather(cos, position_ids) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
sin = tf.gather(sin, position_ids)
cos = tf.expand_dims(cos, 1)
sin = tf.expand_dims(sin, 1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class TFIdeficsMLP(tf.keras.layers.Layer):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
**kwargs,
):
super().__init__(**kwargs)
self.gate_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="gate_proj")
self.down_proj = tf.keras.layers.Dense(hidden_size, use_bias=False, name="down_proj")
self.up_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="up_proj")
self.act_fn = get_tf_activation(hidden_act)
self.intermediate_size = intermediate_size
self.hidden_size = hidden_size
def call(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "gate_proj", None) is not None:
with tf.name_scope(self.gate_proj.name):
self.gate_proj.build(self.hidden_size)
if getattr(self, "down_proj", None) is not None:
with tf.name_scope(self.down_proj.name):
self.down_proj.build(self.intermediate_size)
if getattr(self, "up_proj", None) is not None:
with tf.name_scope(self.up_proj.name):
self.up_proj.build(self.hidden_size)
class TFIdeficsAttention(tf.keras.layers.Layer):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout: float = 0.0,
is_cross_attention: bool = False,
config: IdeficsConfig = None,
qk_layer_norms: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.dropout = dropout
self.config = config
self.is_causal = True
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads})."
)
self.is_cross_attention = is_cross_attention
self.q_proj = tf.keras.layers.Dense(
num_heads * self.head_dim,
use_bias=False,
name="q_proj",
)
self.k_proj = tf.keras.layers.Dense(
num_heads * self.head_dim,
use_bias=False,
name="k_proj",
)
self.v_proj = tf.keras.layers.Dense(
num_heads * self.head_dim,
use_bias=False,
name="v_proj",
)
self.o_proj = tf.keras.layers.Dense(
hidden_size,
use_bias=False,
name="o_proj",
)
self.rotary_emb = TFIdeficsEmbedding(self.head_dim, name="rotary_emb")
self.qk_layer_norms = qk_layer_norms
if self.qk_layer_norms:
self.q_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="q_layer_norm")
self.k_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="k_layer_norm")
def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3])
def call(
self,
hidden_states: tf.Tensor,
key_value_states: Optional[tf.Tensor] = None,
attention_mask: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[Tuple[tf.Tensor]]]:
# if key_value_states are provided this layer is used as a cross-attention layer
is_cross_attention = self.is_cross_attention or key_value_states is not None
bsz, q_len, _ = shape_list(hidden_states)
query_states = self._shape(self.q_proj(hidden_states), q_len, bsz)
if not is_cross_attention:
key_states = self._shape(self.k_proj(hidden_states), q_len, bsz)
value_states = self._shape(self.v_proj(hidden_states), q_len, bsz)
else:
_, kv_len, _ = shape_list(key_value_states) # Note that, in this case, `kv_len` == `kv_seq_len`
key_states = self._shape(self.k_proj(key_value_states), kv_len, bsz)
value_states = self._shape(self.v_proj(key_value_states), kv_len, bsz)
kv_seq_len = shape_list(key_states)[-2]
if past_key_value is not None:
kv_seq_len += shape_list(past_key_value[0])[-2]
if not is_cross_attention:
# Below is to allow symbolic tensors compilation
if tf.is_tensor(kv_seq_len):
seq_len = tf.reduce_max(kv_seq_len, q_len)
else:
seq_len = max(kv_seq_len, q_len)
cos, sin = self.rotary_emb(value_states, seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = tf.concat([past_key_value[0], key_states], axis=2)
value_states = tf.concat([past_key_value[1], value_states], axis=2)
past_key_value = (key_states, value_states) if use_cache else None
if self.qk_layer_norms:
query_states = self.q_layer_norm(query_states)
key_states = self.k_layer_norm(key_states)
tf.debugging.assert_equal(
tf.shape(attention_mask),
[bsz, 1, q_len, kv_seq_len],
message=f"Attention weights should be of size {[bsz, 1, q_len, kv_seq_len]}, but is {tf.shape(attention_mask)}",
)
attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
tf.debugging.assert_equal(
tf.shape(attn_output),
[bsz, self.num_heads, q_len, self.head_dim],
message=f"Attention weights should be of size {[bsz, self.num_heads, q_len, self.head_dim]}, but is {tf.shape(attn_output)}",
)
attn_output = tf.reshape(tf.transpose(attn_output, perm=[0, 2, 1, 3]), (bsz, q_len, self.hidden_size))
attn_output = self.o_proj(attn_output)
attn_weights = None
if output_attentions:
logger.warning_once(
"attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
)
return attn_output, attn_weights, past_key_value
def build(self, input_shape=None):
if self.built:
return
self.built = True
if self.is_cross_attention:
kv_input_dim = (
self.hidden_size
if not hasattr(self.config.vision_config, "embed_dim")
else self.config.vision_config.embed_dim
)
else:
kv_input_dim = self.hidden_size
if getattr(self, "o_proj", None) is not None:
with tf.name_scope(self.o_proj.name):
self.o_proj.build(self.num_heads * self.head_dim)
if getattr(self, "q_proj", None) is not None:
with tf.name_scope(self.q_proj.name):
self.q_proj.build(self.hidden_size)
if getattr(self, "k_proj", None) is not None:
with tf.name_scope(self.k_proj.name):
self.k_proj.build(kv_input_dim)
if getattr(self, "v_proj", None) is not None:
with tf.name_scope(self.v_proj.name):
self.v_proj.build(kv_input_dim)
if getattr(self, "rotary_emb", None) is not None:
with tf.name_scope(self.rotary_emb.name):
self.rotary_emb.build(None)
class TFIdeficsDecoderLayer(tf.keras.layers.Layer):
def __init__(self, config: IdeficsConfig, **kwargs):
super().__init__(**kwargs)
self.hidden_size = config.hidden_size
self.self_attn = TFIdeficsAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.dropout,
config=config,
name="self_attn",
)
self.mlp = TFIdeficsMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
name="mlp",
)
self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm")
self.post_attention_layernorm = TFIdeficsRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm"
)
self.dropout = config.dropout
def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
training=False,
) -> Tuple[tf.Tensor, Optional[Tuple[tf.Tensor, tf.Tensor]]]:
"""
Args:
hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`tf.Tensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def build(self, input_shape=None):
if self.built:
return
self.built = True
if getattr(self, "self_attn", None) is not None:
with tf.name_scope(self.self_attn.name):
self.self_attn.build(None)
if getattr(self, "mlp", None) is not None:
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
if getattr(self, "input_layernorm", None) is not None:
with tf.name_scope(self.input_layernorm.name):
self.input_layernorm.build(None)
if getattr(self, "post_attention_layernorm", None) is not None:
with tf.name_scope(self.post_attention_layernorm.name):
self.post_attention_layernorm.build(None)
class TFIdeficsGatedCrossAttentionLayer(tf.keras.layers.Layer):
def __init__(self, config: IdeficsConfig, **kwargs):
super().__init__(**kwargs)
self.hidden_size = config.hidden_size
self.cross_attn = TFIdeficsAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
is_cross_attention=True,
dropout=config.dropout,
config=config,
qk_layer_norms=config.qk_layer_norms,
name="cross_attn",
)
self.mlp = TFIdeficsMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
name="mlp",
)
self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm")
self.post_attention_layernorm = TFIdeficsRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm"
)
self.config = config.dropout
self.act_cross_attn = tf.keras.activations.tanh
self.act_dense = tf.keras.activations.tanh
self.alpha_initializer = config.alpha_initializer
self.alpha_type = config.alpha_type
self.alphas_initializer_range = config.alphas_initializer_range
def build(self, input_shape):
if self.built:
return
self.built = True
if self.alpha_initializer == "zeros":
if self.alpha_type == "vector":
self.alpha_cross_attn = self.add_weight(
shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_cross_attn"
)
self.alpha_dense = self.add_weight(
shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_dense"
)
elif self.alpha_type == "float":
self.alpha_cross_attn = self.add_weight(
shape=(1,), initializer="zeros", trainable=True, name="alpha_cross_attn"
)
self.alpha_dense = self.add_weight(shape=(1,), initializer="zeros", trainable=True, name="alpha_dense")
else:
raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})")
elif self.alpha_initializer == "ones":
if self.alpha_type == "vector":
self.alpha_cross_attn = self.add_weight(
shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_cross_attn"
)
self.alpha_dense = self.add_weight(
shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_dense"
)
elif self.alpha_type == "float":
self.alpha_cross_attn = self.add_weight(
shape=(1,), initializer="ones", trainable=True, name="alpha_cross_attn"
)
self.alpha_dense = self.add_weight(shape=(1,), initializer="ones", trainable=True, name="alpha_dense")
else:
raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})")
elif self.alpha_initializer in {"normal", "gaussian", "random"}:
if self.alpha_type == "vector":
self.alpha_cross_attn = self.add_weight(
shape=(1, 1, self.hidden_size),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
trainable=True,
name="alpha_cross_attn",
)
self.alpha_dense = self.add_weight(
shape=(1, 1, self.hidden_size),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
trainable=True,
name="alpha_dense",
)
elif self.alpha_type == "float":
self.alpha_cross_attn = self.add_weight(
shape=(1,),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
trainable=True,
name="alpha_type",
)
self.alpha_dense = self.add_weight(
shape=(1,),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
trainable=True,
name="alpha_dense",
)
else:
raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})")
else:
raise NotImplementedError(f"Alpha initialization scheme {self.alpha_initializer} not yet implemented!")
if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
raise ValueError("Alpha parameters not initialized correctly!")
with tf.name_scope(self.cross_attn.name):
self.cross_attn.build(None)
with tf.name_scope(self.mlp.name):
self.mlp.build(None)
with tf.name_scope(self.input_layernorm.name):
self.input_layernorm.build(None)
with tf.name_scope(self.post_attention_layernorm.name):
self.post_attention_layernorm.build(None)
super().build(input_shape)
def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
image_hidden_states: Optional[tf.Tensor] = None,
image_attention_mask: Optional[tf.Tensor] = None,
cross_attention_gate: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
) -> Tuple[tf.Tensor, Optional[Tuple[tf.Tensor, tf.Tensor]]]: