Skip to content
Permalink
Browse files

fix(base): fix duplicate load and init from yaml

  • Loading branch information...
hanxiao committed Jul 19, 2019
1 parent 69a486e commit 991e4425ce1d650d0b2602df8abaab85f07c9b5f
Showing with 48 additions and 32 deletions.
  1. +47 −26 gnes/base/__init__.py
  2. +1 −6 gnes/service/base.py
@@ -60,7 +60,7 @@ def _import(module_name, class_name):


class TrainableType(type):
default_property = {
default_gnes_config = {
'is_trained': False,
'batch_size': None,
'work_dir': os.environ.get('GNES_VOLUME', os.getcwd()),
@@ -77,12 +77,8 @@ def __call__(cls, *args, **kwargs):

obj = type.__call__(cls, *args, **kwargs)

# set attribute
for k, v in TrainableType.default_property.items():
if not hasattr(obj, k):
setattr(obj, k, v)
obj._set_gnes_config(**kwargs)

# do _post_init()
getattr(obj, '_post_init_wrapper', lambda *x: None)()
return obj

@@ -179,6 +175,23 @@ def post_init(self):
def pre_init(cls):
pass

def _set_gnes_config(self, **kwargs):
# set attribute
for k, v in TrainableType.default_gnes_config.items():
if k in kwargs:
v = kwargs[k]
setattr(self, k, v)

if not getattr(self, 'name', None):
_id = str(uuid.uuid4()).split('-')[0]
_name = '%s-%s' % (self.__class__.__name__, _id)
self.logger.warning(
'this object is not named ("- gnes_config: - name" is not found in YAML config), '
'i will call it as "%s". '
'However, naming the object is important especially when you need to '
'serialize/deserialize/store/load the object.' % _name)
setattr(self, 'name', _name)

@property
def dump_full_path(self):
return os.path.join(self.work_dir, '%s.bin' % self.name)
@@ -286,30 +299,38 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):

data = ruamel.yaml.constructor.SafeConstructor.construct_mapping(
constructor, node, deep=True)
cls.init_from_yaml = True

if cls.store_args_kwargs:
p = data.get('parameter', {}) # type: Dict[str, Any]
a = p.pop('args') if 'args' in p else ()
k = p.pop('kwargs') if 'kwargs' in p else {}
# maybe there are some hanging kwargs in "parameter"
tmp_a = (cls._convert_env_var(v) for v in a)
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in {**k, **p}.items()}
obj = cls(*tmp_a, **tmp_p)
else:
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()}
obj = cls(**tmp_p)

for k, v in data.get('gnes_config', {}).items():
old = getattr(obj, k, None)
setattr(obj, k, v)
if old and old != v:
obj.logger.info('gnes_config: %r is replaced from %r to %r' % (k, old, v))
dump_path = cls._get_dump_path_from_config(data)
if dump_path:
obj = cls.load(dump_path)
obj.logger.info('restore %s from %s' % (cls.__name__, dump_path))
else:
cls.init_from_yaml = True

if cls.store_args_kwargs:
p = data.get('parameter', {}) # type: Dict[str, Any]
a = p.pop('args') if 'args' in p else ()
k = p.pop('kwargs') if 'kwargs' in p else {}
# maybe there are some hanging kwargs in "parameter"
tmp_a = (cls._convert_env_var(v) for v in a)
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in {**k, **p}.items()}
obj = cls(*tmp_a, **tmp_p, **data.get('gnes_config', {}))
else:
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()}
obj = cls(**tmp_p, **data.get('gnes_config', {}))

cls.init_from_yaml = False
obj.logger.info('initialize %s from a yaml config' % cls.__name__)
cls.init_from_yaml = False

return obj, data

@staticmethod
def _get_dump_path_from_config(gnes_config: Dict):
if 'work_dir' in gnes_config and 'name' in gnes_config:
dump_path = os.path.join(gnes_config['work_dir'], '%s.bin' % gnes_config['name'])
if os.path.exists(dump_path):
return dump_path

@staticmethod
def _convert_env_var(v):
if isinstance(v, str):
@@ -320,7 +341,7 @@ def _convert_env_var(v):
@staticmethod
def _dump_instance_to_yaml(data):
# note: we only dump non-default property for the sake of clarity
p = {k: getattr(data, k) for k, v in TrainableType.default_property.items() if getattr(data, k) != v}
p = {k: getattr(data, k) for k, v in TrainableType.default_gnes_config.items() if getattr(data, k) != v}
a = {k: v for k, v in data._init_kwargs_dict.items()}
r = {}
if a:
@@ -304,14 +304,9 @@ def post_init(self):

def load_model(self, base_class: Type[TrainableBase]) -> T:
try:
model = base_class.load_yaml(self.args.yaml_path)
return base_class.load_yaml(self.args.yaml_path)
except FileNotFoundError:
raise ComponentNotLoad
try:
model = model.__class__.load(model.dump_full_path)
except FileNotFoundError:
self.logger.warning('load an empty %s from %s' % (model.__class__.__name__, self.args.yaml_path))
return model

@handler.register(NotImplementedError)
def _handler_default(self, msg: 'gnes_pb2.Message'):

0 comments on commit 991e442

Please sign in to comment.
You can’t perform that action at this time.