Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(flow): join is an alias of needs #762

Merged
merged 1 commit into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 16 additions & 7 deletions jina/flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ def _add_gateway(self, needs, **kwargs):
kwargs['name'] = 'gateway'
self._pod_nodes[pod_name] = GatewayFlowPod(kwargs, needs)

def join(self, needs: Union[Tuple[str], List[str]], uses: str = '_merge', name: str = 'joiner', *args,
**kwargs) -> 'Flow':
def needs(self, needs: Union[Tuple[str], List[str]], uses: str = '_merge', name: str = 'joiner', *args,
copy_flow: bool = True, **kwargs) -> 'Flow':
"""
Add a blocker to the flow, wait until all peas defined in **needs** completed.

Expand All @@ -356,9 +356,11 @@ def join(self, needs: Union[Tuple[str], List[str]], uses: str = '_merge', name:
:param name: the name of this joiner, by default is ``joiner``
:return: the modified flow
"""
op_flow = copy.deepcopy(self) if copy_flow else self

if len(needs) <= 1:
raise FlowTopologyError('no need to wait for a single service, need len(needs) > 1')
return self.add(name=name, uses=uses, needs=needs, *args, **kwargs)
return op_flow.add(name=name, uses=uses, needs=needs, *args, **kwargs)

def add(self,
needs: Union[str, Tuple[str], List[str]] = None,
Expand Down Expand Up @@ -642,7 +644,8 @@ def index_lines(self, lines: Iterator[str] = None, filepath: str = None, size: i
:param kwargs: accepts all keyword arguments of `jina client` CLI
"""
from ..clients.python.io import input_lines
self._get_client(**kwargs).index(input_lines(lines, filepath, size, sampling_rate, read_mode), output_fn, **kwargs)
self._get_client(**kwargs).index(input_lines(lines, filepath, size, sampling_rate, read_mode), output_fn,
**kwargs)

def index_files(self, patterns: Union[str, List[str]], recursive: bool = True,
size: int = None, sampling_rate: float = None, read_mode: str = None,
Expand All @@ -661,7 +664,8 @@ def index_files(self, patterns: Union[str, List[str]], recursive: bool = True,
:param kwargs: accepts all keyword arguments of `jina client` CLI
"""
from ..clients.python.io import input_files
self._get_client(**kwargs).index(input_files(patterns, recursive, size, sampling_rate, read_mode), output_fn, **kwargs)
self._get_client(**kwargs).index(input_files(patterns, recursive, size, sampling_rate, read_mode), output_fn,
**kwargs)

def search_files(self, patterns: Union[str, List[str]], recursive: bool = True,
size: int = None, sampling_rate: float = None, read_mode: str = None,
Expand All @@ -680,7 +684,8 @@ def search_files(self, patterns: Union[str, List[str]], recursive: bool = True,
:param kwargs: accepts all keyword arguments of `jina client` CLI
"""
from ..clients.python.io import input_files
self._get_client(**kwargs).search(input_files(patterns, recursive, size, sampling_rate, read_mode), output_fn, **kwargs)
self._get_client(**kwargs).search(input_files(patterns, recursive, size, sampling_rate, read_mode), output_fn,
**kwargs)

def search_lines(self, filepath: str = None, lines: Iterator[str] = None, size: int = None,
sampling_rate: float = None, read_mode='r',
Expand All @@ -698,7 +703,8 @@ def search_lines(self, filepath: str = None, lines: Iterator[str] = None, size:
:param kwargs: accepts all keyword arguments of `jina client` CLI
"""
from ..clients.python.io import input_lines
self._get_client(**kwargs).search(input_lines(lines, filepath, size, sampling_rate, read_mode), output_fn, **kwargs)
self._get_client(**kwargs).search(input_lines(lines, filepath, size, sampling_rate, read_mode), output_fn,
**kwargs)

@deprecated_alias(buffer='input_fn', callback='output_fn')
def index(self, input_fn: Union[Iterator['jina_pb2.Document'], Iterator[bytes], Callable] = None,
Expand Down Expand Up @@ -834,3 +840,6 @@ def use_grpc_gateway(self):
def use_rest_gateway(self):
"""Change to use REST gateway for IO """
self._common_kwargs['rest_api'] = True

# for backward support
join = needs
8 changes: 2 additions & 6 deletions jina/logging/sse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

import logging
import os

from . import default_logger
from .queue import __sse_queue__, __profile_queue__
from .. import JINA_GLOBAL, __version__
Expand Down Expand Up @@ -51,9 +48,9 @@ def start_sse_logger(server_config_path: str, flow_yaml: str = None):
JINA_GLOBAL.logserver.address = f'http://{_config["host"]}:{_config["port"]}'

JINA_GLOBAL.logserver.ready = JINA_GLOBAL.logserver.address + \
_config['endpoints']['ready']
_config['endpoints']['ready']
JINA_GLOBAL.logserver.shutdown = JINA_GLOBAL.logserver.address + \
_config['endpoints']['shutdown']
_config['endpoints']['shutdown']

app = Flask(__name__)
CORS(app)
Expand All @@ -79,7 +76,6 @@ def _profile_stream():
yield 'PROFILE ENDS\n\n'
break


@app.route(_config['endpoints']['log'])
def get_log():
"""Get the logs, endpoint `/log/stream` """
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/flow/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def test_refactor_num_part_proxy_2(self):
f = (Flow().add(name='r1', uses='_logforward')
.add(name='r2', uses='_logforward', needs='r1', parallel=2)
.add(name='r3', uses='_logforward', needs='r1', parallel=3, polling='ALL')
.join(['r2', 'r3']))
.needs(['r2', 'r3']))

with f:
f.index_lines(lines=['abbcs', 'efgh'])
Expand Down