Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
refactor(flow): reorganize file structure for flow
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Oct 15, 2019
1 parent 8a60c26 commit 48d1828
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 113 deletions.
145 changes: 32 additions & 113 deletions gnes/flow/__init__.py
@@ -1,64 +1,12 @@
import copy
from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from functools import wraps
from typing import Union, Tuple, List, Optional, Iterator

from .helper import *
from ..base import TrainableBase
from ..cli.parser import set_router_parser, set_indexer_parser, \
set_frontend_parser, set_preprocessor_parser, \
set_encoder_parser, set_client_cli_parser
from ..client.cli import CLIClient
from ..helper import set_logger
from ..service.base import SocketType, BaseService, BetterEnum, ServiceManager
from ..service.encoder import EncoderService
from ..service.frontend import FrontendService
from ..service.indexer import IndexerService
from ..service.preprocessor import PreprocessorService
from ..service.router import RouterService


class Service(BetterEnum):
Frontend = 0
Encoder = 1
Router = 2
Indexer = 3
Preprocessor = 4


class FlowImcompleteError(ValueError):
"""Exception when the flow missing some important component to run"""


class FlowTopologyError(ValueError):
"""Exception when the topology is ambiguous"""


class FlowMissingNode(ValueError):
"""Exception when the topology is ambiguous"""


class FlowBuildLevelMismatch(ValueError):
"""Exception when required level is higher than the current build level"""


def _build_level(required_level: 'Flow.BuildLevel'):
def __build_level(func):
@wraps(func)
def arg_wrapper(self, *args, **kwargs):
if hasattr(self, '_build_level'):
if self._build_level.value >= required_level.value:
return func(self, *args, **kwargs)
else:
raise FlowBuildLevelMismatch(
'build_level check failed for %r, required level: %s, actual level: %s' % (
func, required_level, self._build_level))
else:
raise AttributeError('%r has no attribute "_build_level"' % self)

return arg_wrapper

return __build_level
from ..service.base import SocketType, BaseService


class Flow(TrainableBase):
Expand Down Expand Up @@ -92,38 +40,6 @@ class Flow(TrainableBase):
"""

_service_map = {
Service.Encoder: {
'parser': set_encoder_parser,
'builder': lambda x: ServiceManager(EncoderService, x),
'cmd': 'encode'},
Service.Router: {
'parser': set_router_parser,
'builder': lambda x: ServiceManager(RouterService, x),
'cmd': 'route',
},
Service.Indexer: {
'parser': set_indexer_parser,
'builder': lambda x: ServiceManager(IndexerService, x),
'cmd': 'index'
},
Service.Frontend: {
'parser': set_frontend_parser,
'builder': FrontendService,
'cmd': 'frontend'
},
Service.Preprocessor: {
'parser': set_preprocessor_parser,
'builder': lambda x: ServiceManager(PreprocessorService, x),
'cmd': 'preprocess'
}
}

class BuildLevel(BetterEnum):
EMPTY = 0
GRAPH = 1
RUNTIME = 2

def __init__(self, with_frontend: bool = True, is_trained: bool = True, *args, **kwargs):
"""
Create a new Flow object.
Expand All @@ -137,13 +53,13 @@ def __init__(self, with_frontend: bool = True, is_trained: bool = True, *args, *
self.logger = set_logger(self.__class__.__name__)
self._service_nodes = OrderedDict()
self._service_edges = {}
self._service_name_counter = {k: 0 for k in Flow._service_map.keys()}
self._service_name_counter = {k: 0 for k in service_map.keys()}
self._service_contexts = []
self._last_changed_service = []
self._common_kwargs = kwargs
self._frontend = None
self._client = None
self._build_level = Flow.BuildLevel.EMPTY
self._build_level = BuildLevel.EMPTY
self._backend = None
self._init_with_frontend = False
self.is_trained = is_trained
Expand All @@ -153,15 +69,15 @@ def __init__(self, with_frontend: bool = True, is_trained: bool = True, *args, *
else:
self.logger.warning('with_frontend is set to False, you need to add_frontend() by yourself')

@_build_level(BuildLevel.GRAPH)
@build_required(BuildLevel.GRAPH)
def to_k8s_yaml(self) -> str:
raise NotImplementedError

@_build_level(BuildLevel.GRAPH)
@build_required(BuildLevel.GRAPH)
def to_shell_script(self) -> str:
raise NotImplementedError

@_build_level(BuildLevel.GRAPH)
@build_required(BuildLevel.GRAPH)
def to_swarm_yaml(self, image: str = 'gnes/gnes:latest-alpine') -> str:
"""
Generate the docker swarm YAML compose file
Expand All @@ -175,7 +91,7 @@ def to_swarm_yaml(self, image: str = 'gnes/gnes:latest-alpine') -> str:
'services': {}}

for k, v in self._service_nodes.items():
defaults_kwargs, _ = Flow._service_map[v['service']]['parser']().parse_known_args(
defaults_kwargs, _ = service_map[v['service']]['parser']().parse_known_args(
['--yaml_path', 'TrainableBase'])
non_default_kwargs = {k: v for k, v in vars(v['parsed_args']).items() if getattr(defaults_kwargs, k) != v}
if not isinstance(non_default_kwargs.get('yaml_path', ''), str):
Expand All @@ -188,7 +104,7 @@ def to_swarm_yaml(self, image: str = 'gnes/gnes:latest-alpine') -> str:
swarm_yml['services'][k] = {
'image': v['kwargs'].get('image', image),
'command': '%s %s' % (
Flow._service_map[v['service']]['cmd'],
service_map[v['service']]['cmd'],
' '.join(['--%s %s' % (k, v) for k, v in non_default_kwargs.items()]))
}
if num_replicas and num_replicas > 1:
Expand Down Expand Up @@ -254,7 +170,7 @@ def to_python_code(self, indent: int = 4) -> str:

return '\n'.join(py_code)

@_build_level(BuildLevel.GRAPH)
@build_required(BuildLevel.GRAPH)
def to_mermaid(self, left_right: bool = True) -> str:
"""
Output the mermaid graph for visualization
Expand Down Expand Up @@ -342,7 +258,7 @@ def to_mermaid(self, left_right: bool = True) -> str:

return mermaid_str

@_build_level(BuildLevel.GRAPH)
@build_required(BuildLevel.GRAPH)
def to_url(self, **kwargs) -> str:
"""
Rendering the current flow as a url points to a SVG, it needs internet connection
Expand All @@ -355,7 +271,7 @@ def to_url(self, **kwargs) -> str:
encoded_str = base64.b64encode(bytes(mermaid_str, 'utf-8')).decode('utf-8')
return 'https://mermaidjs.github.io/mermaid-live-editor/#/view/%s' % encoded_str

@_build_level(BuildLevel.GRAPH)
@build_required(BuildLevel.GRAPH)
def to_jpg(self, path: str = 'flow.jpg', **kwargs) -> None:
"""
Rendering the current flow as a jpg image, this will call :py:meth:`to_mermaid` and it needs internet connection
Expand Down Expand Up @@ -405,8 +321,11 @@ def query(self, bytes_gen: Iterator[bytes] = None, **kwargs):
"""
self._call_client(bytes_gen, mode='query', **kwargs)

@_build_level(BuildLevel.RUNTIME)
@build_required(BuildLevel.RUNTIME)
def _call_client(self, bytes_gen: Iterator[bytes] = None, **kwargs):
from ..cli.parser import set_client_cli_parser
from ..client.cli import CLIClient

args, p_args, unk_args = self._get_parsed_args(self, set_client_cli_parser, kwargs)
p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port
p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host
Expand Down Expand Up @@ -458,7 +377,7 @@ def set_last_service(self, name: str, copy_flow: bool = True) -> 'Flow':

# graph is now changed so we need to
# reset the build level to the lowest
op_flow._build_level = Flow.BuildLevel.EMPTY
op_flow._build_level = BuildLevel.EMPTY

return op_flow

Expand Down Expand Up @@ -525,7 +444,7 @@ def set(self, name: str, recv_from: Union[str, Tuple[str], List[str], 'Service']
if not clear_old_attr:
node['kwargs'].update(kwargs)
kwargs = node['kwargs']
args, p_args, unk_args = op_flow._get_parsed_args(op_flow, Flow._service_map[service]['parser'], kwargs)
args, p_args, unk_args = op_flow._get_parsed_args(op_flow, service_map[service]['parser'], kwargs)
node.update({
'args': args,
'parsed_args': p_args,
Expand All @@ -538,7 +457,7 @@ def set(self, name: str, recv_from: Union[str, Tuple[str], List[str], 'Service']

# graph is now changed so we need to
# reset the build level to the lowest
op_flow._build_level = Flow.BuildLevel.EMPTY
op_flow._build_level = BuildLevel.EMPTY

return op_flow

Expand Down Expand Up @@ -572,7 +491,7 @@ def remove(self, name: str = None, copy_flow: bool = True) -> 'Flow':

# graph is now changed so we need to
# reset the build level to the lowest
op_flow._build_level = Flow.BuildLevel.EMPTY
op_flow._build_level = BuildLevel.EMPTY

return op_flow

Expand Down Expand Up @@ -603,8 +522,8 @@ def add(self, service: Union['Service', str],
if isinstance(service, str):
service = Service.from_string(service)

if service not in Flow._service_map:
raise ValueError('service: %s is not supported, should be one of %s' % (service, Flow._service_map.keys()))
if service not in service_map:
raise ValueError('service: %s is not supported, should be one of %s' % (service, service_map.keys()))

if name in op_flow._service_nodes:
raise FlowTopologyError('name: %s is used in this Flow already!' % name)
Expand All @@ -622,7 +541,7 @@ def add(self, service: Union['Service', str],
recv_from = op_flow._parse_service_endpoints(op_flow, name, recv_from, connect_to_last_service=True)
send_to = op_flow._parse_service_endpoints(op_flow, name, send_to, connect_to_last_service=False)

args, p_args, unk_args = op_flow._get_parsed_args(op_flow, Flow._service_map[service]['parser'], kwargs)
args, p_args, unk_args = op_flow._get_parsed_args(op_flow, service_map[service]['parser'], kwargs)

op_flow._service_nodes[name] = {
'service': service,
Expand All @@ -644,7 +563,7 @@ def add(self, service: Union['Service', str],

# graph is now changed so we need to
# reset the build level to the lowest
op_flow._build_level = Flow.BuildLevel.EMPTY
op_flow._build_level = BuildLevel.EMPTY

return op_flow

Expand Down Expand Up @@ -703,7 +622,7 @@ def _build_graph(self, copy_flow: bool) -> 'Flow':
op_flow._service_edges.clear()

if not op_flow._frontend:
raise FlowImcompleteError('frontend does not exist, you may need to add_frontend()')
raise FlowIncompleteError('frontend does not exist, you may need to add_frontend()')

if not op_flow._last_changed_service or not op_flow._service_nodes:
raise FlowTopologyError('flow is empty?')
Expand Down Expand Up @@ -780,7 +699,7 @@ def _build_graph(self, copy_flow: bool) -> 'Flow':
'i can not determine the socket type' % (
len(edges_with_same_start), start_node, len(edges_with_same_end), end_node))

op_flow._build_level = Flow.BuildLevel.GRAPH
op_flow._build_level = BuildLevel.GRAPH
return op_flow

def build(self, backend: Optional[str] = 'thread', copy_flow: bool = False, *args, **kwargs) -> 'Flow':
Expand All @@ -805,8 +724,8 @@ def build(self, backend: Optional[str] = 'thread', copy_flow: bool = False, *arg
# for thread and process backend which runs locally, host_in and host_out should not be set
p_args.host_in = BaseService.default_host
p_args.host_out = BaseService.default_host
op_flow._service_contexts.append((Flow._service_map[v['service']]['builder'], p_args))
op_flow._build_level = Flow.BuildLevel.RUNTIME
op_flow._service_contexts.append((service_map[v['service']]['builder'], p_args))
op_flow._build_level = BuildLevel.RUNTIME
else:
raise NotImplementedError('backend=%s is not supported yet' % backend)

Expand All @@ -816,7 +735,7 @@ def __call__(self, *args, **kwargs):
return self.build(*args, **kwargs)

def __enter__(self):
if self._build_level.value < Flow.BuildLevel.RUNTIME.value:
if self._build_level.value < BuildLevel.RUNTIME.value:
self.logger.warning(
'current build_level=%s, lower than required. '
'build the flow now via build() with default parameters' % self._build_level)
Expand All @@ -831,7 +750,7 @@ def __enter__(self):
def close(self):
if hasattr(self, '_service_stack'):
self._service_stack.close()
self._build_level = Flow.BuildLevel.EMPTY
self._build_level = BuildLevel.EMPTY
self.logger.critical(
'flow is closed and all resources should be released already, current build level is %s' % self._build_level)

Expand All @@ -844,12 +763,12 @@ def __eq__(self, other):
:return:
"""

if self._build_level.value < Flow.BuildLevel.GRAPH.value:
if self._build_level.value < BuildLevel.GRAPH.value:
a = self.build(backend=None, copy_flow=True)
else:
a = self

if other._build_level.value < Flow.BuildLevel.GRAPH.value:
if other._build_level.value < BuildLevel.GRAPH.value:
b = other.build(backend=None, copy_flow=True)
else:
b = other
Expand Down

0 comments on commit 48d1828

Please sign in to comment.