Skip to content

Commit

Permalink
Clean up pacakge specific logic from config repository
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Dec 15, 2020
1 parent b4a6e38 commit ca20052
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 50 deletions.
Expand Up @@ -63,7 +63,7 @@ def load_config(
if name not in self.configs:
raise ConfigLoadError("Config not found : " + config_path)

res_header = {"package": None}
res_header: Dict[str, Optional[str]] = {"package": None}
if name in self.headers:
header = self.headers[name]
res_header["package"] = header["package"] if "package" in header else None
Expand Down
24 changes: 21 additions & 3 deletions hydra/_internal/config_loader_impl.py
Expand Up @@ -2,6 +2,7 @@
"""
Configuration loader
"""
import copy
import os
import re
import sys
Expand Down Expand Up @@ -351,10 +352,9 @@ def _load_single_config(
self, default: ResultDefault, repo: IConfigRepository, is_primary: bool
) -> Tuple[ConfigResult, LoadTrace]:
config_path = default.config_path
package = default.package

assert config_path is not None
ret = repo.load_config(config_path=config_path, package_override=package)
ret = repo.load_config(config_path=config_path)
assert ret is not None

if not isinstance(ret.config, DictConfig):
Expand Down Expand Up @@ -385,7 +385,6 @@ def _load_single_config(
if "hydra" in ret.config and not hydra_config_group:
hydra = ret.config.pop("hydra")

schema = repo._embed_result_config(schema, package)
merged = OmegaConf.merge(schema.config, ret.config)
assert isinstance(merged, DictConfig)

Expand All @@ -408,8 +407,27 @@ def _load_single_config(
search_path=ret.path,
provider=ret.provider,
)

ret = self._embed_result_config(ret, default.package)

return ret, trace

@staticmethod
def _embed_result_config(
ret: ConfigResult, package_override: Optional[str]
) -> ConfigResult:
package = ret.header["package"]
if package_override is not None:
package = package_override

if package is not None and package != "":
cfg = OmegaConf.create()
OmegaConf.update(cfg, package, ret.config, merge=False)
ret = copy.copy(ret)
ret.config = cfg

return ret

def list_groups(self, parent_name: str) -> List[str]:
return self.get_group_options(
group_name=parent_name, results_filter=ObjectType.GROUP
Expand Down
45 changes: 7 additions & 38 deletions hydra/_internal/config_repository.py
@@ -1,5 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -29,9 +28,7 @@ def get_schema_source(self) -> ConfigSource:
...

@abstractmethod
def load_config(
self, config_path: str, package_override: Optional[str] = None
) -> Optional[ConfigResult]:
def load_config(self, config_path: str) -> Optional[ConfigResult]:
...

@abstractmethod
Expand All @@ -52,28 +49,6 @@ def get_group_options(
def get_sources(self) -> List[ConfigSource]:
...

@staticmethod
def _embed_config(node: Container, package: str) -> Container:
if package == "_global_":
package = ""

if package is not None and package != "":
cfg = OmegaConf.create()
OmegaConf.update(cfg, package, node, merge=False)
else:
cfg = OmegaConf.structured(node)
return cfg

@staticmethod
def _embed_result_config(ret: ConfigResult, package_override: str) -> ConfigResult:
package = ret.header["package"]
if package_override is not None:
package = package_override

ret = copy.copy(ret)
ret.config = ConfigRepository._embed_config(ret.config, package)
return ret


class ConfigRepository(IConfigRepository):

Expand All @@ -98,9 +73,7 @@ def get_schema_source(self) -> ConfigSource:
)
return source

def load_config(
self, config_path: str, package_override: Optional[str] = None
) -> Optional[ConfigResult]:
def load_config(self, config_path: str) -> Optional[ConfigResult]:
source = self._find_object_source(
config_path=config_path, object_type=ObjectType.CONFIG
)
Expand All @@ -117,8 +90,8 @@ def load_config(
raw_defaults = self._extract_defaults_list(config_path, ret.config)
ret.defaults_list = self._create_defaults_list(config_path, raw_defaults)

# TODO: push to a higher level?
ret = self._embed_result_config(ret, package_override)
# # TODO: push to a higher level?
# ret = self._embed_result_config(ret, package_override)

return ret

Expand Down Expand Up @@ -305,16 +278,12 @@ def __init__(self, delegeate: IConfigRepository):
def get_schema_source(self) -> ConfigSource:
return self.delegate.get_schema_source()

def load_config(
self, config_path: str, package_override: Optional[str] = None
) -> Optional[ConfigResult]:
cache_key = f"config_path={config_path},package_override={package_override}"
def load_config(self, config_path: str) -> Optional[ConfigResult]:
cache_key = f"config_path={config_path}"
if cache_key in self.cache:
return self.cache[cache_key]
else:
ret = self.delegate.load_config(
config_path=config_path, package_override=package_override
)
ret = self.delegate.load_config(config_path=config_path)
self.cache[cache_key] = ret
return ret

Expand Down
2 changes: 2 additions & 0 deletions hydra/_internal/defaults_list.py
Expand Up @@ -436,6 +436,8 @@ def _create_defaults_tree_impl(

assert loaded is not None
defaults_list = copy.deepcopy(loaded.defaults_list)
if defaults_list is None:
defaults_list = []

if is_primary_config:
for gd in overrides.append_group_defaults:
Expand Down
2 changes: 1 addition & 1 deletion hydra/core/default_element.py
Expand Up @@ -106,7 +106,7 @@ def is_deleted(self) -> bool:
else:
return False

def set_package_header(self, package_header: str) -> None:
def set_package_header(self, package_header: Optional[str]) -> None:
assert self.__dict__["package_header"] is None
if (
package_header is None
Expand Down
8 changes: 4 additions & 4 deletions hydra/plugins/config_source.py
Expand Up @@ -18,8 +18,8 @@ class ConfigResult:
provider: str
path: str
config: Container
header: Dict[str, str]
defaults_list: List[InputDefault] = None
header: Dict[str, Optional[str]]
defaults_list: Optional[List[InputDefault]] = None
is_schema_source: bool = False


Expand Down Expand Up @@ -121,8 +121,8 @@ def _normalize_file_name(filename: str) -> str:
return filename

@staticmethod
def _get_header_dict(config_text: str) -> Dict[str, str]:
res = {}
def _get_header_dict(config_text: str) -> Dict[str, Optional[str]]:
res: Dict[str, Optional[str]] = {}
for line in config_text.splitlines():
line = line.strip()
if len(line) == 0:
Expand Down
2 changes: 1 addition & 1 deletion hydra/test_utils/config_source_common_tests.py
Expand Up @@ -251,7 +251,7 @@ def test_source_load_config(
path: str,
config_path: str,
expected_defaults_list: List[InputDefault],
expected_package,
expected_package: Any,
expected_config: Any,
recwarn: Any,
) -> None:
Expand Down
5 changes: 3 additions & 2 deletions tests/test_config_repository.py
Expand Up @@ -12,7 +12,7 @@
ImportlibResourcesConfigSource,
)
from hydra._internal.core_plugins.structured_config_source import StructuredConfigSource
from hydra.core.default_element import InputDefault, GroupDefault
from hydra.core.default_element import GroupDefault, InputDefault
from hydra.core.plugins import Plugins
from hydra.core.singleton import Singleton
from hydra.plugins.config_source import ConfigSource
Expand Down Expand Up @@ -126,6 +126,7 @@ def test_config_repository_list(
config_search_path = create_config_search_path(path)
repo = ConfigRepository(config_search_path=config_search_path)
ret = repo.load_config(config_path)
assert ret is not None
assert ret.defaults_list == expected


Expand Down Expand Up @@ -173,7 +174,7 @@ def test_get_config_header(cfg_text: str, expected: Any, sep: str) -> None:
ConfigSource._get_header_dict(cfg_text)


def test_restore_singleton_state_hack():
def test_restore_singleton_state_hack() -> None:
"""
This is a hack that allow us to undo changes to the ConfigStore.
During this test, the config store is being modified in Python imports.
Expand Down

0 comments on commit ca20052

Please sign in to comment.