-
Notifications
You must be signed in to change notification settings - Fork 140
/
registry.py
346 lines (308 loc) · 12.5 KB
/
registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""
Code related to the PyTorch model registry for easily creating models.
"""
from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union
from torch.nn import Module
from merge_args import merge_args
from sparseml.pytorch.utils import load_model
from sparseml.utils import parse_optimization_str, wrapper_decorator
from sparseml.utils.frameworks import PYTORCH_FRAMEWORK
from sparsezoo import Zoo
from sparsezoo.objects import Model
__all__ = [
"ModelRegistry",
]
"""
Simple named tuple object to store model info
"""
_ModelAttributes = NamedTuple(
"_ModelAttributes",
[
("input_shape", Any),
("domain", str),
("sub_domain", str),
("architecture", str),
("sub_architecture", str),
("default_dataset", str),
("default_desc", str),
("repo_source", str),
("ignore_error_tensors", List[str]),
("args", Dict[str, Tuple[str, Any]]),
],
)
class ModelRegistry(object):
"""
Registry class for creating models
"""
_CONSTRUCTORS = {} # type: Dict[str, Callable]
_ATTRIBUTES = {} # type: Dict[str, _ModelAttributes]
@staticmethod
def available_keys() -> List[str]:
"""
:return: the keys (models) currently available in the registry
"""
return list(ModelRegistry._CONSTRUCTORS.keys())
@staticmethod
def create(
key: str,
pretrained: Union[bool, str] = False,
pretrained_path: str = None,
pretrained_dataset: str = None,
load_strict: bool = True,
ignore_error_tensors: List[str] = None,
**kwargs,
) -> Module:
"""
Create a new model for the given key
:param key: the model key (name) to create
:param pretrained: True to load pretrained weights; to load a specific version
give a string with the name of the version (pruned-moderate, base), default None
:param pretrained_path: A model file path to load into the created model
:param pretrained_dataset: The dataset to load for the model
:param load_strict: True to make sure all states are found in and
loaded in model, False otherwise; default True
:param ignore_error_tensors: tensors to ignore if there are errors in loading
:param kwargs: any keyword args to supply to the model constructor
:return: the instantiated model
"""
if key not in ModelRegistry._CONSTRUCTORS:
raise ValueError(
"key {} is not in the model registry; available: {}".format(
key, ModelRegistry._CONSTRUCTORS
)
)
return ModelRegistry._CONSTRUCTORS[key](
pretrained=pretrained,
pretrained_path=pretrained_path,
pretrained_dataset=pretrained_dataset,
load_strict=load_strict,
ignore_error_tensors=ignore_error_tensors,
**kwargs,
)
@staticmethod
def create_zoo_model(
key: str,
pretrained: Union[bool, str] = True,
pretrained_dataset: str = None,
) -> Model:
"""
Create a sparsezoo Model for the desired model in the zoo
:param key: the model key (name) to retrieve
:param pretrained: True to load pretrained weights; to load a specific version
give a string with the name of the version (optim, optim-perf), default True
:param pretrained_dataset: The dataset to load for the model
:return: the sparsezoo Model reference for the given model
"""
if key not in ModelRegistry._CONSTRUCTORS:
raise ValueError(
"key {} is not in the model registry; available: {}".format(
key, ModelRegistry._CONSTRUCTORS
)
)
attributes = ModelRegistry._ATTRIBUTES[key]
optim_name, optim_category, optim_target = parse_optimization_str(
pretrained if isinstance(pretrained, str) else attributes.default_desc
)
return Zoo.load_model(
attributes.domain,
attributes.sub_domain,
attributes.architecture,
attributes.sub_architecture,
PYTORCH_FRAMEWORK,
attributes.repo_source,
attributes.default_dataset
if pretrained_dataset is None
else pretrained_dataset,
None,
optim_name,
optim_category,
optim_target,
)
@staticmethod
def input_shape(key: str) -> Any:
"""
:param key: the model key (name) to create
:return: the specified input shape for the model
"""
if key not in ModelRegistry._CONSTRUCTORS:
raise ValueError(
"key {} is not in the model registry; available: {}".format(
key, ModelRegistry._CONSTRUCTORS
)
)
return ModelRegistry._ATTRIBUTES[key].input_shape
@staticmethod
def register(
key: Union[str, List[str]],
input_shape: Any,
domain: str,
sub_domain: str,
architecture: str,
sub_architecture: str,
default_dataset: str,
default_desc: str,
repo_source: str = "sparseml",
def_ignore_error_tensors: List[str] = None,
desc_args: Dict[str, Tuple[str, Any]] = None,
):
"""
Register a model with the registry. Should be used as a decorator
:param key: the model key (name) to create
:param input_shape: the specified input shape for the model
:param domain: the domain the model belongs to; ex: cv, nlp, etc
:param sub_domain: the sub domain the model belongs to;
ex: classification, detection, etc
:param architecture: the architecture the model belongs to;
ex: resnet, mobilenet, etc
:param sub_architecture: the sub architecture the model belongs to;
ex: 50, 101, etc
:param default_dataset: the dataset to use by default for loading
pretrained if not supplied
:param default_desc: the description to use by default for loading
pretrained if not supplied
:param repo_source: the source repo for the model, default is sparseml
:param def_ignore_error_tensors: tensors to ignore if there are
errors in loading
:param desc_args: args that should be changed based on the description
:return: the decorator
"""
if not isinstance(key, List):
key = [key]
def decorator(const_func):
wrapped_constructor = ModelRegistry._registered_wrapper(key[0], const_func)
ModelRegistry.register_wrapped_model_constructor(
wrapped_constructor,
key,
input_shape,
domain,
sub_domain,
architecture,
sub_architecture,
default_dataset,
default_desc,
repo_source,
def_ignore_error_tensors,
desc_args,
)
return wrapped_constructor
return decorator
@staticmethod
def register_wrapped_model_constructor(
wrapped_constructor: Callable,
key: Union[str, List[str]],
input_shape: Any,
domain: str,
sub_domain: str,
architecture: str,
sub_architecture: str,
default_dataset: str,
default_desc: str,
repo_source: str,
def_ignore_error_tensors: List[str] = None,
desc_args: Dict[str, Tuple[str, Any]] = None,
):
"""
Register a model with the registry from a model constructor or provider function
:param wrapped_constructor: Model constructor wrapped to be compatible
by call from ModelRegistry.create should have pretrained, pretrained_path,
pretrained_dataset, load_strict, ignore_error_tensors, and **kwargs as
arguments
:param key: the model key (name) to create
:param input_shape: the specified input shape for the model
:param domain: the domain the model belongs to; ex: cv, nlp, etc
:param sub_domain: the sub domain the model belongs to;
ex: classification, detection, etc
:param architecture: the architecture the model belongs to;
ex: resnet, mobilenet, etc
:param sub_architecture: the sub architecture the model belongs to;
ex: 50, 101, etc
:param default_dataset: the dataset to use by default for loading
pretrained if not supplied
:param default_desc: the description to use by default for loading
pretrained if not supplied
:param repo_source: the source repo for the model; ex: sparseml, torchvision
:param def_ignore_error_tensors: tensors to ignore if there are
errors in loading
:param desc_args: args that should be changed based on the description
:return: The constructor wrapper registered with the registry
"""
if not isinstance(key, List):
key = [key]
for r_key in key:
if r_key in ModelRegistry._CONSTRUCTORS:
raise ValueError("key {} is already registered".format(key))
ModelRegistry._CONSTRUCTORS[r_key] = wrapped_constructor
ModelRegistry._ATTRIBUTES[r_key] = _ModelAttributes(
input_shape,
domain,
sub_domain,
architecture,
sub_architecture,
default_dataset,
default_desc,
repo_source,
def_ignore_error_tensors,
desc_args,
)
@staticmethod
def _registered_wrapper(
key: str,
const_func: Callable,
):
@merge_args(const_func)
@wrapper_decorator(const_func)
def wrapper(
pretrained_path: str = None,
pretrained: Union[bool, str] = False,
pretrained_dataset: str = None,
load_strict: bool = True,
ignore_error_tensors: List[str] = None,
*args,
**kwargs,
):
"""
:param pretrained_path: A path to the pretrained weights to load,
if provided will override the pretrained param
:param pretrained: True to load the default pretrained weights,
a string to load a specific pretrained weight
(ex: base, optim, optim-perf),
or False to not load any pretrained weights
:param pretrained_dataset: The dataset to load pretrained weights for
(ex: imagenet, mnist, etc).
If not supplied will default to the one preconfigured for the model.
:param load_strict: True to raise an error on issues with state dict
loading from pretrained_path or pretrained, False to ignore
:param ignore_error_tensors: Tensors to ignore while checking the state dict
for weights loaded from pretrained_path or pretrained
"""
attributes = ModelRegistry._ATTRIBUTES[key]
if attributes.args and pretrained in attributes.args:
kwargs[attributes.args[pretrained][0]] = attributes.args[pretrained][1]
model = const_func(*args, **kwargs)
ignore = []
if ignore_error_tensors:
ignore.extend(ignore_error_tensors)
elif attributes.ignore_error_tensors:
ignore.extend(attributes.ignore_error_tensors)
if isinstance(pretrained, str):
if pretrained.lower() == "true":
pretrained = True
elif pretrained.lower() in ["false", "none"]:
pretrained = False
if pretrained_path:
load_model(pretrained_path, model, load_strict, ignore)
elif pretrained:
zoo_model = ModelRegistry.create_zoo_model(
key, pretrained, pretrained_dataset
)
try:
paths = zoo_model.download_framework_files(extensions=[".pth"])
load_model(paths[0], model, load_strict, ignore)
except Exception as ex:
# try one more time with overwrite on in case file was corrupted
paths = zoo_model.download_framework_files(
overwrite=True, extensions=[".pth"]
)
load_model(paths[0], model, load_strict, ignore)
return model
return wrapper