-
Notifications
You must be signed in to change notification settings - Fork 26.4k
/
modeling_tf_sam.py
1465 lines (1227 loc) · 65.3 KB
/
modeling_tf_sam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a
discrepancy, the original file should be regarded as the 'reference' version.
"""
from __future__ import annotations
import collections
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import ACT2FN
from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, shape_list, unpack_inputs
from ...tf_utils import flatten, functional_layernorm
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SamConfig"
_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/sam-vit-huge",
"facebook/sam-vit-large",
"facebook/sam-vit-base",
# See all SAM models at https://huggingface.co/models?filter=sam
]
@dataclass
class TFSamVisionEncoderOutput(ModelOutput):
"""
Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
layer to the pooler_output.
Args:
image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: tf.Tensor | None = None
last_hidden_state: tf.Tensor = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
@dataclass
class TFSamImageSegmentationOutput(ModelOutput):
"""
Base class for Segment-Anything model's output
Args:
iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`):
The iou scores of the predicted masks.
pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`):
The predicted low resolutions masks. Needs to be post-processed by the processor
vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
iou_scores: tf.Tensor = None
pred_masks: tf.Tensor = None
vision_hidden_states: Tuple[tf.Tensor] | None = None
vision_attentions: Tuple[tf.Tensor] | None = None
mask_decoder_attentions: Tuple[tf.Tensor] | None = None
class TFSamPatchEmbeddings(tf.keras.layers.Layer):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = tf.keras.layers.Conv2D(
hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
)
def call(self, pixel_values):
batch_size, num_channels, height, width = shape_list(pixel_values)
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1]))
return embeddings
class TFSamMLPBlock(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.lin1 = tf.keras.layers.Dense(config.mlp_dim, name="lin1")
self.lin2 = tf.keras.layers.Dense(config.hidden_size, name="lin2")
self.act = ACT2FN[config.hidden_act]
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.lin1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.lin2(hidden_states)
return hidden_states
class TFSamLayerNorm(tf.keras.layers.Layer):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs):
super().__init__(**kwargs)
self.eps = eps
self.data_format = data_format
self.normalized_shape = normalized_shape
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
def build(self, input_shape):
self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight")
self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias")
super().build(input_shape)
def call(self, x: tf.Tensor) -> tf.Tensor:
if self.data_format == "channels_last":
x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1)
elif self.data_format == "channels_first":
x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1)
return x
class TFSamAttention(tf.keras.layers.Layer):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values.
"""
def __init__(self, config, downsample_rate=None, **kwargs):
super().__init__(**kwargs)
self.hidden_size = config.hidden_size
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
self.internal_dim = config.hidden_size // downsample_rate
self.num_attention_heads = config.num_attention_heads
if self.internal_dim % config.num_attention_heads != 0:
raise ValueError("num_attention_heads must divide hidden_size.")
self.q_proj = tf.keras.layers.Dense(self.internal_dim, name="q_proj")
self.k_proj = tf.keras.layers.Dense(self.internal_dim, name="k_proj")
self.v_proj = tf.keras.layers.Dense(self.internal_dim, name="v_proj")
self.out_proj = tf.keras.layers.Dense(self.hidden_size, name="out_proj")
def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor:
batch, point_batch_size, n_tokens, channel = shape_list(hidden_states)
c_per_head = channel // num_attention_heads
hidden_states = tf.reshape(
hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
)
return tf.transpose(hidden_states, perm=[0, 2, 1, 3])
def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor:
batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
return tf.reshape(
hidden_states,
(batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
)
def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = shape_list(query)[1]
# Separate into heads
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)
# SamAttention
_, _, _, c_per_head = shape_list(query)
attn = tf.matmul(
query, tf.transpose(key, perm=[0, 1, 3, 2])
) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
attn = attn / tf.math.sqrt(float(c_per_head))
attn = tf.nn.softmax(attn, axis=-1)
# Get output
out = tf.matmul(attn, value)
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
class TFSamTwoWayAttentionBlock(tf.keras.layers.Layer):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs):
"""
A transformer block with four layers:
(1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
sparse inputs (4) cross attention of dense inputs -> sparse inputs
Arguments:
config (`SamMaskDecoderConfig`):
The configuration file used to instantiate the block
attention_downsample_rate (*optionalk*, int, defaults to 2):
The downsample ratio of the block used to reduce the inner dim of the attention.
skip_first_layer_pe (*optional*, bool, defaults to `False`):
Whether or not to skip the addition of the query_point_embedding on the first layer.
"""
super().__init__(**kwargs)
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps
self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn")
self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1")
self.cross_attn_token_to_image = TFSamAttention(
config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image"
)
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2")
self.mlp = TFSamMLPBlock(config, name="mlp")
self.layer_norm3 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3")
self.layer_norm4 = tf.keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4")
self.cross_attn_image_to_token = TFSamAttention(
config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token"
)
self.skip_first_layer_pe = skip_first_layer_pe
def call(
self,
queries: tf.Tensor,
keys: tf.Tensor,
query_point_embedding: tf.Tensor,
key_point_embedding: tf.Tensor,
output_attentions: bool = False,
):
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(query=queries, key=queries, value=queries)
else:
query = queries + query_point_embedding
attn_out = self.self_attn(query=query, key=query, value=queries)
queries = queries + attn_out
queries = self.layer_norm1(queries)
# Cross attention block, tokens attending to image embedding
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out
queries = self.layer_norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.layer_norm3(queries)
# Cross attention block, image embedding attending to tokens
query = queries + query_point_embedding
key = keys + key_point_embedding
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
keys = keys + attn_out
keys = self.layer_norm4(keys)
outputs = (queries, keys)
if output_attentions:
outputs = outputs + (attn_out,)
else:
outputs = outputs + (None,)
return outputs
class TFSamTwoWayTransformer(tf.keras.layers.Layer):
def __init__(self, config: SamMaskDecoderConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.num_hidden_layers = config.num_hidden_layers
self.layers = []
for i in range(self.num_hidden_layers):
self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}"))
self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image")
self.layer_norm_final_attn = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layer_norm_final_attn"
)
def call(
self,
point_embeddings: tf.Tensor,
image_embeddings: tf.Tensor,
image_positional_embeddings: tf.Tensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TFBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
all_attentions = ()
if image_embeddings is None:
raise ValueError("You have to specify an image_embedding")
image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None]
image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None]
# Prepare queries
queries = point_embeddings
keys = image_embeddings
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys, attention_outputs = layer(
queries=queries,
keys=keys,
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
output_attentions=output_attentions,
)
if output_attentions:
all_attentions = all_attentions + (attention_outputs,)
# Apply the final attenion layer from the points to the image
query = queries + point_embeddings
key = keys + image_positional_embeddings
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
queries = queries + attn_out
queries = self.layer_norm_final_attn(queries)
return queries, keys, all_attentions
class TFSamFeedForward(tf.keras.layers.Layer):
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs
):
super().__init__(**kwargs)
self.num_layers = num_layers
self.activation = tf.keras.layers.ReLU()
self.proj_in = tf.keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in")
self.proj_out = tf.keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out")
self.layers = [
tf.keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}")
for i in range(num_layers - 2)
]
self.sigmoid_output = sigmoid_output
def call(self, hidden_states):
hidden_states = self.proj_in(hidden_states)
hidden_states = self.activation(hidden_states)
for layer in self.layers:
hidden_states = self.activation(layer(hidden_states))
hidden_states = self.proj_out(hidden_states)
if self.sigmoid_output:
hidden_states = tf.sigmoid(hidden_states)
return hidden_states
class TFSamMaskDecoder(tf.keras.layers.Layer):
def __init__(self, config: SamMaskDecoderConfig, **kwargs):
super().__init__(**kwargs)
self.hidden_size = config.hidden_size
self.num_multimask_outputs = config.num_multimask_outputs
self.num_mask_tokens = config.num_multimask_outputs + 1
self.transformer = TFSamTwoWayTransformer(config, name="transformer")
self.upscale_conv1 = tf.keras.layers.Conv2DTranspose(
self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first"
)
self.upscale_conv2 = tf.keras.layers.Conv2DTranspose(
self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first"
)
self.upscale_layer_norm = TFSamLayerNorm(
self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm"
)
self.activation = tf.nn.gelu
mlps_list = []
for i in range(self.num_mask_tokens):
mlps_list += [
TFSamFeedForward(
self.hidden_size,
self.hidden_size,
self.hidden_size // 8,
3,
name=f"output_hypernetworks_mlps_._{i}",
)
]
self.output_hypernetworks_mlps = mlps_list
self.iou_prediction_head = TFSamFeedForward(
self.hidden_size,
config.iou_head_hidden_dim,
self.num_mask_tokens,
config.iou_head_depth,
name="iou_prediction_head",
)
def build(self, input_shape):
self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True)
self.mask_tokens = self.add_weight(
shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True
)
super().build(input_shape)
def call(
self,
image_embeddings: tf.Tensor,
image_positional_embeddings: tf.Tensor,
sparse_prompt_embeddings: tf.Tensor,
dense_prompt_embeddings: tf.Tensor,
multimask_output: bool,
output_attentions: Optional[bool] = None,
) -> Tuple[tf.Tensor, tf.Tensor]:
batch_size, num_channels, height, width = shape_list(image_embeddings)
point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1])
output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32)
output_tokens = tf.tile(
output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1]
) # Should be (batch_size, point_size, 5, 32)
# Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
# happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
# it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
if shape_list(sparse_prompt_embeddings)[1] != 0:
tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
else:
tokens = output_tokens
point_embeddings = tf.cast(tokens, self.iou_token.dtype)
image_embeddings = image_embeddings + dense_prompt_embeddings
image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0)
image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0)
point_embedding, image_embeddings, attentions = self.transformer(
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
output_attentions=output_attentions,
)
iou_token_out = point_embedding[:, :, 0, :]
mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2))
image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width])
upscaled_embedding = self.upscale_conv1(image_embeddings)
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
hyper_in_list = []
for i in range(self.num_mask_tokens):
current_mlp = self.output_hypernetworks_mlps[i]
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
hyper_in = tf.stack(hyper_in_list, axis=2)
_, num_channels, height, width = shape_list(upscaled_embedding)
upscaled_embedding = tf.reshape(
upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width]
)
masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width])
iou_pred = self.iou_prediction_head(iou_token_out)
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, :, mask_slice, :, :]
iou_pred = iou_pred[:, :, mask_slice]
outputs = (masks, iou_pred)
if output_attentions:
outputs = outputs + (attentions,)
else:
outputs = outputs + (None,)
return outputs
class TFSamPositionalEmbedding(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.scale = config.hidden_size // 2
self.config = config
def build(self, input_shape):
# TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized?
self.positional_embedding = self.add_weight(
name="positional_embedding",
shape=(2, self.config.num_pos_feats),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
trainable=False,
)
super().build(input_shape)
def call(self, input_coords, input_shape=None):
"""Positionally encode points that are normalized to [0,1]."""
coordinates = tf.identity(input_coords)
if input_shape is not None:
coordinates = tf.stack(
[
tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1],
tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0],
],
axis=-1,
)
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coordinates = 2 * coordinates - 1
coordinates = tf.cast(coordinates, self.positional_embedding.dtype)
coordinates = tf.matmul(coordinates, self.positional_embedding)
coordinates = 2 * np.pi * coordinates
# outputs d_1 x ... x d_n x channel shape
return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1)
class TFSamMaskEmbedding(tf.keras.layers.Layer):
def __init__(self, config: SamPromptEncoderConfig, **kwargs):
super().__init__(**kwargs)
self.mask_input_channels = config.mask_input_channels // 4
self.activation = ACT2FN[config.hidden_act]
self.conv1 = tf.keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1")
self.conv2 = tf.keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2")
self.conv3 = tf.keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3")
self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1")
self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2")
def call(self, masks):
masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last
hidden_states = self.conv1(masks)
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.activation(hidden_states)
dense_embeddings = self.conv3(hidden_states)
dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first
return dense_embeddings
def build(self, input_shape):
# This class needs an explicit build method because it isn't called with the standard dummy inputs
conv1_shape = [None, None, None, 1]
conv2_shape = [None, None, None, self.mask_input_channels]
conv3_shape = [None, None, None, self.mask_input_channels * 4]
layer_norm1_shape = [None, None, None, self.mask_input_channels]
layer_norm2_shape = [None, None, None, self.mask_input_channels * 4]
with tf.name_scope("conv1"):
self.conv1.build(conv1_shape)
with tf.name_scope("conv2"):
self.conv2.build(conv2_shape)
with tf.name_scope("conv3"):
self.conv3.build(conv3_shape)
with tf.name_scope("layer_norm1"):
self.layer_norm1.build(layer_norm1_shape)
with tf.name_scope("layer_norm2"):
self.layer_norm2.build(layer_norm2_shape)
super().build(input_shape)
class TFSamPromptEncoder(tf.keras.layers.Layer):
def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs):
super().__init__(**kwargs)
self.shared_embedding = shared_patch_embedding
self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed")
self.no_mask_embed = None
self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
self.input_image_size = config.image_size
self.point_embed = []
self.hidden_size = config.hidden_size
self.not_a_point_embed = None
self.config = config
def build(self, input_shape):
self.no_mask_embed = self.add_weight(
name="no_mask_embed.weight",
shape=(1, self.hidden_size),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
trainable=True,
)
self.point_embed = [
self.add_weight(
name=f"point_embed_._{i}.weight",
shape=(1, self.hidden_size),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
trainable=True,
)
for i in range(self.config.num_point_embeddings)
]
self.not_a_point_embed = self.add_weight(
name="not_a_point_embed.weight",
shape=(1, self.hidden_size),
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
trainable=True,
)
with tf.name_scope("mask_embed"):
# We must explicitly build the mask embed because it isn't touched by the standard dummy inputs
self.mask_embed.build(
(None, self.config.mask_input_channels, self.config.image_size, self.config.image_size)
)
super().build(input_shape)
def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
points = tf.concat([points, padding_point], axis=2)
labels = tf.concat([labels, padding_label], axis=2)
input_shape = (self.input_image_size, self.input_image_size)
point_embedding = self.shared_embedding(points, input_shape)
point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding)
point_embedding = tf.where(
labels[..., None] != -10,
point_embedding,
tf.zeros_like(point_embedding),
)
point_embedding = tf.where(
(labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding
)
point_embedding = tf.where(
(labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding
)
return point_embedding
def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
batch_size, nb_boxes = shape_list(boxes)[:2]
coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
input_shape = (self.input_image_size, self.input_image_size)
corner_embedding = self.shared_embedding(coords, input_shape)
corner_embedding += tf.where(
tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
self.point_embed[2][0],
self.point_embed[3][0],
)
return corner_embedding
def call(
self,
batch_size: Optional[int],
input_points: Optional[Tuple[tf.Tensor, tf.Tensor]],
input_labels: tf.Tensor | None,
input_boxes: tf.Tensor | None,
input_masks: tf.Tensor | None,
) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
points (`tf.Tensor`, *optional*):
point coordinates and labels to embed.
boxes (`tf.Tensor`, *optional*):
boxes to embed
masks (`tf.Tensor`, *optional*):
masks to embed
"""
sparse_embeddings = None
if input_points is not None:
batch_size, point_batch_size = shape_list(input_points)[:2]
if input_labels is None:
raise ValueError("If points are provided, labels must also be provided.")
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
sparse_embeddings = tf.zeros(
(batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype
)
sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
if input_boxes is not None:
batch_size = shape_list(input_boxes)[0]
box_embeddings = self._embed_boxes(input_boxes)
if sparse_embeddings is None:
sparse_embeddings = box_embeddings
else:
sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2)
if input_masks is not None:
dense_embeddings = self.mask_embed(input_masks)
else:
dense_embeddings = self.no_mask_embed[0]
dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1))
dense_embeddings = tf.tile(
dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])
)
if sparse_embeddings is None:
sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype)
return sparse_embeddings, dense_embeddings
class TFSamVisionAttention(tf.keras.layers.Layer):
"""Multi-head Attention block with relative position embeddings."""
def __init__(self, config, window_size, **kwargs):
super().__init__(**kwargs)
input_size = (
(config.image_size // config.patch_size, config.image_size // config.patch_size)
if window_size == 0
else (window_size, window_size)
)
self.input_size = input_size
self.num_attention_heads = config.num_attention_heads
head_dim = config.hidden_size // config.num_attention_heads
self.head_dim = head_dim
self.scale = head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = tf.keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv")
self.proj = tf.keras.layers.Dense(config.hidden_size, name="proj")
self.use_rel_pos = config.use_rel_pos
if self.use_rel_pos:
if input_size is None:
raise ValueError("Input size must be provided if using relative positional encoding.")
self.config = config
def build(self, input_shape):
if self.input_size is not None:
# initialize relative positional embeddings
self.rel_pos_h = self.add_weight(
shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h"
)
self.rel_pos_w = self.add_weight(
shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w"
)
super().build(input_shape)
def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int):
size of the query.
k_size (int):
size of key k.
rel_pos (`tf.Tensor`):
relative position embeddings (L, channel).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = tf.image.resize(
tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)),
size=(max_rel_dist, rel_pos.shape[1]),
method="bilinear",
)
rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist))
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0)
k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))
def add_decomposed_rel_pos(
self,
attn: tf.Tensor,
query: tf.Tensor,
rel_pos_h: tf.Tensor,
rel_pos_w: tf.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> tf.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`tf.Tensor`):
attention map.
query (`tf.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`tf.Tensor`):
relative position embeddings (Lh, channel) for height axis.
rel_pos_w (`tf.Tensor`):
relative position embeddings (Lw, channel) for width axis.
q_size (tuple):
spatial sequence size of query q with (query_height, query_width).
k_size (tuple):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`tf.Tensor`):
attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, dim = shape_list(query)
reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width))
attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2)
attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width))
return attn
def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
batch_size, height, width, _ = shape_list(hidden_states)
# qkv with shape (3, batch_size, nHead, height * width, channel)
qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1))
qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4))
# q, k, v with shape (batch_size * nHead, height * width, channel)
query, key, value = tf.unstack(
tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0
)
attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if training:
attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
else:
attn_probs = attn_weights
attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1))
attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4))
attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size))
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
class TFSamVisionLayer(tf.keras.layers.Layer):
def __init__(self, config, window_size, **kwargs):
super().__init__(**kwargs)
self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
self.attn = TFSamVisionAttention(config, window_size, name="attn")
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
self.mlp = TFSamMLPBlock(config, name="mlp")
self.window_size = window_size
def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]:
batch_size, height, width, channel = shape_list(hidden_states)
pad_h = (window_size - height % window_size) % window_size
pad_w = (window_size - width % window_size) % window_size
if pad_h > 0 or pad_w > 0:
hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]])
pad_height, pad_width = height + pad_h, width + pad_w
hidden_states = tf.reshape(
hidden_states,
[batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel],
)
windows = tf.reshape(
tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel]
)
return windows, (pad_height, pad_width)
def window_unpartition(
self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
) -> tf.Tensor:
pad_height, pad_width = padding_shape
height, width = original_shape
batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size)
hidden_states = tf.reshape(
windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1]
)
hidden_states = tf.reshape(
tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1]
)
if pad_height > height or pad_width > width:
hidden_states = hidden_states[:, :height, :width, :]
return hidden_states
def call(
self,
hidden_states: tf.Tensor,
output_attentions: Optional[bool] = False,
training: Optional[bool] = False,
) -> Tuple[tf.Tensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
if self.window_size > 0:
height, width = hidden_states.shape[1], hidden_states.shape[2]
hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)