Skip to content

Commit

Permalink
[Enhance] Add support of TorchVision's Model Registration API (#793)
Browse files Browse the repository at this point in the history
* enhance get_torchvision_model

* remove mmcv
  • Loading branch information
HAOCHENYE committed Dec 11, 2022
1 parent be0bc3a commit d876d4e
Show file tree
Hide file tree
Showing 2 changed files with 464 additions and 9 deletions.
62 changes: 53 additions & 9 deletions mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mmengine.fileio import load as load_file
from mmengine.logging import print_log
from mmengine.model import BaseTTAModel, is_model_wrapper
from mmengine.utils import deprecated_function, mkdir_or_exist
from mmengine.utils import deprecated_function, digit_version, mkdir_or_exist
from mmengine.utils.dl_utils import load_url

# `MMENGINE_HOME` is the highest priority directory to save checkpoints
Expand Down Expand Up @@ -113,14 +113,58 @@ def load(module, prefix=''):


def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
if digit_version(torchvision.__version__) < digit_version('0.13.0a0'):
model_urls = dict()
# When the version of torchvision is lower than 0.13, the model url is
# not declared in `torchvision.model.__init__.py`, so we need to
# iterate through `torchvision.models.__path__` to get the url for each
# model.
for _, name, ispkg in pkgutil.walk_packages(
torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
else:
# Since torchvision bumps to v0.13, the weight loading logic,
# model keys and model urls have been changed. Here the URLs of old
# version is loaded to avoid breaking back compatibility. If the
# torchvision version>=0.13.0, new URLs will be added. Users can get
# the resnet50 checkpoint by setting 'resnet50.imagent1k_v1',
# 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config.
json_path = osp.join(mmengine.__path__[0], 'hub/torchvision_0.12.json')
model_urls = mmengine.load(json_path)
if digit_version(torchvision.__version__) < digit_version('0.14.0a0'):
weights_list = [
cls for cls_name, cls in torchvision.models.__dict__.items()
if cls_name.endswith('_Weights')
]
else:
weights_list = [
torchvision.models.get_model_weights(model)
for model in torchvision.models.list_models(torchvision.models)
]

for cls in weights_list:
# The name of torchvision model weights classes ends with
# `_Weights` such as `ResNet18_Weights`. However, some model weight
# classes, such as `MNASNet0_75_Weights` does not have any urls in
# torchvision 0.13.0 and cannot be iterated. Here we simply check
# `DEFAULT` attribute to ensure the class is not empty.
if not hasattr(cls, 'DEFAULT'):
continue
# Since `cls.DEFAULT` can not be accessed by iterating cls, we set
# default urls explicitly.
cls_name = cls.__name__
cls_key = cls_name.replace('_Weights', '').lower()
model_urls[f'{cls_key}.default'] = cls.DEFAULT.url
for weight_enum in cls:
cls_key = cls_name.replace('_Weights', '').lower()
cls_key = f'{cls_key}.{weight_enum.name.lower()}'
model_urls[cls_key] = weight_enum.url

return model_urls


Expand Down

0 comments on commit d876d4e

Please sign in to comment.