Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(flow): add support to replicas plot
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Oct 12, 2019
1 parent fce94d9 commit 4055ad8
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 18 deletions.
98 changes: 80 additions & 18 deletions gnes/flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,30 +142,81 @@ def to_mermaid(self, left_right: bool = True):
Output the mermaid graph for visualization
:param left_right: render the flow in left-to-right manner, otherwise top-down manner.
:return:
:return: a mermaid-formatted string
"""

# fill, stroke
service_color = {
Service.Frontend: ('#FFE0E0', '#000'),
Service.Router: ('#C9E8D2', '#000'),
Service.Encoder: ('#FFDAAF', '#000'),
Service.Preprocessor: ('#CED7EF', '#000'),
Service.Indexer: ('#FFFBC1', '#000'),
}

mermaid_graph = OrderedDict()
for k in self._service_nodes.keys():
mermaid_graph[k] = []
cls_dict = defaultdict(set)
replicas_dict = {}

for k, v in self._service_nodes.items():
mermaid_graph[k] = []
num_replicas = getattr(v['parsed_args'], 'num_parallel', 1)
if num_replicas > 1:
head_router = k + '_HEAD'
tail_router = k + '_TAIL'
replicas_dict[k] = (head_router, tail_router)
cls_dict[Service.Router].add(head_router)
cls_dict[Service.Router].add(tail_router)
p_r = '((%s))'
k_service = v['service']
p_e = '((%s))' if k_service == Service.Router else '(%s)'

mermaid_graph[k].append('subgraph %s["%s (replias=%d)"]' % (k, k, num_replicas))
for j in range(num_replicas):
r = k + '_%d' % j
cls_dict[k_service].add(r)
mermaid_graph[k].append('\t%s%s-->%s%s' % (head_router, p_r % 'router', r, p_e % r))
mermaid_graph[k].append('\t%s%s-->%s%s' % (r, p_e % r, tail_router, p_r % 'router'))
mermaid_graph[k].append('end')
mermaid_graph[k].append(
'style %s fill:%s,stroke:%s,stroke-width:2px,stroke-dasharray:5,stroke-opacity:0.3,fill-opacity:0.5' % (
k, service_color[k_service][0], service_color[k_service][1]))

for k, ed_type in self._service_edges.items():
start_node, end_node = k.split('-')
cur_node = mermaid_graph[start_node]

s_service = self._service_nodes[start_node]['service']
e_service = self._service_nodes[end_node]['service']

start_node_text = start_node
end_node_text = end_node

# check if is in replicas
if start_node in replicas_dict:
start_node = replicas_dict[start_node][1] # outgoing
s_service = Service.Router
start_node_text = 'router'
if end_node in replicas_dict:
end_node = replicas_dict[end_node][0] # incoming
e_service = Service.Router
end_node_text = 'router'

# always plot frontend at the start and the end
if e_service == Service.Frontend:
end_node_text = end_node
end_node += '_END'

cls_dict[s_service].add(start_node)
cls_dict[e_service].add(end_node)
p_s = '((%s))' if s_service == Service.Router else '(%s)'
p_e = '((%s))' if e_service == Service.Router else '(%s)'
mermaid_graph[start_node].append('\t%s%s-- %s -->%s%s' % (
start_node, p_s % start_node, ed_type,
end_node, p_e % end_node))

style = ['classDef FrontendCLS fill:#FFE0E0,stroke:#FFE0E0,stroke-width:1px;',
'classDef EncoderCLS fill:#FFDAAF,stroke:#FFDAAF,stroke-width:1px;',
'classDef IndexerCLS fill:#FFFBC1,stroke:#FFFBC1,stroke-width:1px;',
'classDef RouterCLS fill:#C9E8D2,stroke:#C9E8D2,stroke-width:1px;',
'classDef PreprocessorCLS fill:#CEEEEF,stroke:#CEEEEF,stroke-width:1px;']
cur_node.append('\t%s%s-- %s -->%s%s' % (
start_node, p_s % start_node_text, ed_type,
end_node, p_e % end_node_text))

style = ['classDef %sCLS fill:%s,stroke:%s,stroke-width:1px;' % (k, v[0], v[1]) for k, v in
service_color.items()]
class_def = ['class %s %sCLS;' % (','.join(v), k) for k, v in cls_dict.items()]
mermaid_str = '\n'.join(
['graph %s' % ('LR' if left_right else 'TD')] + [ss for s in mermaid_graph.values() for ss in
Expand All @@ -174,19 +225,30 @@ def to_mermaid(self, left_right: bool = True):
return mermaid_str

@_build_level(BuildLevel.GRAPH)
def to_jpg(self, path: str = 'flow.jpg', left_right: bool = True):
def to_url(self, **kwargs) -> str:
"""
Rendering the current flow as a url points to a SVG, it needs internet connection
:param kwargs: keyword arguments of :py:meth:`to_mermaid`
:return: the url points to a SVG
"""
import base64
mermaid_str = self.to_mermaid(**kwargs)
encoded_str = base64.b64encode(bytes(mermaid_str, 'utf-8')).decode('utf-8')
return 'https://mermaidjs.github.io/mermaid-live-editor/#/view/%s' % encoded_str

@_build_level(BuildLevel.GRAPH)
def to_jpg(self, path: str = 'flow.jpg', **kwargs):
"""
Rendering the current flow as a jpg image, this will call :py:meth:`to_mermaid` and it needs internet connection
:param path: the file path of the image
:param left_right: render the flow in left-to-right manner, otherwise top-down manner.
:param kwargs: keyword arguments of :py:meth:`to_mermaid`
:return:
"""
import base64

from urllib.request import Request, urlopen
mermaid_str = self.to_mermaid(left_right)
encoded_str = base64.b64encode(bytes(mermaid_str, 'utf-8')).decode('utf-8')
print('https://mermaidjs.github.io/mermaid-live-editor/#/view/%s' % encoded_str)
encoded_str = self.to_url().replace('https://mermaidjs.github.io/mermaid-live-editor/#/view/', '')
self.logger.info('saving jpg...')
req = Request('https://mermaid.ink/img/%s' % encoded_str, headers={'User-Agent': 'Mozilla/5.0'})
with open(path, 'wb') as fp:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_gnes_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ def test_flow5(self):
print(f.to_mermaid())
f.to_jpg()

def test_flow_replica_pot(self):
f = (Flow(check_version=False, route_table=True)
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor', replicas=4)
.add(gfs.Encoder, yaml_path='PyTorchTransformers', replicas=3)
.add(gfs.Indexer, name='vec_idx', yaml_path='NumpyIndexer', replicas=2)
.add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', service_in='prep', replicas=2)
.add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter',
num_part=2, service_in=['vec_idx', 'doc_idx'])
.build(backend=None))
print(f.to_mermaid())
print(f.to_url(left_right=False))
print(f.to_url(left_right=True))

def _test_index_flow(self, backend):
for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]:
self.assertFalse(os.path.exists(k))
Expand Down Expand Up @@ -153,3 +166,13 @@ def test_index_query_flow(self):
def test_indexe_query_flow_proc(self):
self._test_index_flow('process')
self._test_query_flow('process')

def test_query_flow_plot(self):
flow = (Flow(check_version=False, route_table=False)
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor', replicas=2)
.add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'), replicas=3)
.add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml'),
replicas=4)
.add(gfs.Router, name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml'))
.add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml')))
print(flow.build(backend=None).to_url())

0 comments on commit 4055ad8

Please sign in to comment.