From 48d1828c9f79b4c93639a33c093b13b965465c3c Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Tue, 15 Oct 2019 12:20:30 +0800 Subject: [PATCH] refactor(flow): reorganize file structure for flow --- gnes/flow/__init__.py | 145 ++++++++++-------------------------------- gnes/flow/helper.py | 88 +++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 113 deletions(-) create mode 100644 gnes/flow/helper.py diff --git a/gnes/flow/__init__.py b/gnes/flow/__init__.py index 0bcdd601..4dceff59 100644 --- a/gnes/flow/__init__.py +++ b/gnes/flow/__init__.py @@ -1,64 +1,12 @@ import copy from collections import OrderedDict, defaultdict from contextlib import ExitStack -from functools import wraps from typing import Union, Tuple, List, Optional, Iterator +from .helper import * from ..base import TrainableBase -from ..cli.parser import set_router_parser, set_indexer_parser, \ - set_frontend_parser, set_preprocessor_parser, \ - set_encoder_parser, set_client_cli_parser -from ..client.cli import CLIClient from ..helper import set_logger -from ..service.base import SocketType, BaseService, BetterEnum, ServiceManager -from ..service.encoder import EncoderService -from ..service.frontend import FrontendService -from ..service.indexer import IndexerService -from ..service.preprocessor import PreprocessorService -from ..service.router import RouterService - - -class Service(BetterEnum): - Frontend = 0 - Encoder = 1 - Router = 2 - Indexer = 3 - Preprocessor = 4 - - -class FlowImcompleteError(ValueError): - """Exception when the flow missing some important component to run""" - - -class FlowTopologyError(ValueError): - """Exception when the topology is ambiguous""" - - -class FlowMissingNode(ValueError): - """Exception when the topology is ambiguous""" - - -class FlowBuildLevelMismatch(ValueError): - """Exception when required level is higher than the current build level""" - - -def _build_level(required_level: 'Flow.BuildLevel'): - def __build_level(func): - @wraps(func) - def arg_wrapper(self, *args, **kwargs): - if hasattr(self, '_build_level'): - if self._build_level.value >= required_level.value: - return func(self, *args, **kwargs) - else: - raise FlowBuildLevelMismatch( - 'build_level check failed for %r, required level: %s, actual level: %s' % ( - func, required_level, self._build_level)) - else: - raise AttributeError('%r has no attribute "_build_level"' % self) - - return arg_wrapper - - return __build_level +from ..service.base import SocketType, BaseService class Flow(TrainableBase): @@ -92,38 +40,6 @@ class Flow(TrainableBase): """ - _service_map = { - Service.Encoder: { - 'parser': set_encoder_parser, - 'builder': lambda x: ServiceManager(EncoderService, x), - 'cmd': 'encode'}, - Service.Router: { - 'parser': set_router_parser, - 'builder': lambda x: ServiceManager(RouterService, x), - 'cmd': 'route', - }, - Service.Indexer: { - 'parser': set_indexer_parser, - 'builder': lambda x: ServiceManager(IndexerService, x), - 'cmd': 'index' - }, - Service.Frontend: { - 'parser': set_frontend_parser, - 'builder': FrontendService, - 'cmd': 'frontend' - }, - Service.Preprocessor: { - 'parser': set_preprocessor_parser, - 'builder': lambda x: ServiceManager(PreprocessorService, x), - 'cmd': 'preprocess' - } - } - - class BuildLevel(BetterEnum): - EMPTY = 0 - GRAPH = 1 - RUNTIME = 2 - def __init__(self, with_frontend: bool = True, is_trained: bool = True, *args, **kwargs): """ Create a new Flow object. @@ -137,13 +53,13 @@ def __init__(self, with_frontend: bool = True, is_trained: bool = True, *args, * self.logger = set_logger(self.__class__.__name__) self._service_nodes = OrderedDict() self._service_edges = {} - self._service_name_counter = {k: 0 for k in Flow._service_map.keys()} + self._service_name_counter = {k: 0 for k in service_map.keys()} self._service_contexts = [] self._last_changed_service = [] self._common_kwargs = kwargs self._frontend = None self._client = None - self._build_level = Flow.BuildLevel.EMPTY + self._build_level = BuildLevel.EMPTY self._backend = None self._init_with_frontend = False self.is_trained = is_trained @@ -153,15 +69,15 @@ def __init__(self, with_frontend: bool = True, is_trained: bool = True, *args, * else: self.logger.warning('with_frontend is set to False, you need to add_frontend() by yourself') - @_build_level(BuildLevel.GRAPH) + @build_required(BuildLevel.GRAPH) def to_k8s_yaml(self) -> str: raise NotImplementedError - @_build_level(BuildLevel.GRAPH) + @build_required(BuildLevel.GRAPH) def to_shell_script(self) -> str: raise NotImplementedError - @_build_level(BuildLevel.GRAPH) + @build_required(BuildLevel.GRAPH) def to_swarm_yaml(self, image: str = 'gnes/gnes:latest-alpine') -> str: """ Generate the docker swarm YAML compose file @@ -175,7 +91,7 @@ def to_swarm_yaml(self, image: str = 'gnes/gnes:latest-alpine') -> str: 'services': {}} for k, v in self._service_nodes.items(): - defaults_kwargs, _ = Flow._service_map[v['service']]['parser']().parse_known_args( + defaults_kwargs, _ = service_map[v['service']]['parser']().parse_known_args( ['--yaml_path', 'TrainableBase']) non_default_kwargs = {k: v for k, v in vars(v['parsed_args']).items() if getattr(defaults_kwargs, k) != v} if not isinstance(non_default_kwargs.get('yaml_path', ''), str): @@ -188,7 +104,7 @@ def to_swarm_yaml(self, image: str = 'gnes/gnes:latest-alpine') -> str: swarm_yml['services'][k] = { 'image': v['kwargs'].get('image', image), 'command': '%s %s' % ( - Flow._service_map[v['service']]['cmd'], + service_map[v['service']]['cmd'], ' '.join(['--%s %s' % (k, v) for k, v in non_default_kwargs.items()])) } if num_replicas and num_replicas > 1: @@ -254,7 +170,7 @@ def to_python_code(self, indent: int = 4) -> str: return '\n'.join(py_code) - @_build_level(BuildLevel.GRAPH) + @build_required(BuildLevel.GRAPH) def to_mermaid(self, left_right: bool = True) -> str: """ Output the mermaid graph for visualization @@ -342,7 +258,7 @@ def to_mermaid(self, left_right: bool = True) -> str: return mermaid_str - @_build_level(BuildLevel.GRAPH) + @build_required(BuildLevel.GRAPH) def to_url(self, **kwargs) -> str: """ Rendering the current flow as a url points to a SVG, it needs internet connection @@ -355,7 +271,7 @@ def to_url(self, **kwargs) -> str: encoded_str = base64.b64encode(bytes(mermaid_str, 'utf-8')).decode('utf-8') return 'https://mermaidjs.github.io/mermaid-live-editor/#/view/%s' % encoded_str - @_build_level(BuildLevel.GRAPH) + @build_required(BuildLevel.GRAPH) def to_jpg(self, path: str = 'flow.jpg', **kwargs) -> None: """ Rendering the current flow as a jpg image, this will call :py:meth:`to_mermaid` and it needs internet connection @@ -405,8 +321,11 @@ def query(self, bytes_gen: Iterator[bytes] = None, **kwargs): """ self._call_client(bytes_gen, mode='query', **kwargs) - @_build_level(BuildLevel.RUNTIME) + @build_required(BuildLevel.RUNTIME) def _call_client(self, bytes_gen: Iterator[bytes] = None, **kwargs): + from ..cli.parser import set_client_cli_parser + from ..client.cli import CLIClient + args, p_args, unk_args = self._get_parsed_args(self, set_client_cli_parser, kwargs) p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host @@ -458,7 +377,7 @@ def set_last_service(self, name: str, copy_flow: bool = True) -> 'Flow': # graph is now changed so we need to # reset the build level to the lowest - op_flow._build_level = Flow.BuildLevel.EMPTY + op_flow._build_level = BuildLevel.EMPTY return op_flow @@ -525,7 +444,7 @@ def set(self, name: str, recv_from: Union[str, Tuple[str], List[str], 'Service'] if not clear_old_attr: node['kwargs'].update(kwargs) kwargs = node['kwargs'] - args, p_args, unk_args = op_flow._get_parsed_args(op_flow, Flow._service_map[service]['parser'], kwargs) + args, p_args, unk_args = op_flow._get_parsed_args(op_flow, service_map[service]['parser'], kwargs) node.update({ 'args': args, 'parsed_args': p_args, @@ -538,7 +457,7 @@ def set(self, name: str, recv_from: Union[str, Tuple[str], List[str], 'Service'] # graph is now changed so we need to # reset the build level to the lowest - op_flow._build_level = Flow.BuildLevel.EMPTY + op_flow._build_level = BuildLevel.EMPTY return op_flow @@ -572,7 +491,7 @@ def remove(self, name: str = None, copy_flow: bool = True) -> 'Flow': # graph is now changed so we need to # reset the build level to the lowest - op_flow._build_level = Flow.BuildLevel.EMPTY + op_flow._build_level = BuildLevel.EMPTY return op_flow @@ -603,8 +522,8 @@ def add(self, service: Union['Service', str], if isinstance(service, str): service = Service.from_string(service) - if service not in Flow._service_map: - raise ValueError('service: %s is not supported, should be one of %s' % (service, Flow._service_map.keys())) + if service not in service_map: + raise ValueError('service: %s is not supported, should be one of %s' % (service, service_map.keys())) if name in op_flow._service_nodes: raise FlowTopologyError('name: %s is used in this Flow already!' % name) @@ -622,7 +541,7 @@ def add(self, service: Union['Service', str], recv_from = op_flow._parse_service_endpoints(op_flow, name, recv_from, connect_to_last_service=True) send_to = op_flow._parse_service_endpoints(op_flow, name, send_to, connect_to_last_service=False) - args, p_args, unk_args = op_flow._get_parsed_args(op_flow, Flow._service_map[service]['parser'], kwargs) + args, p_args, unk_args = op_flow._get_parsed_args(op_flow, service_map[service]['parser'], kwargs) op_flow._service_nodes[name] = { 'service': service, @@ -644,7 +563,7 @@ def add(self, service: Union['Service', str], # graph is now changed so we need to # reset the build level to the lowest - op_flow._build_level = Flow.BuildLevel.EMPTY + op_flow._build_level = BuildLevel.EMPTY return op_flow @@ -703,7 +622,7 @@ def _build_graph(self, copy_flow: bool) -> 'Flow': op_flow._service_edges.clear() if not op_flow._frontend: - raise FlowImcompleteError('frontend does not exist, you may need to add_frontend()') + raise FlowIncompleteError('frontend does not exist, you may need to add_frontend()') if not op_flow._last_changed_service or not op_flow._service_nodes: raise FlowTopologyError('flow is empty?') @@ -780,7 +699,7 @@ def _build_graph(self, copy_flow: bool) -> 'Flow': 'i can not determine the socket type' % ( len(edges_with_same_start), start_node, len(edges_with_same_end), end_node)) - op_flow._build_level = Flow.BuildLevel.GRAPH + op_flow._build_level = BuildLevel.GRAPH return op_flow def build(self, backend: Optional[str] = 'thread', copy_flow: bool = False, *args, **kwargs) -> 'Flow': @@ -805,8 +724,8 @@ def build(self, backend: Optional[str] = 'thread', copy_flow: bool = False, *arg # for thread and process backend which runs locally, host_in and host_out should not be set p_args.host_in = BaseService.default_host p_args.host_out = BaseService.default_host - op_flow._service_contexts.append((Flow._service_map[v['service']]['builder'], p_args)) - op_flow._build_level = Flow.BuildLevel.RUNTIME + op_flow._service_contexts.append((service_map[v['service']]['builder'], p_args)) + op_flow._build_level = BuildLevel.RUNTIME else: raise NotImplementedError('backend=%s is not supported yet' % backend) @@ -816,7 +735,7 @@ def __call__(self, *args, **kwargs): return self.build(*args, **kwargs) def __enter__(self): - if self._build_level.value < Flow.BuildLevel.RUNTIME.value: + if self._build_level.value < BuildLevel.RUNTIME.value: self.logger.warning( 'current build_level=%s, lower than required. ' 'build the flow now via build() with default parameters' % self._build_level) @@ -831,7 +750,7 @@ def __enter__(self): def close(self): if hasattr(self, '_service_stack'): self._service_stack.close() - self._build_level = Flow.BuildLevel.EMPTY + self._build_level = BuildLevel.EMPTY self.logger.critical( 'flow is closed and all resources should be released already, current build level is %s' % self._build_level) @@ -844,12 +763,12 @@ def __eq__(self, other): :return: """ - if self._build_level.value < Flow.BuildLevel.GRAPH.value: + if self._build_level.value < BuildLevel.GRAPH.value: a = self.build(backend=None, copy_flow=True) else: a = self - if other._build_level.value < Flow.BuildLevel.GRAPH.value: + if other._build_level.value < BuildLevel.GRAPH.value: b = other.build(backend=None, copy_flow=True) else: b = other diff --git a/gnes/flow/helper.py b/gnes/flow/helper.py new file mode 100644 index 00000000..3bef25e2 --- /dev/null +++ b/gnes/flow/helper.py @@ -0,0 +1,88 @@ +from functools import wraps + +from ..cli.parser import set_router_parser, set_indexer_parser, \ + set_frontend_parser, set_preprocessor_parser, \ + set_encoder_parser +from ..service.base import BetterEnum, ServiceManager +from ..service.encoder import EncoderService +from ..service.frontend import FrontendService +from ..service.indexer import IndexerService +from ..service.preprocessor import PreprocessorService +from ..service.router import RouterService + + +class BuildLevel(BetterEnum): + EMPTY = 0 + GRAPH = 1 + RUNTIME = 2 + + +class Service(BetterEnum): + Frontend = 0 + Encoder = 1 + Router = 2 + Indexer = 3 + Preprocessor = 4 + + +class FlowIncompleteError(ValueError): + """Exception when the flow missing some important component to run""" + + +class FlowTopologyError(ValueError): + """Exception when the topology is ambiguous""" + + +class FlowMissingNode(ValueError): + """Exception when the topology is ambiguous""" + + +class FlowBuildLevelMismatch(ValueError): + """Exception when required level is higher than the current build level""" + + +def build_required(required_level: 'BuildLevel'): + def __build_level(func): + @wraps(func) + def arg_wrapper(self, *args, **kwargs): + if hasattr(self, '_build_level'): + if self._build_level.value >= required_level.value: + return func(self, *args, **kwargs) + else: + raise FlowBuildLevelMismatch( + 'build_level check failed for %r, required level: %s, actual level: %s' % ( + func, required_level, self._build_level)) + else: + raise AttributeError('%r has no attribute "_build_level"' % self) + + return arg_wrapper + + return __build_level + + +service_map = { + Service.Encoder: { + 'parser': set_encoder_parser, + 'builder': lambda x: ServiceManager(EncoderService, x), + 'cmd': 'encode'}, + Service.Router: { + 'parser': set_router_parser, + 'builder': lambda x: ServiceManager(RouterService, x), + 'cmd': 'route', + }, + Service.Indexer: { + 'parser': set_indexer_parser, + 'builder': lambda x: ServiceManager(IndexerService, x), + 'cmd': 'index' + }, + Service.Frontend: { + 'parser': set_frontend_parser, + 'builder': FrontendService, + 'cmd': 'frontend' + }, + Service.Preprocessor: { + 'parser': set_preprocessor_parser, + 'builder': lambda x: ServiceManager(PreprocessorService, x), + 'cmd': 'preprocess' + } +}