Skip to content
Permalink
Browse files

fix(base): fix gnes_config mixed in kwargs

  • Loading branch information...
hanxiao committed Jul 25, 2019
1 parent 68c15fa commit c52c2cc69239a89b5aeba473ef8a6c0fc48ea744
Showing with 16 additions and 9 deletions.
  1. +12 −7 gnes/base/__init__.py
  2. +4 −2 tests/test_load_dump_pipeline.py
@@ -77,12 +77,17 @@ def __call__(cls, *args, **kwargs):
# do _preload_package
getattr(cls, '_pre_init', lambda *x: None)()

if 'gnes_config' in kwargs:
gnes_config = kwargs.pop('gnes_config')
else:
gnes_config = {}

obj = type.__call__(cls, *args, **kwargs)

# set attribute
for k, v in TrainableType.default_gnes_config.items():
if k in kwargs:
v = kwargs[k]
if k in gnes_config:
v = gnes_config[k]
if not hasattr(obj, k):
setattr(obj, k, v)

@@ -295,9 +300,6 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
data = ruamel.yaml.constructor.SafeConstructor.construct_mapping(
constructor, node, deep=True)

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1'

dump_path = cls._get_dump_path_from_config(data.get('gnes_config', {}))
load_from_dump = False
if dump_path:
@@ -314,14 +316,17 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
# maybe there are some hanging kwargs in "parameter"
tmp_a = (cls._convert_env_var(v) for v in a)
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in {**k, **p}.items()}
obj = cls(*tmp_a, **tmp_p, **data.get('gnes_config', {}))
obj = cls(*tmp_a, **tmp_p, gnes_config=data.get('gnes_config', {}))
else:
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()}
obj = cls(**tmp_p, **data.get('gnes_config', {}))
obj = cls(**tmp_p, gnes_config=data.get('gnes_config', {}))

obj.logger.info('initialize %s from a yaml config' % cls.__name__)
cls.init_from_yaml = False

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1'

return obj, data, load_from_dump

@staticmethod
@@ -44,10 +44,12 @@ def test_name_warning(self):
d2.name = ''
d3 = PipelineEncoder()
d3.component = lambda: [d1, d2]
d3.name = 'aa'
d3.name = 'dummy-pipeline'
d3.work_dir = './'
d3.dump()
d3.dump_yaml()
print('there should not be any warning after this line')
d31 = BaseEncoder.load_yaml(d3.yaml_full_path)
BaseEncoder.load_yaml(d3.yaml_full_path)

def test_dummytf(self):
d1 = DummyTFEncoder()

0 comments on commit c52c2cc

Please sign in to comment.
You can’t perform that action at this time.