Skip to content
Permalink
Browse files

refactor(base): remove dump path and reorganize work dir

  • Loading branch information...
hanxiao committed Jul 18, 2019
1 parent 2da1942 commit a4e153d7bddad52625a9ad5527024e0f2671f160
@@ -26,7 +26,7 @@

import ruamel.yaml.constructor

from ..helper import set_logger, profiling, yaml, parse_arg, touch_dir, FileLock
from ..helper import set_logger, profiling, yaml, parse_arg

__all__ = ['TrainableBase']

@@ -63,6 +63,8 @@ class TrainableType(type):
default_property = {
'is_trained': False,
'batch_size': None,
'work_dir': os.environ.get('GNES_VOLUME', os.getcwd()),
'name': None
}

def __new__(cls, *args, **kwargs):
@@ -149,14 +151,9 @@ def arg_wrapper(self, *args, **kwargs):

class TrainableBase(metaclass=TrainableType):
store_args_kwargs = False
lock_work_dir = False

def __init__(self, *args, **kwargs):
self.is_trained = False
self._obj_id = str(uuid.uuid4()).split('-')[0]
self._obj_pickle_name = '%s%s.bin' % (self.__class__.__name__, self._obj_id)
self._obj_yaml_name = '%s%s.yml' % (self.__class__.__name__, self._obj_id)
self._work_dir = os.getcwd()
self.verbose = 'verbose' in kwargs and kwargs['verbose']
self.logger = set_logger(self.__class__.__name__, self.verbose)
self._post_init_vars = set()
@@ -175,26 +172,11 @@ def pre_init(cls):

@property
def pickle_full_path(self):
return os.path.join(self.work_dir, self._obj_pickle_name)
return os.path.join(self.work_dir, '%s.bin' % self.name)

@property
def yaml_full_path(self):
return os.path.join(self.work_dir, self._obj_yaml_name)

@property
def work_dir(self):
return self._work_dir

@work_dir.setter
def work_dir(self, value: str):
touch_dir(value)
if self.lock_work_dir:
self._file_lock = FileLock(os.path.join(value, "LOCK"))
if self._file_lock.acquire() is None:
raise RuntimeError(
"this model\'s work_dir %r is used and locked by another model" %
value)
self._work_dir = value
return os.path.join(self.work_dir, '%s.yml' % self.name)

def __getstate__(self):
d = dict(self.__dict__)
@@ -208,9 +190,6 @@ def __getstate__(self):
def __setstate__(self, d):
self.__dict__.update(d)
self.logger = set_logger(self.__class__.__name__, self.verbose)
if self.lock_work_dir:
# trigger the lock again
self.work_dir = self._work_dir
try:
self._post_init_wrapper()
except ImportError:
@@ -224,16 +203,16 @@ def train(self, *args, **kwargs):
def dump(self, filename: str = None) -> None:
f = filename or self.pickle_full_path
if not f:
f = tempfile.NamedTemporaryFile('w', delete=False, dir=os.environ.get('NES_TEMP_DIR', None)).name
f = tempfile.NamedTemporaryFile('w', delete=False, dir=os.environ.get('GNES_VOLUME', None)).name
with open(f, 'wb') as fp:
pickle.dump(self, fp)
self.logger.info('model is pickled to %s' % f)
self.logger.info('model is stored to %s' % f)

@profiling
def dump_yaml(self, filename: str = None) -> None:
f = filename or self.yaml_full_path
if not f:
f = tempfile.NamedTemporaryFile('w', delete=False, dir=os.environ.get('NES_TEMP_DIR', None)).name
f = tempfile.NamedTemporaryFile('w', delete=False, dir=os.environ.get('GNES_VOLUME', None)).name
with open(filename, 'w') as fp:
yaml.dump(self, fp)
self.logger.info('model\'s yaml config is dump to %s' % f)
@@ -248,16 +227,15 @@ def load_yaml(cls: Type[T], filename: Union[str, TextIO]) -> T:
with filename:
return yaml.load(filename)

@staticmethod
@profiling
def load(filename: str) -> T:
if not filename: raise FileNotFoundError
with open(filename, 'rb') as fp:
def load(self, filename: str) -> T:
f = filename or self.pickle_full_path
if not f: raise FileNotFoundError
with open(f, 'rb') as fp:
return pickle.load(fp)

def close(self):
if self.lock_work_dir:
self._file_lock.release()
pass

def __enter__(self):
return self
@@ -313,11 +291,20 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()}
obj = cls(**tmp_p)

for k, v in data.get('property', {}).items():
for k, v in data.get('gnes_config', {}).items():
old = getattr(obj, k, None)
setattr(obj, k, v)
if old and old != v:
obj.logger.info('property: %r is replaced from %r to %r' % (k, old, v))
obj.logger.info('gnes_config: %r is replaced from %r to %r' % (k, old, v))

if not getattr(obj, 'name', None):
_id = str(uuid.uuid4()).split('-')[0]
_name = obj.__class__.__name__ + _id
obj.logger.warning(
'this object is not named ("- gnes_config: - name" is not found), i will call it as "%s". '
'However, naming the object is important especially when you need to '
'serialize/deserialize/store/load the object.' % _name)
setattr(obj, 'name', _name)

cls.init_from_yaml = False

@@ -339,5 +326,5 @@ def _dump_instance_to_yaml(data):
if a:
r['parameter'] = a
if p:
r['property'] = p
r['gnes_config'] = p
return r
@@ -56,6 +56,8 @@ def set_composer_parser(parser=None):
help='output path of the docker-compose file for Docker Swarm')
parser.add_argument('--k8s_path', type=argparse.FileType('w', encoding='utf8'),
help='output path of the docker-compose file for Docker Swarm')
parser.add_argument('--graph_path', type=argparse.FileType('w', encoding='utf8'),
help='output path of the mermaid graph file')
parser.add_argument('--shell_log_redirect', type=str,
help='the file path for redirecting shell output. '
'when not given, the output will be flushed to stdout')
@@ -90,10 +92,10 @@ def set_service_parser(parser=None):
parser.add_argument('--timeout', type=int, default=-1,
help='timeout (ms) of all communication, -1 for waiting forever')
parser.add_argument('--dump_interval', type=int, default=5,
help='dump the service every n seconds')
help='serialize the service to a file every n seconds')
parser.add_argument('--read_only', action='store_true', default=False,
help='do not allow the service to modify the model, '
'dump_path and dump_interval will be ignored')
'dump_interval will be ignored')
return parser


@@ -118,8 +120,6 @@ def set_loadable_service_parser(parser=None):
from ..service.base import SocketType
set_service_parser(parser)

parser.add_argument('--dump_path', type=str, default=None,
help='binary dump of the service')
parser.add_argument('--yaml_path', type=argparse.FileType('r'),
default=pkg_resources.resource_stream(
'gnes', '/'.join(('resources', 'config', 'encoder', 'default.yml'))),
@@ -29,7 +29,6 @@ class Layer:
default_values = {
'name': None,
'yaml_path': None,
'dump_path': None,
'replicas': 1,
'income': 'pull'
}
@@ -106,9 +105,9 @@ def check_fields(self, comp: Dict) -> bool:
if comp['name'] not in self.comp2file:
raise AttributeError(
'a component must be one of: %s, but given %s' % (', '.join(self.comp2file.keys()), comp['name']))
if 'yaml_path' not in comp and 'dump_path' not in comp:
if 'yaml_path' not in comp:
self.logger.warning(
'found empty "yaml_path" and "dump_path", '
'found empty "yaml_path", '
'i will use a default config and would probably result in an empty model')
for k in comp:
if k not in self.Layer.default_values:
@@ -138,7 +137,7 @@ def build_layers(self) -> List['YamlGraph.Layer']:
def build_dockerswarm(all_layers: List['YamlGraph.Layer'], docker_img: str = 'gnes/gnes:latest',
volumes: Dict = None) -> str:
swarm_lines = {'version': '3.4', 'services': {}}
taboo = {'name', 'replicas', 'yaml_path', 'dump_path'}
taboo = {'name', 'replicas', 'yaml_path'}
config_dict = {}
network_dict = {'gnes-network': {'driver': 'overlay', 'attachable': True}}
for l_idx, layer in enumerate(all_layers):
@@ -308,6 +307,7 @@ def std_or_print(f, content):
'timestamp': time.strftime("%a, %d %b %Y %H:%M:%S"),
'version': __version__
}
std_or_print(self.args.graph_path, cmds['mermaid'])
std_or_print(self.args.shell_path, cmds['shell'])
std_or_print(self.args.swarm_path, cmds['docker'])
std_or_print(self.args.k8s_path, cmds['k8s'])
@@ -397,8 +397,6 @@ def rule7():
router_layer = YamlGraph.Layer(layer_id=self._num_layer)
self._num_layer += 1
for c in layer.components:
income = self.Layer.get_value(c, 'income')

r = CommentedMap({'name': 'Router',
'yaml_path': None,
'socket_in': str(SocketType.SUB_CONNECT),
@@ -25,7 +25,7 @@


class Word2VecEncoder(BaseTextEncoder):
def __init__(self, model_dir,
def __init__(self, model_dir: str,
skiprows: int = 1,
batch_size: int = 64,
dimension: int = 300,
@@ -65,4 +65,3 @@ def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
pooled_data.append(pooling_simple(_layer_data, self.pooling_strategy))

return np.array(pooled_data).astype(np.float32)

@@ -24,7 +24,6 @@


class BaseIndexer(TrainableBase):
internal_index_path = 'int.indexer.bin' # this is used when pickle dump is not enough for storing all info

def add(self, keys: Any, docs: Any, weights: List[float], *args, **kwargs):
pass
@@ -27,7 +27,7 @@
class LVDBIndexer(BaseTextIndexer):

def __init__(self, data_path: str, keep_na_doc: bool = True, *args, **kwargs):
super().__init__()
super().__init__(*args, **kwargs)
self.data_path = data_path
self.keep_na_doc = keep_na_doc
self._NOT_FOUND = None
@@ -8,13 +8,11 @@


class AnnoyIndexer(BaseVectorIndexer):
lock_work_dir = True

def __init__(self, num_dim: int, data_path: str, metric: str = 'angular', n_trees=10, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_dim = num_dim
self.work_dir = data_path
self.indexer_file_path = os.path.join(self.work_dir, self.internal_index_path)
self.data_path = data_path
self.metric = metric
self.n_trees = n_trees
self._key_info_indexer = ListKeyIndexer()
@@ -23,7 +21,7 @@ def post_init(self):
from annoy import AnnoyIndex
self._index = AnnoyIndex(self.num_dim, self.metric)
try:
self._index.load(self.indexer_file_path)
self._index.load(self.data_path)
except:
self.logger.warning('fail to load model from %s, will create an empty one' % self.indexer_file_path)

@@ -58,5 +56,5 @@ def size(self):

def __getstate__(self):
d = super().__getstate__()
self._index.save(self.indexer_file_path)
self._index.save(self.data_path)
return d
@@ -25,7 +25,6 @@


class BIndexer(BaseVectorIndexer):
lock_work_dir = True

def __init__(self,
num_bytes: int = None,
@@ -41,17 +40,15 @@ def __init__(self,
self.insert_iterations = insert_iterations
self.query_iterations = query_iterations

self.work_dir = data_path
self.indexer_bin_path = os.path.join(self.work_dir,
self.internal_index_path)
self.data_path = data_path
self._weight_norm = 2 ** 16 - 1

def post_init(self):
self.bindexer = IndexCore(self.num_bytes, 4, self.ef,
self.insert_iterations,
self.query_iterations)
if os.path.exists(self.indexer_bin_path):
self.bindexer.load(self.indexer_bin_path)
if os.path.exists(self.data_path):
self.bindexer.load(self.data_path)

def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args,
**kwargs):
@@ -115,6 +112,6 @@ def query(self,
return result

def __getstate__(self):
self.bindexer.save(self.indexer_bin_path)
self.bindexer.save(self.data_path)
d = super().__getstate__()
return d
@@ -26,22 +26,20 @@


class FaissIndexer(BaseVectorIndexer):
lock_work_dir = True

def __init__(self, num_dim: int, index_key: str, data_path: str, *args, **kwargs):
super().__init__()
self.work_dir = data_path
self.indexer_file_path = os.path.join(self.work_dir, self.internal_index_path)
super().__init__(*args, **kwargs)
self.data_path = data_path
self.num_dim = num_dim
self.index_key = index_key
self._key_info_indexer = ListKeyIndexer()

def post_init(self):
import faiss
try:
self._faiss_index = faiss.read_index(self.indexer_file_path)
self._faiss_index = faiss.read_index(self.data_path)
except RuntimeError:
self.logger.warning('fail to load model from %s, will init an empty one' % self.indexer_file_path)
self.logger.warning('fail to load model from %s, will init an empty one' % self.data_path)
self._faiss_index = faiss.index_factory(self.num_dim, self.index_key)

def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
@@ -76,5 +74,5 @@ def size(self):
def __getstate__(self):
import faiss
d = super().__getstate__()
faiss.write_index(self._faiss_index, self.indexer_file_path)
faiss.write_index(self._faiss_index, self.data_path)
return d
@@ -25,7 +25,6 @@


class HBIndexer(BaseVectorIndexer):
lock_work_dir = True

def __init__(self,
num_clusters: int = 100,
@@ -38,18 +37,15 @@ def __init__(self,
self.n_bytes = num_bytes
self.n_clusters = num_clusters
self.n_idx = n_idx

self.work_dir = data_path
self.indexer_bin_path = os.path.join(self.work_dir,
self.internal_index_path)
self.data_path = data_path
self._weight_norm = 2 ** 16 - 1
if self.n_idx <= 0:
raise ValueError('There should be at least 1 clustering slot')

def post_init(self):
self.hbindexer = IndexCore(self.n_clusters, self.n_bytes, self.n_idx)
if os.path.exists(self.indexer_bin_path):
self.hbindexer.load(self.indexer_bin_path)
if os.path.exists(self.data_path):
self.hbindexer.load(self.data_path)

def add(self, keys: List[Tuple[int, int]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
if len(vectors) != len(keys):
@@ -93,6 +89,6 @@ def query(self,
return [sorted(ret.items(), key=lambda x: -x[1])[:top_k] for ret in result]

def __getstate__(self):
self.hbindexer.save(self.indexer_bin_path)
self.hbindexer.save(self.data_path)
d = super().__getstate__()
return d

0 comments on commit a4e153d

Please sign in to comment.
You can’t perform that action at this time.