Skip to content

Commit

Permalink
add LazyConfig
Browse files Browse the repository at this point in the history
Reviewed By: rbgirshick

Differential Revision: D27666493

fbshipit-source-id: e48a9e2d55a3d93babaa90c53f625a664f933c51
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Apr 10, 2021
1 parent 38ff141 commit 65faeb4
Show file tree
Hide file tree
Showing 11 changed files with 369 additions and 32 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ exclude = build
per-file-ignores =
**/__init__.py:F401,F403,E402
**/configs/**.py:F401,E402
**/tests/config/**.py:F401,E402
tests/config/**.py:F401,E402
4 changes: 3 additions & 1 deletion detectron2/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from .compat import downgrade_config, upgrade_config
from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable
from .instantiate import instantiate, LazyCall
from .instantiate import instantiate
from .lazy import LazyCall, LazyConfig

__all__ = [
"CfgNode",
Expand All @@ -13,6 +14,7 @@
"configurable",
"instantiate",
"LazyCall",
"LazyConfig",
]


Expand Down
28 changes: 0 additions & 28 deletions detectron2/config/instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
from collections import abc
from typing import Any
from omegaconf import DictConfig

from detectron2.utils.registry import _convert_target_to_string, locate

Expand Down Expand Up @@ -81,30 +80,3 @@ def instantiate(cfg):
logger.error(f"Error when instantiating {cls_name}!")
raise
return cfg # return as-is if don't know what to do


class LazyCall:
"""
Wrap a callable so that when it's called, the call will not be execued,
but returns a dict that describes the call.
LazyCall object has to be called with only keyword arguments. Positional
arguments are not yet supported.
Examples:
::
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
layer_cfg.out_channels = 64
layer = instantiate(layer_cfg)
"""

def __init__(self, target):
if not (callable(target) or isinstance(target, (str, abc.Mapping))):
raise TypeError(
"target of LazyCall must be a callable or defines a callable! Got {target}"
)
self._target = target

def __call__(self, **kwargs):
kwargs["_target_"] = self._target
return DictConfig(content=kwargs, flags={"allow_objects": True})
290 changes: 290 additions & 0 deletions detectron2/config/lazy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import ast
import builtins
import importlib
import inspect
import logging
import os
import uuid
from collections import abc
from contextlib import contextmanager
from copy import deepcopy
from typing import List, Tuple, Union
import cloudpickle
import yaml
from omegaconf import DictConfig, ListConfig, OmegaConf

from detectron2.utils.file_io import PathManager
from detectron2.utils.registry import _convert_target_to_string

__all__ = ["LazyCall", "LazyConfig"]


class LazyCall:
"""
Wrap a callable so that when it's called, the call will not be execued,
but returns a dict that describes the call.
LazyCall object has to be called with only keyword arguments. Positional
arguments are not yet supported.
Examples:
::
from detectron2.config import instantiate, LazyCall
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
layer_cfg.out_channels = 64 # can edit it afterwards
layer = instantiate(layer_cfg)
"""

def __init__(self, target):
if not (callable(target) or isinstance(target, (str, abc.Mapping))):
raise TypeError(
"target of LazyCall must be a callable or defines a callable! Got {target}"
)
self._target = target

def __call__(self, **kwargs):
kwargs["_target_"] = self._target
return DictConfig(content=kwargs, flags={"allow_objects": True})


def _visit_dict_config(cfg, func):
"""
Apply func recursively to all DictConfig in cfg.
"""
if isinstance(cfg, DictConfig):
func(cfg)
for v in cfg.values():
_visit_dict_config(v, func)
elif isinstance(cfg, ListConfig):
for v in cfg:
_visit_dict_config(v, func)


def _validate_py_syntax(filename):
# see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
with PathManager.open(filename, "r") as f:
content = f.read()
try:
ast.parse(content)
except SyntaxError as e:
raise SyntaxError(f"Config file {filename} has syntax error!") from e


def _cast_to_config(obj):
# if given a dict, return DictConfig instead
if isinstance(obj, dict):
return DictConfig(obj, flags={"allow_objects": True})
return obj


_CFG_PACKAGE_NAME = "detectron2._cfg_loader"
"""
A namespace to put all imported config into.
"""


def _random_package_name(filename):
# generate a random package name when loading config files
return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)


@contextmanager
def _patch_import():
"""
Enhance relative import statements in config files, so that they:
1. locate files purely based on relative location, regardless of packages.
e.g. you can import file without having __init__
2. do not cache modules globally; modifications of module states has no side effect
3. support other storage system through PathManager
4. imported dict are turned into omegaconf.DictConfig automatically
"""
old_import = builtins.__import__

def find_relative_file(original_file, relative_import_path, level):
cur_file = os.path.dirname(original_file)
for _ in range(level - 1):
cur_file = os.path.dirname(cur_file)
cur_name = relative_import_path.lstrip(".")
for part in cur_name.split("."):
cur_file = os.path.join(cur_file, part)
# NOTE: directory import is not handled. Because then it's unclear
# if such import should produce python module or DictConfig. This can
# be discussed further if needed.
if not cur_file.endswith(".py"):
cur_file += ".py"
if not PathManager.isfile(cur_file):
raise ImportError(
f"Cannot import name {relative_import_path} from "
f"{original_file}: {cur_file} has to exist."
)
return cur_file

def new_import(name, globals=None, locals=None, fromlist=(), level=0):
if (
# Only deal with relative imports inside config files
level != 0
and globals is not None
and globals.get("__package__", "").startswith(_CFG_PACKAGE_NAME)
):
cur_file = find_relative_file(globals["__file__"], name, level)
_validate_py_syntax(cur_file)
spec = importlib.machinery.ModuleSpec(
_random_package_name(cur_file), None, origin=cur_file
)
module = importlib.util.module_from_spec(spec)
module.__file__ = cur_file
with PathManager.open(cur_file) as f:
content = f.read()
exec(compile(content, cur_file, "exec"), module.__dict__)
for name in fromlist: # turn imported dict into DictConfig automatically
val = _cast_to_config(module.__dict__[name])
module.__dict__[name] = val
return module
return old_import(name, globals, locals, fromlist=fromlist, level=level)

builtins.__import__ = new_import
yield new_import
builtins.__import__ = old_import


class LazyConfig:
"""
Provid methods to save, load, and overrides an omegaconf config object
which may contain definition of lazily-constructed objects.
"""

@staticmethod
def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
"""
Similar to :meth:`load()`, but load path relative to the caller's
source file.
This has the same functionality as a relative import, except that this method
accepts filename as a string, so more characters are allowed in the filename.
"""
caller_frame = inspect.stack()[1]
caller_fname = caller_frame[0].f_code.co_filename
assert caller_fname != "<string>", "load_rel Unable to find caller"
caller_dir = os.path.dirname(caller_fname)
filename = os.path.join(caller_dir, filename)
return LazyConfig.load(filename, keys)

@staticmethod
def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
"""
Load a config file.
Args:
filename: absolute path or relative path w.r.t. the current working directory
keys: keys to load and return. If not given, return all keys
(whose values are config objects) in a dict.
"""
has_keys = keys is not None
filename = filename.replace("/./", "/") # redundant
if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
raise ValueError(f"Config file {filename} has to be a python or yaml file.")
if filename.endswith(".py"):
_validate_py_syntax(filename)

with _patch_import():
# Record the filename
module_namespace = {
"__file__": filename,
"__package__": _random_package_name(filename),
}
with PathManager.open(filename) as f:
content = f.read()
# Compile first with filename to:
# 1. make filename appears in stacktrace
# 2. make load_rel able to find its parent's (possibly remote) location
exec(compile(content, filename, "exec"), module_namespace)

ret = module_namespace
else:
with PathManager.open(filename) as f:
obj = yaml.unsafe_load(f)
ret = OmegaConf.create(obj, flags={"allow_objects": True})

if has_keys:
if isinstance(keys, str):
return _cast_to_config(ret[keys])
else:
return tuple(_cast_to_config(ret[a]) for a in keys)
else:
if filename.endswith(".py"):
# when not specified, only load those that are config objects
ret = DictConfig(
{
name: _cast_to_config(value)
for name, value in ret.items()
if isinstance(value, (DictConfig, ListConfig, dict))
and not name.startswith("_")
},
flags={"allow_objects": True},
)
return ret

@staticmethod
def save(cfg, filename: str):
"""
Args:
cfg: an omegaconf config object
filename: yaml file name to save the config file
"""
logger = logging.getLogger(__name__)
try:
cfg = deepcopy(cfg)
except Exception:
pass
else:
# if it's deep-copyable, then...
def _replace_type_by_name(x):
if "_target_" in x and callable(x._target_):
try:
x._target_ = _convert_target_to_string(x._target_)
except AttributeError:
pass

# not necessary, but makes yaml looks nicer
_visit_dict_config(cfg, _replace_type_by_name)

try:
OmegaConf.save(cfg, filename)
except Exception:
logger.exception("Unable to serialize the config to yaml. Error:")
new_filename = filename + ".pkl"
try:
# retry by pickle
with PathManager.open(new_filename, "wb") as f:
cloudpickle.dump(cfg, f)
logger.warning(f"Config saved using cloudpickle at {new_filename} ...")
except Exception:
pass

@staticmethod
def apply_overrides(cfg, overrides: List[str]):
"""
In-place override contents of cfg.
Args:
cfg: an omegaconf config object
overrides: list of strings in the format of "a=b" to override configs.
See https://hydra.cc/docs/next/advanced/override_grammar/basic/
for syntax.
Returns:
the cfg object
"""
from hydra.core.override_parser.overrides_parser import OverridesParser

parser = OverridesParser.create()
overrides = parser.parse_overrides(overrides)
for o in overrides:
key = o.key_or_group
value = o.value()
# TODO seems nice to support this
assert not o.is_delete(), "deletion is not a supported override"
OmegaConf.update(cfg, key, value, merge=True)
return cfg
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ multi_line_output=3
include_trailing_comma=True
known_standard_library=numpy,setuptools,mock
skip=./datasets,docs
skip_glob=*/__init__.py,**/configs/**
skip_glob=*/__init__.py,**/configs/**,tests/config/**
known_myself=detectron2
known_third_party=fvcore,matplotlib,cv2,torch,torchvision,PIL,pycocotools,yacs,termcolor,cityscapesscripts,tabulate,tqdm,scipy,lvis,psutil,pkg_resources,caffe2,onnx,panopticapi,black,isort,av,iopath,omegaconf,hydra,yaml,pydoc,submitit,cloudpickle
no_lines_before=STDLIB,THIRDPARTY
Expand Down
3 changes: 3 additions & 0 deletions tests/config/dir1/dir1_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates.
dir1a_str = "base_a_1"
dir1a_dict = {"a": 1, "b": 2}
11 changes: 11 additions & 0 deletions tests/config/dir1/dir1_b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from detectron2.config import LazyConfig

# equivalent to relative import
dir1a_str, dir1a_dict = LazyConfig.load_rel("dir1_a.py", ("dir1a_str", "dir1a_dict"))

dir1b_str = dir1a_str + "_from_b"
dir1b_dict = dir1a_dict

# Every import is a reload: not modified by other config files
assert dir1a_dict.a == 1
14 changes: 14 additions & 0 deletions tests/config/root_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from itertools import count

from detectron2.config import LazyCall as L

from .dir1.dir1_a import dir1a_dict, dir1a_str

dir1a_dict.a = "modified"

# modification above won't affect future imports
from .dir1.dir1_b import dir1b_dict, dir1b_str


lazyobj = L(count)(x=dir1a_str, y=dir1b_str)
Loading

0 comments on commit 65faeb4

Please sign in to comment.