/
modeling_tf_swiftformer.py
870 lines (711 loc) 路 34.2 KB
/
modeling_tf_swiftformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
# coding=utf-8
# Copyright 2024 MBZUAI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow SwiftFormer model."""
import collections.abc
from typing import Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
TFBaseModelOutputWithNoAttention,
TFImageClassifierOutputWithNoAttention,
)
from ...modeling_tf_utils import TFPreTrainedModel, keras, keras_serializable, unpack_inputs
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from .configuration_swiftformer import SwiftFormerConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "SwiftFormerConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "MBZUAI/swiftformer-xs"
_EXPECTED_OUTPUT_SHAPE = [1, 220, 7, 7]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "MBZUAI/swiftformer-xs"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"MBZUAI/swiftformer-xs",
# See all SwiftFormer models at https://huggingface.co/models?filter=swiftformer
]
class TFSwiftFormerPatchEmbeddingSequential(keras.layers.Layer):
"""
The sequential component of the patch embedding layer.
Input: tensor of shape `[batch_size, in_channels, height, width]`
Output: tensor of shape `[batch_size, out_channels, height/4, width/4]`
"""
def __init__(self, config: SwiftFormerConfig, **kwargs):
super().__init__(**kwargs)
self.out_chs = config.embed_dims[0]
self.zero_padding = keras.layers.ZeroPadding2D(padding=(1, 1))
self.conv1 = keras.layers.Conv2D(self.out_chs // 2, kernel_size=3, strides=2, name="0")
self.batch_norm1 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="1")
self.conv2 = keras.layers.Conv2D(self.out_chs, kernel_size=3, strides=2, name="3")
self.batch_norm2 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="4")
self.config = config
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
x = self.zero_padding(x)
x = self.conv1(x)
x = self.batch_norm1(x, training=training)
x = get_tf_activation("relu")(x)
x = self.zero_padding(x)
x = self.conv2(x)
x = self.batch_norm2(x, training=training)
x = get_tf_activation("relu")(x)
return x
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "conv1", None) is not None:
with tf.name_scope(self.conv1.name):
self.conv1.build(self.config.num_channels)
if getattr(self, "batch_norm1", None) is not None:
with tf.name_scope(self.batch_norm1.name):
self.batch_norm1.build((None, None, None, self.out_chs // 2))
if getattr(self, "conv2", None) is not None:
with tf.name_scope(self.conv2.name):
self.conv2.build((None, None, None, self.out_chs // 2))
if getattr(self, "batch_norm2", None) is not None:
with tf.name_scope(self.batch_norm2.name):
self.batch_norm2.build((None, None, None, self.out_chs))
self.built = True
class TFSwiftFormerPatchEmbedding(keras.layers.Layer):
"""
Patch Embedding Layer constructed of two 2D convolutional layers.
Input: tensor of shape `[batch_size, in_channels, height, width]`
Output: tensor of shape `[batch_size, out_channels, height/4, width/4]`
"""
def __init__(self, config: SwiftFormerConfig, **kwargs):
super().__init__(**kwargs)
self.patch_embedding = TFSwiftFormerPatchEmbeddingSequential(config, name="patch_embedding")
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
return self.patch_embedding(x, training=training)
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "patch_embedding", None) is not None:
with tf.name_scope(self.patch_embedding.name):
self.patch_embedding.build(None)
self.built = True
class TFSwiftFormerDropPath(keras.layers.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, config: SwiftFormerConfig, **kwargs) -> None:
super().__init__(**kwargs)
raise NotImplementedError("Drop path is not implemented in TF port")
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
raise NotImplementedError("Drop path is not implemented in TF port")
class TFSwiftFormerEmbeddings(keras.layers.Layer):
"""
Embeddings layer consisting of a single 2D convolutional and batch normalization layer.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels, height/stride, width/stride]`
"""
def __init__(self, config: SwiftFormerConfig, index: int, **kwargs):
super().__init__(**kwargs)
patch_size = config.down_patch_size
stride = config.down_stride
padding = config.down_pad
embed_dims = config.embed_dims
self.in_chans = embed_dims[index]
self.embed_dim = embed_dims[index + 1]
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride)
padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding)
self.pad = keras.layers.ZeroPadding2D(padding=padding)
self.proj = keras.layers.Conv2D(self.embed_dim, kernel_size=patch_size, strides=stride, name="proj")
self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm")
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
x = self.pad(x)
x = self.proj(x)
x = self.norm(x, training=training)
return x
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "proj", None) is not None:
with tf.name_scope(self.proj.name):
self.proj.build(self.in_chans)
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build((None, None, None, self.embed_dim))
self.built = True
class TFSwiftFormerConvEncoder(keras.layers.Layer):
"""
`SwiftFormerConvEncoder` with 3*3 and 1*1 convolutions.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels, height, width]`
"""
def __init__(self, config: SwiftFormerConfig, dim: int, **kwargs):
super().__init__(**kwargs)
hidden_dim = int(config.mlp_ratio * dim)
self.dim = dim
self.pad = keras.layers.ZeroPadding2D(padding=(1, 1))
self.depth_wise_conv = keras.layers.Conv2D(dim, kernel_size=3, groups=dim, name="depth_wise_conv")
self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm")
self.point_wise_conv1 = keras.layers.Conv2D(hidden_dim, kernel_size=1, name="point_wise_conv1")
self.act = get_tf_activation("gelu")
self.point_wise_conv2 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv2")
self.drop_path = keras.layers.Dropout(name="drop_path", rate=config.drop_conv_encoder_rate)
self.hidden_dim = int(config.mlp_ratio * self.dim)
def build(self, input_shape=None):
if self.built:
return
self.layer_scale = self.add_weight(
name="layer_scale",
shape=self.dim,
initializer="ones",
trainable=True,
)
if getattr(self, "depth_wise_conv", None) is not None:
with tf.name_scope(self.depth_wise_conv.name):
self.depth_wise_conv.build(self.dim)
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build((None, None, None, self.dim))
if getattr(self, "point_wise_conv1", None) is not None:
with tf.name_scope(self.point_wise_conv1.name):
self.point_wise_conv1.build(self.dim)
if getattr(self, "point_wise_conv2", None) is not None:
with tf.name_scope(self.point_wise_conv2.name):
self.point_wise_conv2.build(self.hidden_dim)
if getattr(self, "drop_path", None) is not None:
with tf.name_scope(self.drop_path.name):
self.drop_path.build(None)
self.built = True
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
input = x
x = self.pad(x)
x = self.depth_wise_conv(x)
x = self.norm(x, training=training)
x = self.point_wise_conv1(x)
x = self.act(x)
x = self.point_wise_conv2(x)
x = input + self.drop_path(self.layer_scale * x)
return x
class TFSwiftFormerMlp(keras.layers.Layer):
"""
MLP layer with 1*1 convolutions.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels, height, width]`
"""
def __init__(self, config: SwiftFormerConfig, in_features: int, **kwargs):
super().__init__(**kwargs)
hidden_features = int(in_features * config.mlp_ratio)
self.norm1 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm1")
self.fc1 = keras.layers.Conv2D(hidden_features, 1, name="fc1")
act_layer = get_tf_activation(config.hidden_act)
self.act = act_layer
self.fc2 = keras.layers.Conv2D(in_features, 1, name="fc2")
self.drop = keras.layers.Dropout(rate=config.drop_mlp_rate)
self.hidden_features = hidden_features
self.in_features = in_features
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
x = self.norm1(x, training=training)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x, training=training)
x = self.fc2(x)
x = self.drop(x, training=training)
return x
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "norm1", None) is not None:
with tf.name_scope(self.norm1.name):
self.norm1.build((None, None, None, self.in_features))
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build((None, None, None, self.in_features))
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build((None, None, None, self.hidden_features))
self.built = True
class TFSwiftFormerEfficientAdditiveAttention(keras.layers.Layer):
"""
Efficient Additive Attention module for SwiftFormer.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels, height, width]`
"""
def __init__(self, config: SwiftFormerConfig, dim: int = 512, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.to_query = keras.layers.Dense(dim, name="to_query")
self.to_key = keras.layers.Dense(dim, name="to_key")
self.scale_factor = dim**-0.5
self.proj = keras.layers.Dense(dim, name="proj")
self.final = keras.layers.Dense(dim, name="final")
def build(self, input_shape=None):
if self.built:
return
self.w_g = self.add_weight(
name="w_g",
shape=(self.dim, 1),
initializer=keras.initializers.RandomNormal(mean=0, stddev=1),
trainable=True,
)
if getattr(self, "to_query", None) is not None:
with tf.name_scope(self.to_query.name):
self.to_query.build(self.dim)
if getattr(self, "to_key", None) is not None:
with tf.name_scope(self.to_key.name):
self.to_key.build(self.dim)
if getattr(self, "proj", None) is not None:
with tf.name_scope(self.proj.name):
self.proj.build(self.dim)
if getattr(self, "final", None) is not None:
with tf.name_scope(self.final.name):
self.final.build(self.dim)
self.built = True
def call(self, x: tf.Tensor) -> tf.Tensor:
query = self.to_query(x)
key = self.to_key(x)
query = tf.math.l2_normalize(query, dim=-1)
key = tf.math.l2_normalize(key, dim=-1)
query_weight = query @ self.w_g
scaled_query_weight = query_weight * self.scale_factor
scaled_query_weight = tf.nn.softmax(scaled_query_weight, axis=-1)
global_queries = tf.math.reduce_sum(scaled_query_weight * query, axis=1)
global_queries = tf.tile(tf.expand_dims(global_queries, 1), (1, key.shape[1], 1))
out = self.proj(global_queries * key) + query
out = self.final(out)
return out
class TFSwiftFormerLocalRepresentation(keras.layers.Layer):
"""
Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels, height, width]`
"""
def __init__(self, config: SwiftFormerConfig, dim: int, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.pad = keras.layers.ZeroPadding2D(padding=(1, 1))
self.depth_wise_conv = keras.layers.Conv2D(dim, kernel_size=3, groups=dim, name="depth_wise_conv")
self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm")
self.point_wise_conv1 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv1")
self.act = get_tf_activation("gelu")
self.point_wise_conv2 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv2")
self.drop_path = keras.layers.Identity(name="drop_path")
def build(self, input_shape=None):
if self.built:
return
self.layer_scale = self.add_weight(
name="layer_scale",
shape=(self.dim),
initializer="ones",
trainable=True,
)
if getattr(self, "depth_wise_conv", None) is not None:
with tf.name_scope(self.depth_wise_conv.name):
self.depth_wise_conv.build((None, None, None, self.dim))
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build((None, None, None, self.dim))
if getattr(self, "point_wise_conv1", None) is not None:
with tf.name_scope(self.point_wise_conv1.name):
self.point_wise_conv1.build(self.dim)
if getattr(self, "point_wise_conv2", None) is not None:
with tf.name_scope(self.point_wise_conv2.name):
self.point_wise_conv2.build(self.dim)
if getattr(self, "drop_path", None) is not None:
with tf.name_scope(self.drop_path.name):
self.drop_path.build(None)
self.built = True
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
input = x
x = self.pad(x)
x = self.depth_wise_conv(x)
x = self.norm(x, training=training)
x = self.point_wise_conv1(x)
x = self.act(x)
x = self.point_wise_conv2(x)
x = input + self.drop_path(self.layer_scale * x, training=training)
return x
class TFSwiftFormerEncoderBlock(keras.layers.Layer):
"""
SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2)
SwiftFormerEfficientAdditiveAttention, and (3) MLP block.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels,height, width]`
"""
def __init__(self, config: SwiftFormerConfig, dim: int, drop_path: float = 0.0, **kwargs):
super().__init__(**kwargs)
layer_scale_init_value = config.layer_scale_init_value
use_layer_scale = config.use_layer_scale
self.local_representation = TFSwiftFormerLocalRepresentation(config, dim=dim, name="local_representation")
self.attn = TFSwiftFormerEfficientAdditiveAttention(config, dim=dim, name="attn")
self.linear = TFSwiftFormerMlp(config, in_features=dim, name="linear")
self.drop_path = TFSwiftFormerDropPath(config) if drop_path > 0.0 else keras.layers.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.dim = dim
self.layer_scale_init_value = layer_scale_init_value
def build(self, input_shape=None):
if self.built:
return
self.layer_scale_1 = self.add_weight(
name="layer_scale_1",
shape=self.dim,
initializer=keras.initializers.constant(self.layer_scale_init_value),
trainable=True,
)
self.layer_scale_2 = self.add_weight(
name="layer_scale_2",
shape=self.dim,
initializer=keras.initializers.constant(self.layer_scale_init_value),
trainable=True,
)
if getattr(self, "local_representation", None) is not None:
with tf.name_scope(self.local_representation.name):
self.local_representation.build(None)
if getattr(self, "attn", None) is not None:
with tf.name_scope(self.attn.name):
self.attn.build(None)
if getattr(self, "linear", None) is not None:
with tf.name_scope(self.linear.name):
self.linear.build(None)
self.built = True
def call(self, x: tf.Tensor, training: bool = False):
x = self.local_representation(x, training=training)
batch_size, height, width, channels = x.shape
res = tf.reshape(x, [-1, height * width, channels])
res = self.attn(res)
res = tf.reshape(res, [-1, height, width, channels])
if self.use_layer_scale:
x = x + self.drop_path(self.layer_scale_1 * res, training=training)
x = x + self.drop_path(self.layer_scale_2 * self.linear(x), training=training)
else:
x = x + self.drop_path(res, training=training)
x = x + self.drop_path(self.linear(x), training=training)
return x
class TFSwiftFormerStage(keras.layers.Layer):
"""
A Swiftformer stage consisting of a series of `SwiftFormerConvEncoder` blocks and a final
`SwiftFormerEncoderBlock`.
Input: tensor in shape `[batch_size, channels, height, width]`
Output: tensor in shape `[batch_size, channels, height, width]`
"""
def __init__(self, config: SwiftFormerConfig, index: int, **kwargs) -> None:
super().__init__(**kwargs)
layer_depths = config.depths
dim = config.embed_dims[index]
depth = layer_depths[index]
self.blocks = []
for block_idx in range(depth):
block_dpr = config.drop_path_rate * (block_idx + sum(layer_depths[:index])) / (sum(layer_depths) - 1)
if depth - block_idx <= 1:
self.blocks.append(
TFSwiftFormerEncoderBlock(config, dim=dim, drop_path=block_dpr, name=f"blocks_._{block_idx}")
)
else:
self.blocks.append(TFSwiftFormerConvEncoder(config, dim=dim, name=f"blocks_._{block_idx}"))
def call(self, input: tf.Tensor, training: bool = False) -> tf.Tensor:
for i, block in enumerate(self.blocks):
input = block(input, training=training)
return input
def build(self, input_shape=None):
for layer in self.blocks:
with tf.name_scope(layer.name):
layer.build(None)
class TFSwiftFormerEncoder(keras.layers.Layer):
def __init__(self, config: SwiftFormerConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.config = config
embed_dims = config.embed_dims
downsamples = config.downsamples
layer_depths = config.depths
# Transformer model
self.network = []
name_i = 0
for i in range(len(layer_depths)):
stage = TFSwiftFormerStage(config, index=i, name=f"network_._{name_i}")
self.network.append(stage)
name_i += 1
if i >= len(layer_depths) - 1:
break
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
# downsampling between two stages
self.network.append(TFSwiftFormerEmbeddings(config, index=i, name=f"network_._{name_i}"))
name_i += 1
self.gradient_checkpointing = False
def call(
self,
hidden_states: tf.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[tuple, TFBaseModelOutputWithNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
all_hidden_states = (hidden_states,) if output_hidden_states else None
for i, block in enumerate(self.network):
hidden_states = block(hidden_states, training=training)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2])
if all_hidden_states:
all_hidden_states = tuple(tf.transpose(s, perm=[0, 3, 1, 2]) for s in all_hidden_states)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return TFBaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
)
def build(self, input_shape=None):
for layer in self.network:
with tf.name_scope(layer.name):
layer.build(None)
class TFSwiftFormerPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SwiftFormerConfig
base_model_prefix = "swiftformer"
main_input_name = "pixel_values"
TFSWIFTFORMER_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
<Tip>
TF 2.0 models accepts two formats as inputs:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is useful when using [`keras.Model.fit`] method which currently requires having all the
tensors in the first argument of the model call function: `model(inputs)`.
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the
first positional argument :
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
</Tip>
Parameters:
config ([`SwiftFormerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
TFSWIFTFORMER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
training (`bool`, *optional*, defaults to `False`):
Whether or not to run the model in training mode.
"""
@keras_serializable
class TFSwiftFormerMainLayer(keras.layers.Layer):
config_class = SwiftFormerConfig
def __init__(self, config: SwiftFormerConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.patch_embed = TFSwiftFormerPatchEmbedding(config, name="patch_embed")
self.encoder = TFSwiftFormerEncoder(config, name="encoder")
@unpack_inputs
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple, TFBaseModelOutputWithNoAttention]:
r""" """
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TF 2.0 image layers can't use NCHW format when running on CPU.
# We transpose to NHWC format and then transpose back after the full forward pass.
# (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1])
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
embedding_output = self.patch_embed(pixel_values, training=training)
encoder_outputs = self.encoder(
embedding_output,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
return tuple(v for v in encoder_outputs if v is not None)
return TFBaseModelOutputWithNoAttention(
last_hidden_state=encoder_outputs.last_hidden_state,
hidden_states=encoder_outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "patch_embed", None) is not None:
with tf.name_scope(self.patch_embed.name):
self.patch_embed.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
self.built = True
@add_start_docstrings(
"The bare TFSwiftFormer Model transformer outputting raw hidden-states without any specific head on top.",
TFSWIFTFORMER_START_DOCSTRING,
)
class TFSwiftFormerModel(TFSwiftFormerPreTrainedModel):
def __init__(self, config: SwiftFormerConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.swiftformer = TFSwiftFormerMainLayer(config, name="swiftformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFBaseModelOutputWithNoAttention, Tuple[tf.Tensor]]:
outputs = self.swiftformer(
pixel_values=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "swiftformer", None) is not None:
with tf.name_scope(self.swiftformer.name):
self.swiftformer.build(None)
self.built = True
@add_start_docstrings(
"""
TFSwiftFormer Model transformer with an image classification head on top (e.g. for ImageNet).
""",
TFSWIFTFORMER_START_DOCSTRING,
)
class TFSwiftFormerForImageClassification(TFSwiftFormerPreTrainedModel):
def __init__(self, config: SwiftFormerConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.num_labels = config.num_labels
self.swiftformer = TFSwiftFormerMainLayer(config, name="swiftformer")
# Classifier head
self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm")
self.head = (
keras.layers.Dense(self.num_labels, name="head")
if self.num_labels > 0
else keras.layers.Identity(name="head")
)
self.dist_head = (
keras.layers.Dense(self.num_labels, name="dist_head")
if self.num_labels > 0
else keras.layers.Identity(name="dist_head")
)
def hf_compute_loss(self, labels, logits):
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == tf.int64 or labels.dtype == tf.int32):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = keras.losses.MSE
if self.num_labels == 1:
loss = loss_fct(labels.squeeze(), logits.squeeze())
else:
loss = loss_fct(labels, logits)
elif self.config.problem_type == "single_label_classification":
loss_fct = keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=keras.losses.Reduction.NONE
)
loss = loss_fct(labels, logits)
elif self.config.problem_type == "multi_label_classification":
loss_fct = keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=keras.losses.Reduction.NONE,
)
loss = loss_fct(labels, logits)
else:
loss = None
return loss
@unpack_inputs
@add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
labels: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[tuple, TFImageClassifierOutputWithNoAttention]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# run base model
outputs = self.swiftformer(
pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = outputs.last_hidden_state if return_dict else outputs[0]
sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1])
# run classification head
sequence_output = self.norm(sequence_output, training=training)
sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2])
_, num_channels, height, width = sequence_output.shape
sequence_output = tf.reshape(sequence_output, [-1, num_channels, height * width])
sequence_output = tf.reduce_mean(sequence_output, axis=-1)
cls_out = self.head(sequence_output)
distillation_out = self.dist_head(sequence_output)
logits = (cls_out + distillation_out) / 2
# calculate loss
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TFImageClassifierOutputWithNoAttention(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "swiftformer", None) is not None:
with tf.name_scope(self.swiftformer.name):
self.swiftformer.build(None)
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build((None, None, None, self.config.embed_dims[-1]))
if getattr(self, "head", None) is not None:
with tf.name_scope(self.head.name):
self.head.build(self.config.embed_dims[-1])
if getattr(self, "dist_head", None) is not None:
with tf.name_scope(self.dist_head.name):
self.dist_head.build(self.config.embed_dims[-1])
self.built = True