Skip to content
Permalink
Browse files

fix(base): fix redundant warning in pipeline encoder

  • Loading branch information...
hanxiao committed Jul 25, 2019
1 parent aadeeef commit 68c15fac3d7d32cb9f32de620bc930206b18b2f7
Showing with 19 additions and 5 deletions.
  1. +7 −5 gnes/base/__init__.py
  2. +12 −0 tests/test_load_dump_pipeline.py
@@ -33,7 +33,6 @@
T = TypeVar('T', bound='TrainableBase')



def register_all_class(cls2file_map: Dict, module_name: str):
import importlib
for k, v in cls2file_map.items():
@@ -164,7 +163,7 @@ def __init__(self, *args, **kwargs):
self._post_init_vars = set()

def _post_init_wrapper(self):
if not getattr(self, 'name', None):
if not getattr(self, 'name', None) and os.environ.get('GNES_WARN_UNNAMED_COMPONENT', '1') == '1':
_id = str(uuid.uuid4()).split('-')[0]
_name = '%s-%s' % (self.__class__.__name__, _id)
self.logger.warning(
@@ -290,9 +289,15 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
if stop_on_import_error:
raise RuntimeError('Cannot import module, pip install may required') from ex

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '0'

data = ruamel.yaml.constructor.SafeConstructor.construct_mapping(
constructor, node, deep=True)

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1'

dump_path = cls._get_dump_path_from_config(data.get('gnes_config', {}))
load_from_dump = False
if dump_path:
@@ -344,6 +349,3 @@ def _dump_instance_to_yaml(data):
if p:
r['gnes_config'] = p
return r



@@ -37,6 +37,18 @@ def test_base(self):
b = BaseEncoder.load_yaml(self.yaml_path)
self.assertTrue(b.is_trained)

def test_name_warning(self):
d1 = DummyTFEncoder()
d2 = DummyTFEncoder()
d1.name = ''
d2.name = ''
d3 = PipelineEncoder()
d3.component = lambda: [d1, d2]
d3.name = 'aa'
d3.dump_yaml()
print('there should not be any warning after this line')
d31 = BaseEncoder.load_yaml(d3.yaml_full_path)

def test_dummytf(self):
d1 = DummyTFEncoder()
self.assertEqual(d1.encode(1), 2)

0 comments on commit 68c15fa

Please sign in to comment.
You can’t perform that action at this time.