Skip to content

Commit

Permalink
Clean up pacakge specific logic from config repository
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Dec 15, 2020
1 parent b4a6e38 commit a148fd8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 41 deletions.
24 changes: 21 additions & 3 deletions hydra/_internal/config_loader_impl.py
Expand Up @@ -2,6 +2,7 @@
"""
Configuration loader
"""
import copy
import os
import re
import sys
Expand Down Expand Up @@ -351,10 +352,9 @@ def _load_single_config(
self, default: ResultDefault, repo: IConfigRepository, is_primary: bool
) -> Tuple[ConfigResult, LoadTrace]:
config_path = default.config_path
package = default.package

assert config_path is not None
ret = repo.load_config(config_path=config_path, package_override=package)
ret = repo.load_config(config_path=config_path)
assert ret is not None

if not isinstance(ret.config, DictConfig):
Expand Down Expand Up @@ -385,7 +385,6 @@ def _load_single_config(
if "hydra" in ret.config and not hydra_config_group:
hydra = ret.config.pop("hydra")

schema = repo._embed_result_config(schema, package)
merged = OmegaConf.merge(schema.config, ret.config)
assert isinstance(merged, DictConfig)

Expand All @@ -408,8 +407,27 @@ def _load_single_config(
search_path=ret.path,
provider=ret.provider,
)

ret = self._embed_result_config(ret, default.package)

return ret, trace

@staticmethod
def _embed_result_config(
ret: ConfigResult, package_override: Optional[str]
) -> ConfigResult:
package = ret.header["package"]
if package_override is not None:
package = package_override

if package is not None and package != "":
cfg = OmegaConf.create()
OmegaConf.update(cfg, package, ret.config, merge=False)
ret = copy.copy(ret)
ret.config = cfg

return ret

def list_groups(self, parent_name: str) -> List[str]:
return self.get_group_options(
group_name=parent_name, results_filter=ObjectType.GROUP
Expand Down
45 changes: 7 additions & 38 deletions hydra/_internal/config_repository.py
@@ -1,5 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -29,9 +28,7 @@ def get_schema_source(self) -> ConfigSource:
...

@abstractmethod
def load_config(
self, config_path: str, package_override: Optional[str] = None
) -> Optional[ConfigResult]:
def load_config(self, config_path: str) -> Optional[ConfigResult]:
...

@abstractmethod
Expand All @@ -52,28 +49,6 @@ def get_group_options(
def get_sources(self) -> List[ConfigSource]:
...

@staticmethod
def _embed_config(node: Container, package: str) -> Container:
if package == "_global_":
package = ""

if package is not None and package != "":
cfg = OmegaConf.create()
OmegaConf.update(cfg, package, node, merge=False)
else:
cfg = OmegaConf.structured(node)
return cfg

@staticmethod
def _embed_result_config(ret: ConfigResult, package_override: str) -> ConfigResult:
package = ret.header["package"]
if package_override is not None:
package = package_override

ret = copy.copy(ret)
ret.config = ConfigRepository._embed_config(ret.config, package)
return ret


class ConfigRepository(IConfigRepository):

Expand All @@ -98,9 +73,7 @@ def get_schema_source(self) -> ConfigSource:
)
return source

def load_config(
self, config_path: str, package_override: Optional[str] = None
) -> Optional[ConfigResult]:
def load_config(self, config_path: str) -> Optional[ConfigResult]:
source = self._find_object_source(
config_path=config_path, object_type=ObjectType.CONFIG
)
Expand All @@ -117,8 +90,8 @@ def load_config(
raw_defaults = self._extract_defaults_list(config_path, ret.config)
ret.defaults_list = self._create_defaults_list(config_path, raw_defaults)

# TODO: push to a higher level?
ret = self._embed_result_config(ret, package_override)
# # TODO: push to a higher level?
# ret = self._embed_result_config(ret, package_override)

return ret

Expand Down Expand Up @@ -305,16 +278,12 @@ def __init__(self, delegeate: IConfigRepository):
def get_schema_source(self) -> ConfigSource:
return self.delegate.get_schema_source()

def load_config(
self, config_path: str, package_override: Optional[str] = None
) -> Optional[ConfigResult]:
cache_key = f"config_path={config_path},package_override={package_override}"
def load_config(self, config_path: str) -> Optional[ConfigResult]:
cache_key = f"config_path={config_path}"
if cache_key in self.cache:
return self.cache[cache_key]
else:
ret = self.delegate.load_config(
config_path=config_path, package_override=package_override
)
ret = self.delegate.load_config(config_path=config_path)
self.cache[cache_key] = ret
return ret

Expand Down

0 comments on commit a148fd8

Please sign in to comment.