-
Notifications
You must be signed in to change notification settings - Fork 26.4k
/
modeling_tf_resnet.py
501 lines (418 loc) · 20.1 KB
/
modeling_tf_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
# coding=utf-8
# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow ResNet model."""
from typing import Dict, Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import ACT2FN
from ...modeling_tf_outputs import (
TFBaseModelOutputWithNoAttention,
TFBaseModelOutputWithPoolingAndNoAttention,
TFImageClassifierOutputWithNoAttention,
)
from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs
from ...tf_utils import shape_list
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_resnet import ResNetConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "ResNetConfig"
_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/resnet-50"
_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "microsoft/resnet-50"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat"
TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/resnet-50",
# See all resnet models at https://huggingface.co/models?filter=resnet
]
class TFResNetConvLayer(tf.keras.layers.Layer):
def __init__(
self, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu", **kwargs
) -> None:
super().__init__(**kwargs)
self.pad_value = kernel_size // 2
self.conv = tf.keras.layers.Conv2D(
out_channels, kernel_size=kernel_size, strides=stride, padding="valid", use_bias=False, name="convolution"
)
# Use same default momentum and epsilon as PyTorch equivalent
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
self.activation = ACT2FN[activation] if activation is not None else tf.keras.layers.Activation("linear")
def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor:
# Pad to match that done in the PyTorch Conv2D model
height_pad = width_pad = (self.pad_value, self.pad_value)
hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)])
hidden_state = self.conv(hidden_state)
return hidden_state
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = self.convolution(hidden_state)
hidden_state = self.normalization(hidden_state, training=training)
hidden_state = self.activation(hidden_state)
return hidden_state
class TFResNetEmbeddings(tf.keras.layers.Layer):
"""
ResNet Embeddings (stem) composed of a single aggressive convolution.
"""
def __init__(self, config: ResNetConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.embedder = TFResNetConvLayer(
config.embedding_size,
kernel_size=7,
stride=2,
activation=config.hidden_act,
name="embedder",
)
self.pooler = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="valid", name="pooler")
self.num_channels = config.num_channels
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
_, _, _, num_channels = shape_list(pixel_values)
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
hidden_state = pixel_values
hidden_state = self.embedder(hidden_state)
hidden_state = tf.pad(hidden_state, [[0, 0], [1, 1], [1, 1], [0, 0]])
hidden_state = self.pooler(hidden_state)
return hidden_state
class TFResNetShortCut(tf.keras.layers.Layer):
"""
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
"""
def __init__(self, out_channels: int, stride: int = 2, **kwargs) -> None:
super().__init__(**kwargs)
self.convolution = tf.keras.layers.Conv2D(
out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
)
# Use same default momentum and epsilon as PyTorch equivalent
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = x
hidden_state = self.convolution(hidden_state)
hidden_state = self.normalization(hidden_state, training=training)
return hidden_state
class TFResNetBasicLayer(tf.keras.layers.Layer):
"""
A classic ResNet's residual layer composed by two `3x3` convolutions.
"""
def __init__(
self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu", **kwargs
) -> None:
super().__init__(**kwargs)
should_apply_shortcut = in_channels != out_channels or stride != 1
self.conv1 = TFResNetConvLayer(out_channels, stride=stride, name="layer.0")
self.conv2 = TFResNetConvLayer(out_channels, activation=None, name="layer.1")
self.shortcut = (
TFResNetShortCut(out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else tf.keras.layers.Activation("linear", name="shortcut")
)
self.activation = ACT2FN[activation]
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
residual = hidden_state
hidden_state = self.conv1(hidden_state, training=training)
hidden_state = self.conv2(hidden_state, training=training)
residual = self.shortcut(residual, training=training)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class TFResNetBottleNeckLayer(tf.keras.layers.Layer):
"""
A classic ResNet's bottleneck layer composed by three `3x3` convolutions.
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
activation: str = "relu",
reduction: int = 4,
**kwargs
) -> None:
super().__init__(**kwargs)
should_apply_shortcut = in_channels != out_channels or stride != 1
reduces_channels = out_channels // reduction
self.conv0 = TFResNetConvLayer(reduces_channels, kernel_size=1, name="layer.0")
self.conv1 = TFResNetConvLayer(reduces_channels, stride=stride, name="layer.1")
self.conv2 = TFResNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.2")
self.shortcut = (
TFResNetShortCut(out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else tf.keras.layers.Activation("linear", name="shortcut")
)
self.activation = ACT2FN[activation]
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
residual = hidden_state
hidden_state = self.conv0(hidden_state, training=training)
hidden_state = self.conv1(hidden_state, training=training)
hidden_state = self.conv2(hidden_state, training=training)
residual = self.shortcut(residual, training=training)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class TFResNetStage(tf.keras.layers.Layer):
"""
A ResNet stage composed of stacked layers.
"""
def __init__(
self, config: ResNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs
) -> None:
super().__init__(**kwargs)
layer = TFResNetBottleNeckLayer if config.layer_type == "bottleneck" else TFResNetBasicLayer
layers = [layer(in_channels, out_channels, stride=stride, activation=config.hidden_act, name="layers.0")]
layers += [
layer(out_channels, out_channels, activation=config.hidden_act, name=f"layers.{i + 1}")
for i in range(depth - 1)
]
self.stage_layers = layers
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
for layer in self.stage_layers:
hidden_state = layer(hidden_state, training=training)
return hidden_state
class TFResNetEncoder(tf.keras.layers.Layer):
def __init__(self, config: ResNetConfig, **kwargs) -> None:
super().__init__(**kwargs)
# based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
self.stages = [
TFResNetStage(
config,
config.embedding_size,
config.hidden_sizes[0],
stride=2 if config.downsample_in_first_stage else 1,
depth=config.depths[0],
name="stages.0",
)
]
for i, (in_channels, out_channels, depth) in enumerate(
zip(config.hidden_sizes, config.hidden_sizes[1:], config.depths[1:])
):
self.stages.append(TFResNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i + 1}"))
def call(
self,
hidden_state: tf.Tensor,
output_hidden_states: bool = False,
return_dict: bool = True,
training: bool = False,
) -> TFBaseModelOutputWithNoAttention:
hidden_states = () if output_hidden_states else None
for stage_module in self.stages:
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
hidden_state = stage_module(hidden_state, training=training)
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
class TFResNetPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ResNetConfig
base_model_prefix = "resnet"
main_input_name = "pixel_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network. Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 224, 224), dtype=tf.float32)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
}
]
)
def serving(self, inputs):
output = self.call(inputs)
return self.serving_output(output)
RESNET_START_DOCSTRING = r"""
This model is a TensorFlow
[tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
regular TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and
behavior.
Parameters:
config ([`ResNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
RESNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
[`AutoFeatureExtractor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@keras_serializable
class TFResNetMainLayer(tf.keras.layers.Layer):
config_class = ResNetConfig
def __init__(self, config: ResNetConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.config = config
self.embedder = TFResNetEmbeddings(config, name="embedder")
self.encoder = TFResNetEncoder(config, name="encoder")
self.pooler = tf.keras.layers.GlobalAveragePooling2D(keepdims=True)
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TF 2.0 image layers can't use NCHW format when running on CPU.
# We transpose to NHWC format and then transpose back after the full forward pass.
# (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1])
embedding_output = self.embedder(pixel_values, training=training)
encoder_outputs = self.encoder(
embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
)
last_hidden_state = encoder_outputs[0]
pooled_output = self.pooler(last_hidden_state)
# Transpose all the outputs to the NCHW format
# (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
last_hidden_state = tf.transpose(last_hidden_state, (0, 3, 1, 2))
pooled_output = tf.transpose(pooled_output, (0, 3, 1, 2))
hidden_states = ()
for hidden_state in encoder_outputs[1:]:
hidden_states = hidden_states + tuple(tf.transpose(h, (0, 3, 1, 2)) for h in hidden_state)
if not return_dict:
return (last_hidden_state, pooled_output) + hidden_states
hidden_states = hidden_states if output_hidden_states else None
return TFBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=hidden_states,
)
@add_start_docstrings(
"The bare ResNet model outputting raw features without any specific head on top.",
RESNET_START_DOCSTRING,
)
class TFResNetModel(TFResNetPreTrainedModel):
def __init__(self, config: ResNetConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.resnet = TFResNetMainLayer(config=config, name="resnet")
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPoolingAndNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
resnet_outputs = self.resnet(
pixel_values=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return resnet_outputs
def serving_output(
self, output: TFBaseModelOutputWithPoolingAndNoAttention
) -> TFBaseModelOutputWithPoolingAndNoAttention:
# hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=output.hidden_states,
)
@add_start_docstrings(
"""
ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
""",
RESNET_START_DOCSTRING,
)
class TFResNetForImageClassification(TFResNetPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: ResNetConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.num_labels = config.num_labels
self.resnet = TFResNetMainLayer(config, name="resnet")
# classification head
self.classifier_layer = (
tf.keras.layers.Dense(config.num_labels, name="classifier.1")
if config.num_labels > 0
else tf.keras.layers.Activation("linear", name="classifier.1")
)
def classifier(self, x: tf.Tensor) -> tf.Tensor:
x = tf.keras.layers.Flatten()(x)
logits = self.classifier_layer(x)
return logits
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=TFImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor = None,
labels: tf.Tensor = None,
output_hidden_states: bool = None,
return_dict: bool = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFImageClassifierOutputWithNoAttention]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.resnet(
pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(pooled_output)
loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not return_dict:
output = (logits,) + outputs[2:]
return (loss,) + output if loss is not None else output
return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
def serving_output(self, output: TFImageClassifierOutputWithNoAttention) -> TFImageClassifierOutputWithNoAttention:
# hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFImageClassifierOutputWithNoAttention(logits=output.logits, hidden_states=output.hidden_states)