Skip to content

Commit

Permalink
[Fix] Avoid creating a new logger in PretrainedInit (#791)
Browse files Browse the repository at this point in the history
* use current logger

* remove get_current_instance

* remove logger parameter at weight_init

* remove elif branch
  • Loading branch information
xiexinch committed Dec 12, 2022
1 parent d876d4e commit 504fdc3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
11 changes: 5 additions & 6 deletions mmengine/model/weight_init.py
Expand Up @@ -8,7 +8,7 @@
import torch.nn as nn
from torch import Tensor

from mmengine.logging import MMLogger, print_log
from mmengine.logging import print_log
from mmengine.registry import WEIGHT_INITIALIZERS, build_from_cfg


Expand Down Expand Up @@ -481,22 +481,21 @@ def __call__(self, module):
from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix,
load_checkpoint,
load_state_dict)
logger = MMLogger.get_instance('mmengine')
if self.prefix is None:
print_log(f'load model from: {self.checkpoint}', logger=logger)
print_log(f'load model from: {self.checkpoint}', logger='current')
load_checkpoint(
module,
self.checkpoint,
map_location=self.map_location,
strict=False,
logger=logger)
logger='current')
else:
print_log(
f'load {self.prefix} in model from: {self.checkpoint}',
logger=logger)
logger='current')
state_dict = _load_checkpoint_with_prefix(
self.prefix, self.checkpoint, map_location=self.map_location)
load_state_dict(module, state_dict, strict=False, logger=logger)
load_state_dict(module, state_dict, strict=False, logger='current')

if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())
Expand Down
5 changes: 2 additions & 3 deletions mmengine/runner/checkpoint.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import io
import logging
import os
import os.path as osp
import pkgutil
Expand Down Expand Up @@ -106,10 +107,8 @@ def load(module, prefix=''):
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
print_log(err_msg, logger=logger, level=logging.WARNING)


def get_torchvision_models():
Expand Down

0 comments on commit 504fdc3

Please sign in to comment.