/
modeling_timm_backbone.py
158 lines (126 loc) 路 6.46 KB
/
modeling_timm_backbone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...utils import is_timm_available, is_torch_available, requires_backends
from ...utils.backbone_utils import BackboneMixin
from .configuration_timm_backbone import TimmBackboneConfig
if is_timm_available():
import timm
if is_torch_available():
from torch import Tensor
class TimmBackbone(PreTrainedModel, BackboneMixin):
"""
Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the
other models in the library keeping the same API.
"""
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
config_class = TimmBackboneConfig
def __init__(self, config, **kwargs):
requires_backends(self, "timm")
super().__init__(config)
self.config = config
if config.backbone is None:
raise ValueError("backbone is not set in the config. Please set it to a timm model name.")
if config.backbone not in timm.list_models():
raise ValueError(f"backbone {config.backbone} is not supported by timm.")
if hasattr(config, "out_features") and config.out_features is not None:
raise ValueError("out_features is not supported by TimmBackbone. Please use out_indices instead.")
pretrained = getattr(config, "use_pretrained_backbone", None)
if pretrained is None:
raise ValueError("use_pretrained_backbone is not set in the config. Please set it to True or False.")
# We just take the final layer by default. This matches the default for the transformers models.
out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,)
self._backbone = timm.create_model(
config.backbone,
pretrained=pretrained,
# This is currently not possible for transformer architectures.
features_only=config.features_only,
in_chans=config.num_channels,
out_indices=out_indices,
**kwargs,
)
# Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively
if getattr(config, "freeze_batch_norm_2d", False):
self.freeze_batch_norm_2d()
# These are used to control the output of the model when called. If output_hidden_states is True, then
# return_layers is modified to include all layers.
self._return_layers = self._backbone.return_layers
self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
super()._init_backbone(config)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
requires_backends(cls, ["vision", "timm"])
from ...models.timm_backbone import TimmBackboneConfig
config = kwargs.pop("config", TimmBackboneConfig())
use_timm = kwargs.pop("use_timm_backbone", True)
if not use_timm:
raise ValueError("use_timm_backbone must be True for timm backbones")
num_channels = kwargs.pop("num_channels", config.num_channels)
features_only = kwargs.pop("features_only", config.features_only)
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
out_indices = kwargs.pop("out_indices", config.out_indices)
config = TimmBackboneConfig(
backbone=pretrained_model_name_or_path,
num_channels=num_channels,
features_only=features_only,
use_pretrained_backbone=use_pretrained_backbone,
out_indices=out_indices,
)
return super()._from_config(config, **kwargs)
def freeze_batch_norm_2d(self):
timm.layers.freeze_batch_norm_2d(self._backbone)
def unfreeze_batch_norm_2d(self):
timm.layers.unfreeze_batch_norm_2d(self._backbone)
def _init_weights(self, module):
"""
Empty init weights function to ensure compatibility of the class in the library.
"""
pass
def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[BackboneOutput, Tuple[Tensor, ...]]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if output_attentions:
raise ValueError("Cannot output attentions for timm backbones at the moment")
if output_hidden_states:
# We modify the return layers to include all the stages of the backbone
self._backbone.return_layers = self._all_layers
hidden_states = self._backbone(pixel_values, **kwargs)
self._backbone.return_layers = self._return_layers
feature_maps = tuple(hidden_states[i] for i in self.out_indices)
else:
feature_maps = self._backbone(pixel_values, **kwargs)
hidden_states = None
feature_maps = tuple(feature_maps)
hidden_states = tuple(hidden_states) if hidden_states is not None else None
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output = output + (hidden_states,)
return output
return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None)