-
Notifications
You must be signed in to change notification settings - Fork 226
/
config.py
159 lines (138 loc) · 5.95 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# pyre-strict
import logging
import os
from typing import Any, Callable, Dict, IO, List, Union
import yaml
from iopath.common.file_io import g_pathmgr
from yacs.config import CfgNode as _CfgNode
BASE_KEY = "_BASE_"
class CfgNode(_CfgNode):
"""
Our own extended version of :class:`yacs.config.CfgNode`.
It contains the following extra features:
1. The :meth:`merge_from_file` method supports the "_BASE_" key,
which allows the new CfgNode to inherit all the attributes from the
base configuration file(s).
2. Keys that start with "COMPUTED_" are treated as insertion-only
"computed" attributes. They can be inserted regardless of whether
the CfgNode is frozen or not.
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
expressions in config. See examples in
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
Note that this may lead to arbitrary code execution: you must not
load a config file from untrusted sources before manually inspecting
the content of the file.
"""
@classmethod
def _open_cfg(cls, filename: str) -> Union[IO[str], IO[bytes]]:
"""
Defines how a config file is opened. May be overridden to support
different file schemas.
"""
return g_pathmgr.open(filename, "r")
@classmethod
def load_yaml_with_base(
cls, filename: str, allow_unsafe: bool = False
) -> Dict[str, Any]:
"""
Just like `yaml.load(open(filename))`, but inherit attributes from its
`_BASE_`.
Args:
filename (str or file-like object): the file name or file of the current config.
Will be used to find the base config file.
allow_unsafe (bool): whether to allow loading the config file with
`yaml.unsafe_load`.
Returns:
(dict): the loaded yaml
"""
with cls._open_cfg(filename) as f:
try:
cfg = yaml.safe_load(f)
except yaml.constructor.ConstructorError:
if not allow_unsafe:
raise
logger = logging.getLogger(__name__)
logger.warning(
"Loading config {} with yaml.unsafe_load. Your machine may "
"be at risk if the file contains malicious content.".format(
filename
)
)
f.close()
with cls._open_cfg(filename) as f:
cfg = yaml.unsafe_load(f)
def merge_a_into_b(a: Dict[str, Any], b: Dict[str, Any]) -> None:
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v
def _load_with_base(base_cfg_file: str) -> Dict[str, Any]:
if base_cfg_file.startswith("~"):
base_cfg_file = os.path.expanduser(base_cfg_file)
if not any(map(base_cfg_file.startswith, ["/", "https://", "http://"])):
# the path to base cfg is relative to the config file itself.
base_cfg_file = os.path.join(os.path.dirname(filename), base_cfg_file)
return cls.load_yaml_with_base(base_cfg_file, allow_unsafe=allow_unsafe)
if BASE_KEY in cfg:
if isinstance(cfg[BASE_KEY], list):
base_cfg: Dict[str, Any] = {}
base_cfg_files = cfg[BASE_KEY]
for base_cfg_file in base_cfg_files:
merge_a_into_b(_load_with_base(base_cfg_file), base_cfg)
else:
base_cfg_file = cfg[BASE_KEY]
base_cfg = _load_with_base(base_cfg_file)
del cfg[BASE_KEY]
merge_a_into_b(cfg, base_cfg)
return base_cfg
return cfg
def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False) -> None:
"""
Merge configs from a given yaml file.
Args:
cfg_filename: the file name of the yaml config.
allow_unsafe: whether to allow loading the config file with
`yaml.unsafe_load`.
"""
loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
loaded_cfg = type(self)(loaded_cfg)
self.merge_from_other_cfg(loaded_cfg)
# Forward the following calls to base, but with a check on the BASE_KEY.
def merge_from_other_cfg(self, cfg_other: "CfgNode") -> Callable[[], None]:
"""
Args:
cfg_other (CfgNode): configs to merge from.
"""
assert (
BASE_KEY not in cfg_other
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
return super().merge_from_other_cfg(cfg_other)
def merge_from_list(self, cfg_list: List[str]) -> Callable[[], None]:
"""
Args:
cfg_list (list): list of configs to merge from.
"""
keys = set(cfg_list[0::2])
assert (
BASE_KEY not in keys
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
return super().merge_from_list(cfg_list)
def __setattr__(self, name: str, val: Any) -> None: # pyre-ignore
if name.startswith("COMPUTED_"):
if name in self:
old_val = self[name]
if old_val == val:
return
raise KeyError(
"Computed attributed '{}' already exists "
"with a different value! old={}, new={}.".format(name, old_val, val)
)
self[name] = val
else:
super().__setattr__(name, val)