From e5df803e0f272e3e202c5f99a3b80a6ea9247cca Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Sat, 9 Jan 2021 22:55:55 +0530 Subject: [PATCH] feat(gateway): websocket based streaming and client (#1608) * feat(gateway): websocket streaming in gateway server * feat(gateway): websocket streaming client * feat(gateway): enabling flow with websocket client * feat(gateway): enabling flow with websocket client * test(restful): added restful to flow unit tests * test(restful): added restful to eval flow integration tests * test(restful): added restful to evaluation integration tests * test(restful): added restful to gateway non blocking integration tests * test(restful): added restful to other integration tests * ci: change tests timeout to 40 mins * fix(gateway): allow all encodings * fix(gateway): send response based on received encodings * style: fix coding style * style: fix coding style * style: fix coding style * style: fix coding style * style: fix coding style * fix(indexer): fix indexer keys assert * style: fix coding style Co-authored-by: Han Xiao --- .github/workflows/cd.yml | 19 +- .github/workflows/ci.yml | 8 +- .github/workflows/latency-tracking.yml | 2 +- daemon/api/endpoints/flow.py | 2 +- daemon/api/endpoints/pea.py | 4 +- daemon/api/endpoints/pod.py | 4 +- daemon/store.py | 10 +- extra-requirements.txt | 3 +- jina/clients/__init__.py | 64 +++++-- jina/clients/asyncio.py | 44 +++++ jina/clients/base.py | 3 +- jina/clients/helper.py | 2 +- jina/clients/request.py | 2 +- jina/clients/websockets.py | 89 ++++++++++ jina/docker/hubapi.py | 4 +- jina/docker/hubio.py | 6 +- jina/executors/__init__.py | 2 +- jina/executors/indexers/cache.py | 2 +- jina/flow/asyncio.py | 6 +- jina/flow/base.py | 13 +- .../peapods/runtimes/asyncio/rest/__init__.py | 57 +----- jina/peapods/runtimes/asyncio/rest/app.py | 162 ++++++++++++++++++ jina/peapods/runtimes/jinad/api.py | 10 +- jina/peapods/zmq/__init__.py | 4 +- jina/types/message/__init__.py | 2 +- .../eval_flow/test_flow_eval_pod.py | 36 ++-- .../flow-evaluate-from-file-parallel.yml | 2 + .../evaluation/flow-evaluate-from-file.yml | 2 + .../evaluation/flow-index-gt-parallel.yml | 2 + .../integration/evaluation/flow-index-gt.yml | 2 + ...w-parallel-evaluate-from-file-parallel.yml | 2 + .../evaluation/test_evaluation_from_file.py | 7 +- .../integration/gateway_non_blocking/flow.yml | 2 + .../test_gateway_non_blocking.py | 6 +- .../high_order_matches/test_adjacency.py | 7 +- .../test_incremental_indexing.py | 59 ++++--- .../test_unique_indexing.py | 22 ++- tests/integration/level_depth/flow-index.yml | 2 + tests/integration/level_depth/flow-query.yml | 2 + .../test_search_different_depths.py | 11 +- tests/integration/mime/test_mime.py | 36 ++-- tests/integration/mime/test_segmenter.py | 17 +- .../flow-embedding-multimodal-parallel.yml | 2 + .../flow-multimodal-all-types-parallel.yml | 2 + .../multimodal/test_multimodal_parallel.py | 12 +- .../jinad/integration/distributed/helpers.py | 2 +- .../test_index_query/test_integration.py | 1 - tests/unit/clients/python/test_client.py | 5 +- tests/unit/clients/python/test_on_err.py | 7 +- tests/unit/flow/test_asyncflow.py | 44 +++-- tests/unit/flow/test_flow.py | 102 ++++++----- tests/unit/flow/test_flow_before_after.py | 22 ++- tests/unit/flow/test_flow_except.py | 42 +++-- tests/unit/flow/test_flow_index.py | 27 +-- tests/unit/flow/test_flow_merge.py | 13 +- tests/unit/flow/test_flow_multimode.py | 17 +- tests/unit/flow/test_flow_skip.py | 26 ++- 57 files changed, 773 insertions(+), 292 deletions(-) create mode 100644 jina/clients/websockets.py create mode 100644 jina/peapods/runtimes/asyncio/rest/app.py diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 07594f2d391be..17331086cab2d 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -162,7 +162,7 @@ jobs: uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} - - name: Run test + - name: Prepare enviroment run: | docker login docker.pkg.github.com -u $GITHUB_ACTOR -p $GITHUB_TOKEN docker pull docker.pkg.github.com/jina-ai/jina/jina:test-pip @@ -171,6 +171,10 @@ jobs: pip install ".[cicd,test]" --no-cache-dir jina check export JINA_LOG_VERBOSITY="ERROR" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Test + run: | pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml -n 1 --timeout=120 -v --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/${{ matrix.test-path }} timeout-minutes: 20 env: @@ -207,20 +211,25 @@ jobs: uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + - name: Prepare enviroment run: | docker login docker.pkg.github.com -u $GITHUB_ACTOR -p $GITHUB_TOKEN docker pull docker.pkg.github.com/jina-ai/jina/jina:test-pip docker tag docker.pkg.github.com/jina-ai/jina/jina:test-pip jinaai/jina:test-pip python -m pip install --upgrade pip - pip install ".[cicd,daemon,test]" --no-cache-dir - docker build --build-arg PIP_TAG="[devel]" -f Dockerfiles/pip.Dockerfile -t jinaai/jina:test-pip . + pip install ".[cicd,test,daemon]" --no-cache-dir + jina check export JINA_LOG_VERBOSITY="ERROR" - pytest --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml -n 1 --timeout=120 -v tests/jinad/${{ matrix.test-path }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Test + run: | + pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml -n 1 --timeout=120 -v tests/jinad/${{ matrix.test-path }} timeout-minutes: 20 env: JINAHUB_USERNAME: ${{ secrets.JINAHUB_USERNAME }} JINAHUB_PASSWORD: ${{ secrets.JINAHUB_PASSWORD }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Upload coverage from test to Codecov uses: codecov/codecov-action@v1 if: ${{ matrix.python-version }} == 3.7 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5279a7b4add8e..a38bfbd898d57 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -125,7 +125,7 @@ jobs: uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + - name: Prepare enviroment run: | docker login docker.pkg.github.com -u $GITHUB_ACTOR -p $GITHUB_TOKEN docker pull docker.pkg.github.com/jina-ai/jina/jina:test-pip @@ -134,13 +134,17 @@ jobs: pip install ".[cicd,test]" --no-cache-dir jina check export JINA_LOG_VERBOSITY="ERROR" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Test + run: | pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml -n 1 --timeout=120 -v --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/${{ matrix.test-path }} timeout-minutes: 20 env: JINAHUB_USERNAME: ${{ secrets.JINAHUB_USERNAME }} JINAHUB_PASSWORD: ${{ secrets.JINAHUB_PASSWORD }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Check file existence + - name: Check codecov file id: check_files uses: andstor/file-existence-action@v1 with: diff --git a/.github/workflows/latency-tracking.yml b/.github/workflows/latency-tracking.yml index 7e5e31025bdbc..95fd73ab76f9a 100644 --- a/.github/workflows/latency-tracking.yml +++ b/.github/workflows/latency-tracking.yml @@ -30,7 +30,7 @@ jobs: cd latency docker build --build-arg JINA_VER=master . -t latency-tracking docker run -v $(pwd)/output:/workspace/output -v $(pwd)/original:/workspace/original latency-tracking - bash batch.sh 3 + bash batch.sh 2 pip install prettytable python ppstat.py > comment.txt - id: get-comment-body diff --git a/daemon/api/endpoints/flow.py b/daemon/api/endpoints/flow.py index 8fe04fb7380b0..f9dab5adfb59c 100644 --- a/daemon/api/endpoints/flow.py +++ b/daemon/api/endpoints/flow.py @@ -145,7 +145,7 @@ def __init__(self): detail=f'Invalid yaml file.') except FlowStartException as e: raise HTTPException(status_code=404, - detail=f'Flow couldn\'t get started: {repr(e)}') + detail=f'Flow couldn\'t get started: {e!r}') return { 'status_code': status.HTTP_200_OK, diff --git a/daemon/api/endpoints/pea.py b/daemon/api/endpoints/pea.py index 57ad7c7a40ed2..d03a6978398b9 100644 --- a/daemon/api/endpoints/pea.py +++ b/daemon/api/endpoints/pea.py @@ -56,9 +56,9 @@ async def _create( pea_id = pea_store._create(pea_arguments=pea_arguments) except PeaStartException as e: raise HTTPException(status_code=404, - detail=f'Pea couldn\'t get started: {repr(e)}') + detail=f'Pea couldn\'t get started: {e!r}') except Exception as e: - daemon_logger.error(f'Got an error while creating a pea {repr(e)}') + daemon_logger.error(f'Got an error while creating a pea {e!r}') raise HTTPException(status_code=404, detail=f'Something went wrong') return { diff --git a/daemon/api/endpoints/pod.py b/daemon/api/endpoints/pod.py index 61047106e6178..5d573f9bdc7d0 100644 --- a/daemon/api/endpoints/pod.py +++ b/daemon/api/endpoints/pod.py @@ -57,9 +57,9 @@ async def _create( pod_id = pod_store._create(pod_arguments=pod_arguments) except PodStartException as e: raise HTTPException(status_code=404, - detail=f'Pod couldn\'t get started: {repr(e)}') + detail=f'Pod couldn\'t get started: {e!r}') except Exception as e: - daemon_logger.error(f'Got an error while creating a pod {repr(e)}') + daemon_logger.error(f'Got an error while creating a pod {e!r}') raise HTTPException(status_code=404, detail=f'Something went wrong') return { diff --git a/daemon/store.py b/daemon/store.py index 4202045aa6444..3dd2f8cc2dd95 100644 --- a/daemon/store.py +++ b/daemon/store.py @@ -82,13 +82,13 @@ def _create(self, JAML.register(Flow) flow = JAML.load(yamlspec) except Exception as e: - self.logger.error(f'Got error while loading from yaml {repr(e)}') + self.logger.error(f'Got error while loading from yaml {e!r}') raise FlowYamlParseException elif isinstance(config, list): try: flow = self._build_with_pods(pod_args=config) except Exception as e: - self.logger.error(f'Got error while creating flows via pods: {repr(e)}') + self.logger.error(f'Got error while creating flows via pods: {e!r}') raise FlowCreationException else: raise FlowBadInputException(f'Not valid Flow config input {type(config)}') @@ -97,7 +97,7 @@ def _create(self, flow_id = uuid.UUID(flow.args.log_id) flow = self._start(context=flow) except Exception as e: - self.logger.critical(f'Got following error while starting the flow: {repr(e)}') + self.logger.critical(f'Got following error while starting the flow: {e!r}') raise FlowStartException(repr(e)) self._store[flow_id] = {} @@ -160,7 +160,7 @@ def _create(self, pod_arguments: Union[Dict, Namespace]): pod = Pod(pod_arguments) pod = self._start(context=pod) except Exception as e: - self.logger.critical(f'Got following error while starting the pod: {repr(e)}') + self.logger.critical(f'Got following error while starting the pod: {e!r}') raise PodStartException(repr(e)) self._store[pod_id] = {} @@ -195,7 +195,7 @@ def _create(self, pea_arguments: Union[Dict, Namespace]): pea = Pea(pea_arguments) pea = self._start(context=pea) except Exception as e: - self.logger.critical(f'Got following error while starting the pea: {repr(e)}') + self.logger.critical(f'Got following error while starting the pea: {e!r}') raise PeaStartException(repr(e)) self._store[pea_id] = {} diff --git a/extra-requirements.txt b/extra-requirements.txt index 626fc24f000a6..c52e0c5908847 100644 --- a/extra-requirements.txt +++ b/extra-requirements.txt @@ -66,7 +66,8 @@ mock: test requests: http, devel, test, daemon prettytable: devel, test sseclient-py: test -websockets: http, devel, test, daemon +websockets: http, devel, test, daemon, ws +wsproto: http, devel, test, ws, daemon pydantic: http, devel, test, daemon python-multipart: http, devel, test, daemon pytest-custom_exit_code: cicd, test \ No newline at end of file diff --git a/jina/clients/__init__.py b/jina/clients/__init__.py index 4b86f5c03ef76..19a7ef0d907a0 100644 --- a/jina/clients/__init__.py +++ b/jina/clients/__init__.py @@ -3,6 +3,7 @@ from . import request from .base import BaseClient, CallbackFnType, InputFnType +from .websockets import WebSocketClientMixin from .helper import callback_exec from .request import GeneratorSourceType from ..enums import RequestType @@ -10,8 +11,8 @@ class Client(BaseClient): - """A simple Python client for connecting to the gateway. - It manges the asyncio eventloop internally, so all interfaces are synchronous from the outside. + """A simple Python client for connecting to the gRPC gateway. + It manages the asyncio eventloop internally, so all interfaces are synchronous from the outside. """ @deprecated_alias(buffer='input_fn', callback='on_done', output_fn='on_done') @@ -70,10 +71,10 @@ def index(self, input_fn: InputFnType = None, @deprecated_alias(buffer='input_fn', callback='on_done', output_fn='on_done') def update(self, input_fn: InputFnType = None, - on_done: CallbackFnType = None, - on_error: CallbackFnType = None, - on_always: CallbackFnType = None, - **kwargs) -> None: + on_done: CallbackFnType = None, + on_error: CallbackFnType = None, + on_always: CallbackFnType = None, + **kwargs) -> None: """ :param input_fn: the input function that generates the content @@ -88,10 +89,10 @@ def update(self, input_fn: InputFnType = None, @deprecated_alias(buffer='input_fn', callback='on_done', output_fn='on_done') def delete(self, input_fn: InputFnType = None, - on_done: CallbackFnType = None, - on_error: CallbackFnType = None, - on_always: CallbackFnType = None, - **kwargs) -> None: + on_done: CallbackFnType = None, + on_error: CallbackFnType = None, + on_always: CallbackFnType = None, + **kwargs) -> None: """ :param input_fn: the input function that generates the content @@ -102,4 +103,45 @@ def delete(self, input_fn: InputFnType = None, :return: """ self.mode = RequestType.DELETE - return run_async(self._get_results, input_fn, on_done, on_error, on_always, **kwargs) \ No newline at end of file + return run_async(self._get_results, input_fn, on_done, on_error, on_always, **kwargs) + + +class WebSocketClient(Client, WebSocketClientMixin): + """A Python Client to stream requests from a Flow with a RESTGateway + :class:`WebSocketClient` shares the same interface as :class:`Client` and provides methods like + :meth:`index`, "meth:`search`, :meth:`train`, :meth:`update` & :meth:`delete`. + + It is used by default while running operations when we create a `Flow` with `rest_api=True` + + .. highlight:: python + .. code-block:: python + + from jina.flow import Flow + f = Flow(rest_api=True).add().add() + + with f: + f.index(['abc']) + + + :class:`WebSocketClient` can also be used to run operations for a remote Flow + + .. highlight:: python + .. code-block:: python + + # A Flow running on remote + from jina.flow import Flow + f = Flow(rest_api=True, port_expose=34567).add().add() + + with f: + f.block() + + # Local WebSocketClient running index & search + from jina.clients import WebSocketClient + + client = WebSocketClient(...) + client.index(...) + client.search(...) + + + :class:`WebSocketClient` internally handles an event loop to run operations asynchronously. + """ diff --git a/jina/clients/asyncio.py b/jina/clients/asyncio.py index 0063d9e100829..fc97979550c48 100644 --- a/jina/clients/asyncio.py +++ b/jina/clients/asyncio.py @@ -1,4 +1,5 @@ from .base import InputFnType, BaseClient, CallbackFnType +from .websockets import WebSocketClientMixin from ..enums import RequestType from ..helper import deprecated_alias @@ -97,3 +98,46 @@ async def index(self, input_fn: InputFnType = None, """ self.mode = RequestType.INDEX return await self._get_results(input_fn, on_done, on_error, on_always, **kwargs) + + +class AsyncWebSocketClient(AsyncClient, WebSocketClientMixin): + """ + :class:`AsyncWebSocketClient` is the asynchronous version of the :class:`WebSocketClient`. + They share the same interface, except in :class:`AsyncWebSocketClient` :meth:`train`, :meth:`index`, :meth:`search` + methods are coroutines (i.e. declared with the async/await syntax), simply calling them will not schedule them to be executed. + To actually run a coroutine, user need to put them in an eventloop, e.g. via ``asyncio.run()``, + ``asyncio.create_task()``. + + :class:`AsyncWebSocketClient` can be very useful in the integration settings, where Jina/Flow/Client is NOT the + main logic, but rather served as a part of other program. In this case, users often do not want to let Jina control + the ``asyncio.eventloop``. On contrary, :class:`WebSocketClient` is controlling and wrapping the eventloop + internally, making the Client looks synchronous from outside. + + For example, say you have the Flow running in remote. You want to use Client to connect to it do + some index and search, but meanwhile you have some other IO-bounded jobs and want to do them concurrently. + You can use :class:`AsyncWebSocketClient`, + + .. highlight:: python + .. code-block:: python + + from jina.clients.asyncio import AsyncWebSocketClient + + ac = AsyncWebSocketClient(...) + + async def jina_client_query(): + await ac.search(...) + + async def heavylifting(): + await other_library.download_big_files(...) + + async def concurrent_main(): + await asyncio.gather(jina_client_query(), heavylifting()) + + + if __name__ == '__main__': + # under python + asyncio.run(concurrent_main()) + + One can think of :class:`WebSocketClient` as Jina-managed eventloop, + whereas :class:`AsyncWebSocketClient` is self-managed eventloop. + """ diff --git a/jina/clients/base.py b/jina/clients/base.py index 8c64b86fd2f3f..330f7a99fb493 100644 --- a/jina/clients/base.py +++ b/jina/clients/base.py @@ -110,7 +110,8 @@ def input_fn(self, bytes_gen: InputFnType) -> None: else: self._input_fn = bytes_gen - async def _get_results(self, input_fn: Callable, + async def _get_results(self, + input_fn: Callable, on_done: Callable, on_error: Callable = None, on_always: Callable = None, **kwargs): diff --git a/jina/clients/helper.py b/jina/clients/helper.py index 33b1723e5a001..00001b6cc5055 100644 --- a/jina/clients/helper.py +++ b/jina/clients/helper.py @@ -61,7 +61,7 @@ def arg_wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as ex: - err_msg = f'uncaught exception in callback {func.__name__}(): {repr(ex)}' + err_msg = f'uncaught exception in callback {func.__name__}(): {ex!r}' if continue_on_error: logger.error(err_msg) else: diff --git a/jina/clients/request.py b/jina/clients/request.py index 4c15653779ec6..581f0c329bff5 100644 --- a/jina/clients/request.py +++ b/jina/clients/request.py @@ -82,7 +82,7 @@ def _generate(data: GeneratorSourceType, yield req except Exception as ex: # must be handled here, as grpc channel wont handle Python exception - default_logger.critical(f'input_fn is not valid! {repr(ex)}', exc_info=True) + default_logger.critical(f'input_fn is not valid! {ex!r}', exc_info=True) def index(*args, **kwargs): diff --git a/jina/clients/websockets.py b/jina/clients/websockets.py new file mode 100644 index 0000000000000..16639b7527439 --- /dev/null +++ b/jina/clients/websockets.py @@ -0,0 +1,89 @@ +import asyncio +from typing import Callable, List + +from .base import BaseClient +from .helper import callback_exec +from ..importer import ImportExtensions +from ..logging.profile import TimeContext, ProgressBar +from ..types.request import Request, Response + + +class WebSocketClientMixin(BaseClient): + async def _get_results(self, + input_fn: Callable, + on_done: Callable, + on_error: Callable = None, + on_always: Callable = None, **kwargs): + """ + :meth:`send_requests()` + Traverses through the request iterator + Sends each request & awaits :meth:`websocket.send()` + Sends & awaits `byte(True)` to acknowledge request iterator is empty + Traversal logic: + Starts an independent task :meth:`send_requests()` + Awaits on each response from :meth:`websocket.recv()` (done in an async loop) + This makes sure client makes concurrent invocations + Await exit strategy: + :meth:`send_requests()` keeps track of num_requests sent + Async recv loop keeps track of num_responses received + Client exits out of await when num_requests == num_responses + """ + with ImportExtensions(required=True): + import websockets + + result = [] # type: List['Response'] + self.input_fn = input_fn + req_iter, tname = self._get_requests(**kwargs) + try: + client_info = f'{self.args.host}:{self.args.port_expose}' + # setting `max_size` as None to avoid connection closure due to size of message + # https://websockets.readthedocs.io/en/stable/api.html?highlight=1009#module-websockets.protocol + + async with websockets.connect(f'ws://{client_info}/stream', max_size=None) as websocket: + # To enable websockets debug logs + # https://websockets.readthedocs.io/en/stable/cheatsheet.html#debugging + self.logger.success(f'Connected to the gateway at {client_info}') + self.num_requests = 0 + self.num_responses = 0 + + async def send_requests(request_iterator): + for next_request in request_iterator: + await websocket.send(next_request.SerializeToString()) + self.num_requests += 1 + # Server has no way of knowing when to stop the await on sending response back to the client + # We send one last message to say `request_iterator` is completed. + # On the client side, this :meth:`send` doesn't need to be awaited with a :meth:`recv` + await websocket.send(bytes(True)) + + with ProgressBar(task_name=tname) as p_bar, TimeContext(tname): + # Unlike gRPC, any arbitrary function (generator) cannot be passed via websockets. + # Simply iterating through the `req_iter` makes the request-response sequential. + # To make client unblocking, :func:`send_requests` and `recv_responses` are separate tasks + + asyncio.create_task(send_requests(request_iterator=req_iter)) + async for response_bytes in websocket: + # When we have a stream of responses, instead of doing `await websocket.recv()`, + # we need to traverse through the websocket to recv messages. + # https://websockets.readthedocs.io/en/stable/faq.html#why-does-the-server-close-the-connection-after-processing-one-message + + response = Request(response_bytes).to_response() + callback_exec(response=response, + on_error=on_error, + on_done=on_done, + on_always=on_always, + continue_on_error=self.args.continue_on_error, + logger=self.logger) + p_bar.update(self.args.batch_size) + if self.args.return_results: + result.append(response) + self.num_responses += 1 + if self.num_requests == self.num_responses: + break + + except websockets.exceptions.ConnectionClosedOK: + self.logger.warning(f'Client got disconnected from the websocket server') + except websockets.exceptions.WebSocketException as e: + self.logger.error(f'Got following error while streaming requests via websocket: {e!r}') + finally: + if self.args.return_results: + return result diff --git a/jina/docker/hubapi.py b/jina/docker/hubapi.py index 8e83c81fad0f8..6b96eb637c283 100644 --- a/jina/docker/hubapi.py +++ b/jina/docker/hubapi.py @@ -156,7 +156,7 @@ def _docker_auth(logger) -> Optional[Dict[str, str]]: else: logger.error(f'failed to fetch docker credentials') except Exception as exp: - logger.error(f'got an exception while fetching docker credentials {repr(exp)}') + logger.error(f'got an exception while fetching docker credentials {exp!r}') def _make_hub_table_with_local(manifests, local_manifests): @@ -240,7 +240,7 @@ def _register_to_mongodb(logger, summary: Dict = None): f'please login using command: {colored("jina hub login", attrs=["bold"])}') logger.error(f'got an error from the API: {response.text}') except Exception as exp: - logger.error(f'got an exception while invoking hubapi for push {repr(exp)}') + logger.error(f'got an exception while invoking hubapi for push {exp!r}') def _fetch_access_token(logger): diff --git a/jina/docker/hubio.py b/jina/docker/hubio.py index f3adcf7fcab8a..ed517f3bc3558 100644 --- a/jina/docker/hubio.py +++ b/jina/docker/hubio.py @@ -206,7 +206,7 @@ def push(self, name: str = None, readme_path: str = None, build_result: Dict = N if build_result.get('details', None) and build_result.get('build_history', None): self._write_slack_message(build_result, build_result['details'], build_result['build_history']) except Exception as e: - self.logger.error(f'Error when trying to push image {name}: {repr(e)}') + self.logger.error(f'Error when trying to push image {name}: {e!r}') if isinstance(e, ImageAlreadyExists): raise e @@ -263,7 +263,7 @@ def pull(self) -> None: self.logger.success( f'🎉 pulled {image_tag} ({image.short_id}) uncompressed size: {get_readable_size(image.attrs["Size"])}') except Exception as ex: - self.logger.error(f'can not pull image {self.args.name} from {self.args.registry} due to {repr(ex)}') + self.logger.error(f'can not pull image {self.args.name} from {self.args.registry} due to {ex!r}') def _check_docker_image(self, name: str) -> None: # check local image @@ -383,7 +383,7 @@ def build(self) -> Dict: self.logger.warning( f'Build successful. Tests failed at : {str(failed_test_levels)} levels. This could be due to the fact that the executor has non-installed external dependencies') except Exception as ex: - self.logger.error(f'something wrong while testing the build: {repr(ex)}') + self.logger.error(f'something wrong while testing the build: {ex!r}') ex = HubBuilderTestError(ex) _except_strs.append(repr(ex)) _excepts.append(ex) diff --git a/jina/executors/__init__.py b/jina/executors/__init__.py index 16264e3ec5b8c..151ff1479dd46 100644 --- a/jina/executors/__init__.py +++ b/jina/executors/__init__.py @@ -293,7 +293,7 @@ def __setstate__(self, d): self._post_init_wrapper(fill_in_metas=False) except ModuleNotFoundError as ex: self.logger.warning(f'{typename(ex)} is often caused by a missing component, ' - f'which often can be solved by "pip install" relevant package: {repr(ex)}', + f'which often can be solved by "pip install" relevant package: {ex!r}', exc_info=True) def train(self, *args, **kwargs) -> None: diff --git a/jina/executors/indexers/cache.py b/jina/executors/indexers/cache.py index 9509a68ba9a5d..52b24aabc0ec1 100644 --- a/jina/executors/indexers/cache.py +++ b/jina/executors/indexers/cache.py @@ -42,7 +42,7 @@ def __init__(self, path, logger): self.content_hash = pickle.load(open(path + '.cache', 'rb')) except FileNotFoundError as e: logger.warning( - f'File path did not exist : {path}.ids or {path}.cache: {repr(e)}. Creating new CacheHandler...') + f'File path did not exist : {path}.ids or {path}.cache: {e!r}. Creating new CacheHandler...') self.ids = [] self.content_hash = [] diff --git a/jina/flow/asyncio.py b/jina/flow/asyncio.py index 94839107b193d..21823b711667a 100644 --- a/jina/flow/asyncio.py +++ b/jina/flow/asyncio.py @@ -1,7 +1,7 @@ from typing import Union, List, Iterator from .base import BaseFlow -from ..clients.asyncio import AsyncClient +from ..clients.asyncio import AsyncClient, AsyncWebSocketClient from ..clients.base import InputFnType, CallbackFnType from ..enums import DataInputType from ..helper import deprecated_alias @@ -71,6 +71,10 @@ async def concurrent_main(): """ _cls_client = AsyncClient #: the type of the Client, can be changed to other class + def _update_client(self): + if self._pod_nodes['gateway'].args.restful: + self._cls_client = AsyncWebSocketClient + @deprecated_alias(buffer='input_fn', callback='on_done', output_fn='on_done') async def train(self, input_fn: InputFnType = None, on_done: CallbackFnType = None, diff --git a/jina/flow/base.py b/jina/flow/base.py index c4e275bcafff9..d8551cbabf583 100644 --- a/jina/flow/base.py +++ b/jina/flow/base.py @@ -14,7 +14,7 @@ from .builder import build_required, _build_flow, _optimize_flow, _hanging_pods from .. import JINA_GLOBAL -from ..clients import Client +from ..clients import Client, WebSocketClient from ..enums import FlowBuildLevel, PodRoleType, FlowInspectType from ..excepts import FlowTopologyError, FlowMissingPodError from ..helper import colored, \ @@ -379,6 +379,7 @@ def build(self, copy_flow: bool = False) -> FlowLike: self.logger.warning(f'{hanging_pods} are hanging in this flow with no pod receiving from them, ' f'you may want to double check if it is intentional or some mistake') op_flow._build_level = FlowBuildLevel.GRAPH + self._update_client() return op_flow def __call__(self, *args, **kwargs): @@ -408,7 +409,7 @@ def _stop_log_server(self): # it may have been shutdown from the outside urllib.request.urlopen(JINA_GLOBAL.logserver.shutdown, timeout=5) except Exception as ex: - self.logger.info(f'Failed to connect to shutdown log sse server: {repr(ex)}') + self.logger.info(f'Failed to connect to shutdown log sse server: {ex!r}') def _start_log_server(self): try: @@ -428,13 +429,13 @@ def _start_log_server(self): if response.status == 200: self.logger.success(f'logserver is started and available at {JINA_GLOBAL.logserver.address}') except Exception as ex: - self.logger.error(f'Could not start logserver because of {repr(ex)}') + self.logger.error(f'Could not start logserver because of {ex!r}') except ModuleNotFoundError: self.logger.error( f'sse logserver can not start because of "flask" and "flask_cors" are missing, ' f'use pip install "jina[http]" (with double quotes) to install the dependencies') except Exception as ex: - self.logger.error(f'logserver fails to start: {repr(ex)}') + self.logger.error(f'logserver fails to start: {ex!r}') def start(self): """Start to run all Pods in this Flow. @@ -755,6 +756,10 @@ def __getitem__(self, item): else: raise TypeError(f'{typename(item)} is not supported') + def _update_client(self): + if self._pod_nodes['gateway'].args.restful: + self._cls_client = WebSocketClient + def index(self): raise NotImplementedError diff --git a/jina/peapods/runtimes/asyncio/rest/__init__.py b/jina/peapods/runtimes/asyncio/rest/__init__.py index 39da7f26e5de5..bf67664e3b178 100644 --- a/jina/peapods/runtimes/asyncio/rest/__init__.py +++ b/jina/peapods/runtimes/asyncio/rest/__init__.py @@ -1,13 +1,8 @@ -from typing import Any - -from google.protobuf.json_format import MessageToDict - +from .app import get_fastapi_app from ..base import AsyncZMQRuntime -from ..grpc.async_call import AsyncPrefetchCall -from ..... import clients -from .....enums import RequestType from .....importer import ImportExtensions + __all__ = ['RESTRuntime'] @@ -18,9 +13,14 @@ def setup(self): from uvicorn import Config, Server # change log_level for REST server debugging - self._server = Server(config=Config(app=self._get_fastapi_app(), + # TODO(Deepankar): The default `websockets` implementation needs the max_size to be set. + # But uvicorn doesn't expose a config for max_size of a ws message, hence falling back to `ws='wsproto'` + # Change to 'auto' once https://github.com/encode/uvicorn/pull/538 gets merged, + # as 'wsproto' is less performant and adds another dependency. + self._server = Server(config=Config(app=get_fastapi_app(self.args, self.logger), host=self.args.host, port=self.args.port_expose, + ws='wsproto', log_level='critical')) self.logger.success(f'{self.__class__.__name__} is listening at: {self.args.host}:{self.args.port_expose}') @@ -29,44 +29,3 @@ async def async_run_forever(self): async def async_cancel(self): self._server.should_exit = True - - def _get_fastapi_app(self): - with ImportExtensions(required=True): - from fastapi import FastAPI, Body - from fastapi.responses import JSONResponse - from fastapi.middleware.cors import CORSMiddleware - - app = FastAPI(title=self.__class__.__name__) - app.add_middleware( - CORSMiddleware, - allow_origins=['*'], - allow_credentials=True, - allow_methods=['*'], - allow_headers=['*'], - ) - servicer = AsyncPrefetchCall(self.args) - - def error(reason, status_code): - return JSONResponse(content={'reason': reason}, status_code=status_code) - - @app.get('/ready') - async def is_ready(): - return JSONResponse(status_code=200) - - @app.post(path='/api/{mode}') - async def api(mode: str, body: Any = Body(...)): - if mode.upper() not in RequestType.__members__: - return error(reason=f'unsupported mode {mode}', status_code=405) - - if 'data' not in body: - return error('"data" field is empty', 406) - - body['mode'] = RequestType.from_string(mode) - req_iter = getattr(clients.request, mode)(**body) - results = await get_result_in_json(req_iter=req_iter) - return JSONResponse(content=results[0], status_code=200) - - async def get_result_in_json(req_iter): - return [MessageToDict(k) async for k in servicer.Call(request_iterator=req_iter, context=None)] - - return app diff --git a/jina/peapods/runtimes/asyncio/rest/app.py b/jina/peapods/runtimes/asyncio/rest/app.py new file mode 100644 index 0000000000000..8cc9cf34d7091 --- /dev/null +++ b/jina/peapods/runtimes/asyncio/rest/app.py @@ -0,0 +1,162 @@ +import argparse +import asyncio +from typing import Any + +from google.protobuf.json_format import MessageToDict + +from ..grpc.async_call import AsyncPrefetchCall +from ....zmq import AsyncZmqlet +from ..... import clients +from .....enums import RequestType +from .....importer import ImportExtensions +from .....logging import JinaLogger +from .....types.message import Message +from .....types.request import Request + + +def get_fastapi_app(args: 'argparse.Namespace', logger: 'JinaLogger'): + with ImportExtensions(required=True): + from fastapi import FastAPI, WebSocket, Body + from fastapi.responses import JSONResponse + from fastapi.middleware.cors import CORSMiddleware + from starlette.endpoints import WebSocketEndpoint + from starlette import status + if False: + from starlette.types import Receive, Scope, Send + + app = FastAPI(title='RESTRuntime') + app.add_middleware( + CORSMiddleware, + allow_origins=['*'], + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], + ) + servicer = AsyncPrefetchCall(args) + + def error(reason, status_code): + return JSONResponse(content={'reason': reason}, status_code=status_code) + + @app.get('/ready') + async def is_ready(): + return JSONResponse(status_code=200) + + @app.post(path='/api/{mode}') + async def api(mode: str, body: Any = Body(...)): + if mode.upper() not in RequestType.__members__: + return error(reason=f'unsupported mode {mode}', status_code=405) + + if 'data' not in body: + return error('"data" field is empty', 406) + + body['mode'] = RequestType.from_string(mode) + req_iter = getattr(clients.request, mode)(**body) + results = await get_result_in_json(req_iter=req_iter) + return JSONResponse(content=results[0], status_code=200) + + async def get_result_in_json(req_iter): + return [MessageToDict(k) async for k in servicer.Call(request_iterator=req_iter, context=None)] + + @app.websocket_route(path='/stream') + class StreamingEndpoint(WebSocketEndpoint): + """ + :meth:`handle_receive()` + Await a message on :meth:`websocket.receive()` + Send the message to zmqlet via :meth:`zmqlet.send_message()` and await + :meth:`handle_send()` + Await a message on :meth:`zmqlet.recv_message()` + Send the message back to client via :meth:`websocket.send()` and await + :meth:`dispatch()` + Awaits on concurrent tasks :meth:`handle_receive()` & :meth:`handle_send()` + This makes sure gateway is nonblocking + Await exit strategy: + :meth:`handle_receive()` keeps track of num_requests received + :meth:`handle_send()` keeps track of num_responses sent + Client sends a final message: `bytes(True)` to indicate request iterator is empty + Server exits out of await when `(num_requests == num_responses != 0 and is_req_empty)` + """ + + encoding = None + is_req_empty = False + num_requests = 0 + num_responses = 0 + + def __init__(self, scope: 'Scope', receive: 'Receive', send: 'Send') -> None: + super().__init__(scope, receive, send) + self.args = args + self.name = args.name or self.__class__.__name__ + self.client_encoding = None + + async def dispatch(self) -> None: + websocket = WebSocket(self.scope, receive=self.receive, send=self.send) + await self.on_connect(websocket) + close_code = status.WS_1000_NORMAL_CLOSURE + + await asyncio.gather( + self.handle_receive(websocket=websocket, close_code=close_code), + self.handle_send(websocket=websocket) + ) + + async def on_connect(self, websocket: WebSocket) -> None: + # TODO(Deepankar): To enable multiple concurrent clients, + # Register each client - https://fastapi.tiangolo.com/advanced/websockets/#handling-disconnections-and-multiple-clients + # And move class variables to instance variable + await websocket.accept() + self.client_info = f'{websocket.client.host}:{websocket.client.port}' + logger.success(f'Client {self.client_info} connected to stream requests via websockets') + self.zmqlet = AsyncZmqlet(args, logger) + + async def handle_receive(self, websocket: WebSocket, close_code: int) -> None: + try: + while True: + message = await websocket.receive() + if message['type'] == 'websocket.receive': + data = await self.decode(websocket, message) + if data == bytes(True): + self.is_req_empty = True + continue + await self.zmqlet.send_message( + Message(None, Request(data), 'gateway', **vars(self.args)) + ) + self.num_requests += 1 + elif message['type'] == 'websocket.disconnect': + close_code = int(message.get('code', status.WS_1000_NORMAL_CLOSURE)) + break + except Exception as exc: + close_code = status.WS_1011_INTERNAL_ERROR + logger.error(f'Got an exception in handle_receive: {exc!r}') + raise exc from None + finally: + await self.on_disconnect(websocket, close_code) + + async def handle_send(self, websocket: WebSocket) -> None: + + def handle_route(msg: 'Message') -> 'Request': + msg.add_route(self.name, self.args.identity) + return msg.response + + try: + while not (self.num_requests == self.num_responses != 0 and self.is_req_empty): + response = await self.zmqlet.recv_message(callback=handle_route) + if self.client_encoding == 'bytes': + await websocket.send_bytes(response.SerializeToString()) + else: + await websocket.send_json(response.to_json()) + self.num_responses += 1 + except Exception as e: + logger.error(f'Got an exception in handle_send: {e!r}') + + async def decode(self, websocket: WebSocket, message: Message) -> Any: + if 'text' in message or 'json' in message: + self.client_encoding = 'text' + + if 'bytes' in message: + self.client_encoding = 'bytes' + + return await super().decode(websocket, message) + + async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: + self.zmqlet.close() + logger.info(f'Client {self.client_info} got disconnected!') + + return app diff --git a/jina/peapods/runtimes/jinad/api.py b/jina/peapods/runtimes/jinad/api.py index 9917fd3c9635b..a620f333d6f6b 100644 --- a/jina/peapods/runtimes/jinad/api.py +++ b/jina/peapods/runtimes/jinad/api.py @@ -113,7 +113,7 @@ def is_alive(self) -> bool: r = requests.get(url=self.alive_url, timeout=self.timeout) return r.status_code == requests.codes.ok except requests.exceptions.RequestException as ex: - self.logger.error(f'something wrong on remote: {repr(ex)}') + self.logger.error(f'something wrong on remote: {ex!r}') return False def upload(self, args: Dict, **kwargs) -> bool: @@ -143,7 +143,7 @@ def upload(self, args: Dict, **kwargs) -> bool: self.logger.success(f'Got status {r.json()["status"]} from remote') return True except requests.exceptions.RequestException as ex: - self.logger.error(f'something wrong on remote: {repr(ex)}') + self.logger.error(f'something wrong on remote: {ex!r}') def create(self, args: Dict, **kwargs) -> Optional[str]: """ Create a remote pea/pod @@ -161,7 +161,7 @@ def create(self, args: Dict, **kwargs) -> Optional[str]: return r.json()[f'{self.kind}_id'] self.logger.error(f'couldn\'t create pod with remote jinad {r.json()}') except requests.exceptions.RequestException as ex: - self.logger.error(f'couldn\'t create pod with remote jinad {repr(ex)}') + self.logger.error(f'couldn\'t create pod with remote jinad {ex!r}') async def logstream(self, remote_id: 'str', log_id: 'str'): """ websocket log stream from remote pea/pod @@ -203,7 +203,7 @@ async def logstream(self, remote_id: 'str', log_id: 'str'): except websockets.exceptions.ConnectionClosedOK: self.logger.error(f'🌏 Client got disconnected from server') except websockets.exceptions.WebSocketException as e: - self.logger.error(f'🌏 Got following error while streaming logs via websocket {repr(e)}') + self.logger.error(f'🌏 Got following error while streaming logs via websocket {e!r}') except asyncio.CancelledError: self.logger.info(f'🌏 Logging task cancelled successfully') finally: @@ -226,7 +226,7 @@ def delete(self, remote_id: 'str', **kwargs) -> bool: r = requests.delete(url=url, timeout=self.timeout) return r.status_code == requests.codes.ok except requests.exceptions.RequestException as ex: - self.logger.error(f'couldn\'t connect with remote jinad url {repr(ex)}') + self.logger.error(f'couldn\'t connect with remote jinad url {ex!r}') return False diff --git a/jina/peapods/zmq/__init__.py b/jina/peapods/zmq/__init__.py index 4ee86b0eaafe6..54121edb05b3d 100644 --- a/jina/peapods/zmq/__init__.py +++ b/jina/peapods/zmq/__init__.py @@ -272,7 +272,7 @@ async def send_message(self, msg: 'Message', sleep: float = 0, **kwargs): self.bytes_sent += num_bytes self.msg_sent += 1 except (asyncio.CancelledError, TypeError) as ex: - self.logger.error(f'sending message error: {repr(ex)}, gateway cancelled?') + self.logger.error(f'sending message error: {ex!r}, gateway cancelled?') async def recv_message(self, callback: Callable[['Message'], Union['Message', 'Request']] = None) -> 'Message': try: @@ -282,7 +282,7 @@ async def recv_message(self, callback: Callable[['Message'], Union['Message', 'R if callback: return callback(msg) except (asyncio.CancelledError, TypeError) as ex: - self.logger.error(f'receiving message error: {repr(ex)}, gateway cancelled?') + self.logger.error(f'receiving message error: {ex!r}, gateway cancelled?') def __enter__(self): time.sleep(.2) # sleep a bit until handshake is done diff --git a/jina/types/message/__init__.py b/jina/types/message/__init__.py index cb50796e70d92..a737f8d2e1a45 100644 --- a/jina/types/message/__init__.py +++ b/jina/types/message/__init__.py @@ -205,7 +205,7 @@ def _compress(self, data: bytes) -> bytes: self.envelope.compression.algorithm = 'NONE' except Exception as ex: default_logger.error( - f'compression={str(ctag)} failed, fallback to compression="NONE". reason: {repr(ex)}') + f'compression={str(ctag)} failed, fallback to compression="NONE". reason: {ex!r}') self.envelope.compression.algorithm = 'NONE' return data diff --git a/tests/integration/eval_flow/test_flow_eval_pod.py b/tests/integration/eval_flow/test_flow_eval_pod.py index d11ee0f96646e..614866c736aa2 100644 --- a/tests/integration/eval_flow/test_flow_eval_pod.py +++ b/tests/integration/eval_flow/test_flow_eval_pod.py @@ -40,16 +40,18 @@ def validate(ids, expect): @pytest.mark.parametrize('inspect', params) -def test_flow1(inspect): - f = Flow(inspect=inspect).add() +@pytest.mark.parametrize('restful', [False, True]) +def test_flow1(inspect, restful): + f = Flow(restful=restful, inspect=inspect).add() with f: f.index(docs) @pytest.mark.parametrize('inspect', params) -def test_flow2(inspect): - f = Flow(inspect=inspect).add().inspect(uses='DummyEvaluator1') +@pytest.mark.parametrize('restful', [False, True]) +def test_flow2(inspect, restful): + f = Flow(restful=restful, inspect=inspect).add().inspect(uses='DummyEvaluator1') with f: f.index(docs) @@ -57,10 +59,16 @@ def test_flow2(inspect): validate([1], expect=f.args.inspect.is_keep) +# TODO(Deepankar): Gets stuck when `restful: True` - issues with `needs='gateway'` @pytest.mark.parametrize('inspect', params) -def test_flow3(inspect): - f = Flow(inspect=inspect).add(name='p1').inspect(uses='DummyEvaluator1') \ - .add(name='p2', needs='gateway').needs(['p1', 'p2']).inspect(uses='DummyEvaluator2') +@pytest.mark.parametrize('restful', [False]) +def test_flow3(inspect, restful): + f = (Flow(restful=restful, inspect=inspect) + .add(name='p1') + .inspect(uses='DummyEvaluator1') + .add(name='p2', needs='gateway') + .needs(['p1', 'p2']) + .inspect(uses='DummyEvaluator2')) with f: f.index(docs) @@ -69,10 +77,16 @@ def test_flow3(inspect): @pytest.mark.parametrize('inspect', params) -def test_flow5(inspect): - f = Flow(inspect=inspect).add().inspect(uses='DummyEvaluator1').add().inspect( - uses='DummyEvaluator2').add().inspect( - uses='DummyEvaluator3').plot(build=True) +@pytest.mark.parametrize('restful', [False, True]) +def test_flow5(inspect, restful): + f = (Flow(restful=restful, inspect=inspect) + .add() + .inspect(uses='DummyEvaluator1') + .add() + .inspect(uses='DummyEvaluator2') + .add() + .inspect(uses='DummyEvaluator3') + .plot(build=True)) with f: f.index(docs) diff --git a/tests/integration/evaluation/flow-evaluate-from-file-parallel.yml b/tests/integration/evaluation/flow-evaluate-from-file-parallel.yml index 1dfdc987878bf..5f41c2e66bb49 100644 --- a/tests/integration/evaluation/flow-evaluate-from-file-parallel.yml +++ b/tests/integration/evaluation/flow-evaluate-from-file-parallel.yml @@ -1,4 +1,6 @@ !Flow +with: + restful: $RESTFUL pods: gt_indexer: uses: yaml/index-gt.yml diff --git a/tests/integration/evaluation/flow-evaluate-from-file.yml b/tests/integration/evaluation/flow-evaluate-from-file.yml index ccefa81b609fd..72423f5b2008f 100644 --- a/tests/integration/evaluation/flow-evaluate-from-file.yml +++ b/tests/integration/evaluation/flow-evaluate-from-file.yml @@ -1,4 +1,6 @@ !Flow +with: + restful: $RESTFUL pods: evaluate_from_file: uses: yaml/evaluate-from-file.yml diff --git a/tests/integration/evaluation/flow-index-gt-parallel.yml b/tests/integration/evaluation/flow-index-gt-parallel.yml index 9f35ea5ef5e7b..af7a6ac150341 100644 --- a/tests/integration/evaluation/flow-index-gt-parallel.yml +++ b/tests/integration/evaluation/flow-index-gt-parallel.yml @@ -1,4 +1,6 @@ !Flow +with: + restful: $RESTFUL pods: gt_indexer: uses: yaml/index-gt.yml diff --git a/tests/integration/evaluation/flow-index-gt.yml b/tests/integration/evaluation/flow-index-gt.yml index ea93ad9e3dcac..4e8427b6c85a8 100644 --- a/tests/integration/evaluation/flow-index-gt.yml +++ b/tests/integration/evaluation/flow-index-gt.yml @@ -1,4 +1,6 @@ !Flow +with: + restful: $RESTFUL pods: gt_indexer: uses: yaml/index-gt.yml diff --git a/tests/integration/evaluation/flow-parallel-evaluate-from-file-parallel.yml b/tests/integration/evaluation/flow-parallel-evaluate-from-file-parallel.yml index 2d6b32e9f6987..708ca42e5fb3e 100644 --- a/tests/integration/evaluation/flow-parallel-evaluate-from-file-parallel.yml +++ b/tests/integration/evaluation/flow-parallel-evaluate-from-file-parallel.yml @@ -1,4 +1,6 @@ !Flow +with: + restful: $RESTFUL pods: gt_indexer: uses: yaml/index-gt.yml diff --git a/tests/integration/evaluation/test_evaluation_from_file.py b/tests/integration/evaluation/test_evaluation_from_file.py index ff2cd5a62821a..d978b007e945d 100644 --- a/tests/integration/evaluation/test_evaluation_from_file.py +++ b/tests/integration/evaluation/test_evaluation_from_file.py @@ -67,11 +67,16 @@ def random_workspace(tmpdir): ('flow-index-gt-parallel.yml', 'flow-evaluate-from-file-parallel.yml'), ('flow-index-gt-parallel.yml', 'flow-parallel-evaluate-from-file-parallel.yml') ]) -def test_evaluation_from_file(random_workspace, index_groundtruth, evaluate_docs, index_yaml, search_yaml, mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_evaluation_from_file(random_workspace, index_groundtruth, evaluate_docs, index_yaml, search_yaml, + restful, mocker, monkeypatch): + monkeypatch.setenv("RESTFUL", restful) + with Flow.load_config(index_yaml) as index_gt_flow: index_gt_flow.index(input_fn=index_groundtruth, batch_size=10) m = mocker.Mock() + def validate_evaluation_response(resp): m() assert len(resp.docs) == 97 diff --git a/tests/integration/gateway_non_blocking/flow.yml b/tests/integration/gateway_non_blocking/flow.yml index 1e9943dbe01fd..ecc357f05020f 100644 --- a/tests/integration/gateway_non_blocking/flow.yml +++ b/tests/integration/gateway_non_blocking/flow.yml @@ -1,4 +1,6 @@ !Flow +with: + restful: $RESTFUL pods: pod: uses: yaml/fast_slow.yml diff --git a/tests/integration/gateway_non_blocking/test_gateway_non_blocking.py b/tests/integration/gateway_non_blocking/test_gateway_non_blocking.py index d0d4e9b0f91b2..488479bfb5708 100644 --- a/tests/integration/gateway_non_blocking/test_gateway_non_blocking.py +++ b/tests/integration/gateway_non_blocking/test_gateway_non_blocking.py @@ -7,8 +7,10 @@ @pytest.mark.parametrize('parallel, expected_response', [(1, ['slow', 'fast']), (2, ['fast', 'slow'])]) -def test_non_blocking_gateway(parallel, expected_response, mocker): - os.environ['JINA_NON_BLOCKING_PARALLEL'] = str(parallel) +@pytest.mark.parametrize('restful', [False, True]) +def test_non_blocking_gateway(parallel, expected_response, restful, mocker, monkeypatch): + monkeypatch.setenv("JINA_NON_BLOCKING_PARALLEL", str(parallel)) + monkeypatch.setenv("RESTFUL", restful) response = [] def fill_responses(resp): diff --git a/tests/integration/high_order_matches/test_adjacency.py b/tests/integration/high_order_matches/test_adjacency.py index d41eb09e3482c..57af287cd2597 100644 --- a/tests/integration/high_order_matches/test_adjacency.py +++ b/tests/integration/high_order_matches/test_adjacency.py @@ -1,6 +1,8 @@ import os import shutil +import pytest + from jina.flow import Flow from tests import random_docs @@ -17,7 +19,8 @@ # f.search(random_docs(1, chunks_per_doc=0, embed_dim=2), on_done=validate) -def test_high_order_matches_integrated(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_high_order_matches_integrated(mocker, restful): def validate(req): assert len(req.docs) == 1 assert len(req.docs[0].matches) == 5 @@ -28,7 +31,7 @@ def validate(req): response_mock = mocker.Mock(wrap=validate) # this is equivalent to the last test but with simplified YAML spec. - f = Flow(callback_on='body').add(uses=os.path.join(cur_dir, 'test-adjacency-integrated.yml')) + f = Flow(restful=restful, callback_on='body').add(uses=os.path.join(cur_dir, 'test-adjacency-integrated.yml')) with f: f.index(random_docs(100, chunks_per_doc=0, embed_dim=2)) diff --git a/tests/integration/incremental_indexing/test_incremental_indexing.py b/tests/integration/incremental_indexing/test_incremental_indexing.py index 4fc0cb1aea682..59df762a58c71 100644 --- a/tests/integration/incremental_indexing/test_incremental_indexing.py +++ b/tests/integration/incremental_indexing/test_incremental_indexing.py @@ -1,5 +1,7 @@ import os +import pytest + from jina.clients import Client from jina.executors import BaseExecutor from jina.executors.indexers.keyvalue import BinaryPbIndexer @@ -13,11 +15,12 @@ print(random_workspace) -def test_incremental_indexing_sequential_indexers(random_workspace): +@pytest.mark.parametrize('restful', [False, True]) +def test_incremental_indexing_sequential_indexers(random_workspace, restful): total_docs = 20 duplicate_docs, num_uniq_docs = get_duplicate_docs(num_docs=total_docs) - f = (Flow() + f = (Flow(restful=restful) .add(uses=os.path.join(cur_dir, 'uniq_vectorindexer.yml')) .add(uses=os.path.join(cur_dir, 'uniq_docindexer.yml'))) @@ -39,13 +42,14 @@ def test_incremental_indexing_sequential_indexers(random_workspace): assert doc_indexer._size == num_uniq_docs -def test_incremental_indexing_sequential_indexers_content_hash_same_content(random_workspace): +@pytest.mark.parametrize('restful', [False, True]) +def test_incremental_indexing_sequential_indexers_content_hash_same_content(random_workspace, restful): total_docs = 20 duplicate_docs, _ = get_duplicate_docs(num_docs=total_docs, same_content=True) # because they all have the same content num_uniq_docs = 1 - f = (Flow() + f = (Flow(restful=restful) .add(uses=os.path.join(cur_dir, 'uniq_vectorindexer_content_hash.yml')) .add(uses=os.path.join(cur_dir, 'uniq_docindexer_content_hash.yml'))) @@ -67,13 +71,14 @@ def test_incremental_indexing_sequential_indexers_content_hash_same_content(rand assert doc_indexer._size == num_uniq_docs -def test_incremental_indexing_sequential_indexers_content_hash(random_workspace): +@pytest.mark.parametrize('restful', [False, True]) +def test_incremental_indexing_sequential_indexers_content_hash(random_workspace, restful): total_docs = 20 duplicate_docs, _ = get_duplicate_docs(num_docs=total_docs, same_content=False) # because the content is % 2 num_uniq_docs = 10 - f = (Flow() + f = (Flow(restful=restful) .add(uses=os.path.join(cur_dir, 'uniq_vectorindexer_content_hash.yml')) .add(uses=os.path.join(cur_dir, 'uniq_docindexer_content_hash.yml'))) @@ -95,11 +100,13 @@ def test_incremental_indexing_sequential_indexers_content_hash(random_workspace) assert doc_indexer._size == num_uniq_docs -def test_incremental_indexing_parallel_indexers(random_workspace): +# TODO(Deepankar): Gets stuck when `restful: True` - issues with `needs='gateway'` +@pytest.mark.parametrize('restful', [False]) +def test_incremental_indexing_parallel_indexers(random_workspace, restful): total_docs = 1000 duplicate_docs, num_uniq_docs = get_duplicate_docs(num_docs=total_docs) - f = (Flow() + f = (Flow(restful=restful) .add(uses=os.path.join(cur_dir, 'uniq_vectorindexer.yml'), name='inc_vec') .add(uses=os.path.join(cur_dir, 'uniq_docindexer.yml'), @@ -121,13 +128,14 @@ def test_incremental_indexing_parallel_indexers(random_workspace): assert doc_indexer._size == num_uniq_docs -def test_incremental_indexing_sequential_indexers_with_shards(random_workspace): +@pytest.mark.parametrize('restful', [False, True]) +def test_incremental_indexing_sequential_indexers_with_shards(random_workspace, restful): total_docs = 1000 duplicate_docs, num_uniq_docs = get_duplicate_docs(num_docs=total_docs) num_shards = 4 # can't use plain _unique in uses_before because workspace will conflict with other - f = (Flow() + f = (Flow(restful=restful) .add(uses=os.path.join(cur_dir, 'vectorindexer.yml'), uses_before=os.path.join(cur_dir, '_unique_vec.yml'), shards=num_shards, @@ -160,27 +168,28 @@ def test_incremental_indexing_sequential_indexers_with_shards(random_workspace): assert doc_idx_size == num_uniq_docs -def test_incremental_indexing_parallel_indexers_with_shards(random_workspace): +# TODO(Deepankar): Gets stuck when `restful: True` - issues with `needs='gateway'` +@pytest.mark.parametrize('restful', [False]) +def test_incremental_indexing_parallel_indexers_with_shards(random_workspace, restful): total_docs = 1000 duplicate_docs, num_uniq_docs = get_duplicate_docs(num_docs=total_docs) num_shards = 4 # can't use plain _unique in uses_before because workspace will conflict with other - f = (Flow() - .add(uses=os.path.join(cur_dir, 'vectorindexer.yml'), - uses_before=os.path.join(cur_dir, '_unique_vec.yml'), - shards=num_shards, - name='inc_vec', - separated_workspace=True) - .add(uses=os.path.join(cur_dir, 'docindexer.yml'), - uses_before=os.path.join(cur_dir, '_unique_doc.yml'), - shards=num_shards, - name='inc_doc', - needs=['gateway'], - separated_workspace=True) - .add( - needs=['inc_vec', 'inc_doc'])) + f = (Flow(restful=restful) + .add(uses=os.path.join(cur_dir, 'vectorindexer.yml'), + uses_before=os.path.join(cur_dir, '_unique_vec.yml'), + shards=num_shards, + name='inc_vec', + separated_workspace=True) + .add(uses=os.path.join(cur_dir, 'docindexer.yml'), + uses_before=os.path.join(cur_dir, '_unique_doc.yml'), + shards=num_shards, + name='inc_doc', + needs=['gateway'], + separated_workspace=True) + .add(needs=['inc_vec', 'inc_doc'])) with f: f.index(duplicate_docs[:500]) diff --git a/tests/integration/incremental_indexing/test_unique_indexing.py b/tests/integration/incremental_indexing/test_unique_indexing.py index b64e2482e4841..b12e6e0c2f19c 100644 --- a/tests/integration/incremental_indexing/test_unique_indexing.py +++ b/tests/integration/incremental_indexing/test_unique_indexing.py @@ -1,5 +1,7 @@ import os +import pytest + from jina.executors import BaseExecutor from jina.executors.indexers.keyvalue import BinaryPbIndexer from jina.executors.indexers.vector import NumpyIndexer @@ -12,11 +14,12 @@ print(random_workspace) -def test_unique_indexing_vecindexers(random_workspace): +@pytest.mark.parametrize('restful', [False, True]) +def test_unique_indexing_vecindexers(random_workspace, restful): total_docs = 10 duplicate_docs, num_uniq_docs = get_duplicate_docs(num_docs=total_docs) - f = (Flow() + f = (Flow(restful=restful) .add(uses=os.path.join(cur_dir, 'uniq_vectorindexer.yml'), name='vec_idx')) with f: @@ -27,11 +30,12 @@ def test_unique_indexing_vecindexers(random_workspace): assert vector_indexer.size == num_uniq_docs -def test_unique_indexing_docindexers(random_workspace): +@pytest.mark.parametrize('restful', [False, True]) +def test_unique_indexing_docindexers(random_workspace, restful): total_docs = 10 duplicate_docs, num_uniq_docs = get_duplicate_docs(num_docs=total_docs) - f = (Flow() + f = (Flow(restful=restful) .add(uses=os.path.join(cur_dir, 'uniq_docindexer.yml'), shards=1)) with f: @@ -42,12 +46,13 @@ def test_unique_indexing_docindexers(random_workspace): assert doc_indexer.size == num_uniq_docs -def test_unique_indexing_vecindexers_before(random_workspace): +@pytest.mark.parametrize('restful', [False, True]) +def test_unique_indexing_vecindexers_before(random_workspace, restful): total_docs = 10 duplicate_docs, num_uniq_docs = get_duplicate_docs(num_docs=total_docs) # can't use plain _unique because workspace will conflict with other tests - f = (Flow() + f = (Flow(restful=restful) .add(uses=os.path.join(cur_dir, 'vectorindexer.yml'), uses_before=os.path.join(cur_dir, '_unique_vec.yml'))) @@ -59,12 +64,13 @@ def test_unique_indexing_vecindexers_before(random_workspace): assert vector_indexer.size == num_uniq_docs -def test_unique_indexing_docindexers_before(random_workspace): +@pytest.mark.parametrize('restful', [False, True]) +def test_unique_indexing_docindexers_before(random_workspace, restful): total_docs = 10 duplicate_docs, num_uniq_docs = get_duplicate_docs(num_docs=total_docs) # can't use plain _unique because workspace will conflict with other tests - f = (Flow() + f = (Flow(restful=restful) .add(uses=os.path.join(cur_dir, 'docindexer.yml'), uses_before=os.path.join(cur_dir, '_unique_doc.yml'))) diff --git a/tests/integration/level_depth/flow-index.yml b/tests/integration/level_depth/flow-index.yml index 0e210020edab4..bd8f9cade138b 100644 --- a/tests/integration/level_depth/flow-index.yml +++ b/tests/integration/level_depth/flow-index.yml @@ -1,4 +1,6 @@ !Flow +with: + restful: $RESTFUL pods: segmenter: uses: yaml/segmenter.yml diff --git a/tests/integration/level_depth/flow-query.yml b/tests/integration/level_depth/flow-query.yml index 7a8ae27061ae0..b51dd91462b84 100644 --- a/tests/integration/level_depth/flow-query.yml +++ b/tests/integration/level_depth/flow-query.yml @@ -1,4 +1,6 @@ !Flow +with: + restful: $RESTFUL pods: encoder: uses: yaml/encoder.yml diff --git a/tests/integration/level_depth/test_search_different_depths.py b/tests/integration/level_depth/test_search_different_depths.py index db89b5deeaca5..ded81d26544c5 100644 --- a/tests/integration/level_depth/test_search_different_depths.py +++ b/tests/integration/level_depth/test_search_different_depths.py @@ -1,10 +1,16 @@ import os +import pytest + from jina.flow import Flow -def test_index_depth_0_search_depth_1(tmpdir, mocker): - os.environ['JINA_TEST_LEVEL_DEPTH_WORKSPACE'] = str(tmpdir) +# TODO(Deepankar): Gets stuck when `restful: True` - issues with `needs='gateway'` +@pytest.mark.parametrize('restful', [False]) +def test_index_depth_0_search_depth_1(tmpdir, mocker, monkeypatch, restful): + monkeypatch.setenv("RESTFUL", restful) + monkeypatch.setenv("JINA_TEST_LEVEL_DEPTH_WORKSPACE", str(tmpdir)) + index_data = [ 'I am chunk 0 of doc 1, I am chunk 1 of doc 1, I am chunk 2 of doc 1', 'I am chunk 0 of doc 2, I am chunk 1 of doc 2', @@ -55,5 +61,4 @@ def validate_granularity_1(resp): callback_on='body', ) - del os.environ['JINA_TEST_LEVEL_DEPTH_WORKSPACE'] mock.assert_called_once() diff --git a/tests/integration/mime/test_mime.py b/tests/integration/mime/test_mime.py index 23df4dbe49371..eedc5955368da 100644 --- a/tests/integration/mime/test_mime.py +++ b/tests/integration/mime/test_mime.py @@ -1,6 +1,8 @@ import glob import os +import pytest + from jina.flow import Flow num_docs = 100 @@ -30,31 +32,35 @@ def input_fn3(): yield g -def test_dummy_seg(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_dummy_seg(mocker, restful): response_mock = mocker.Mock() - f = Flow().add(uses='- !Buffer2URI | {mimetype: png}') + f = Flow(restful=restful).add(uses='- !Buffer2URI | {mimetype: png}') with f: f.index(input_fn=input_fn, on_done=response_mock) response_mock.assert_called() response_mock_2 = mocker.Mock() - f = Flow().add(uses='- !Buffer2URI | {mimetype: png, base64: true}') + f = Flow(restful=restful).add(uses='- !Buffer2URI | {mimetype: png, base64: true}') with f: f.index(input_fn=input_fn, on_done=response_mock_2) response_mock_2.assert_called() -def test_any_file(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_any_file(mocker, restful): response_mock = mocker.Mock() - f = Flow().add(uses='- !URI2DataURI | {base64: true}') + f = Flow(restful=restful).add(uses='- !URI2DataURI | {base64: true}') with f: f.index(input_fn=input_fn2, on_done=response_mock) response_mock.assert_called() -def test_aba(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_aba(mocker, restful): response_mock = mocker.Mock() - f = (Flow().add(uses='- !Buffer2URI | {mimetype: png}') + f = (Flow(restful=restful) + .add(uses='- !Buffer2URI | {mimetype: png}') .add(uses='- !URI2Buffer {}') .add(uses='- !Buffer2URI | {mimetype: png}')) @@ -63,9 +69,11 @@ def test_aba(mocker): response_mock.assert_called() -def test_pathURI2Buffer(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_pathURI2Buffer(mocker, restful): response_mock = mocker.Mock() - f = (Flow().add(uses='- !URI2Buffer {}') + f = (Flow(restful=restful) + .add(uses='- !URI2Buffer {}') .add(uses='- !Buffer2URI {}')) with f: @@ -73,18 +81,20 @@ def test_pathURI2Buffer(mocker): response_mock.assert_called() -def test_text2datauri(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_text2datauri(mocker, restful): response_mock = mocker.Mock() - f = (Flow().add(uses='- !Text2URI {}')) + f = (Flow(restful=restful).add(uses='- !Text2URI {}')) with f: f.index_lines(lines=['abc', '123', 'hello, world'], on_done=response_mock) response_mock.assert_called() -def test_gateway_dataui(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_gateway_dataui(mocker, restful): response_mock = mocker.Mock() - f = (Flow().add()) + f = (Flow(restful=restful).add()) with f: f.index_lines(lines=['abc', '123', 'hello, world'], on_done=response_mock) diff --git a/tests/integration/mime/test_segmenter.py b/tests/integration/mime/test_segmenter.py index da728ac4680d3..662c766fef2c7 100644 --- a/tests/integration/mime/test_segmenter.py +++ b/tests/integration/mime/test_segmenter.py @@ -1,5 +1,7 @@ import os +import pytest + from jina.executors.crafters import BaseSegmenter from jina.flow import Flow from tests import random_docs @@ -22,25 +24,28 @@ def validate(req): return validate -def test_dummy_seg(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_dummy_seg(mocker, restful): mock = mocker.Mock() - f = Flow().add(uses='DummySegment') + f = Flow(restful=restful).add(uses='DummySegment') with f: f.index(input_fn=random_docs(10, chunks_per_doc=0), on_done=validate_factory(mock)) mock.assert_called_once() -def test_dummy_seg_random(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_dummy_seg_random(mocker, restful): mock = mocker.Mock() - f = Flow().add(uses=os.path.join(cur_dir, '../../unit/yaml/dummy-seg-random.yml')) + f = Flow(restful=restful).add(uses=os.path.join(cur_dir, '../../unit/yaml/dummy-seg-random.yml')) with f: f.index(input_fn=random_docs(10, chunks_per_doc=0), on_done=validate_factory(mock)) mock.assert_called_once() -def test_dummy_seg_not_random(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_dummy_seg_not_random(mocker, restful): mock = mocker.Mock() - f = Flow().add(uses=os.path.join(cur_dir, '../../unit/yaml/dummy-seg-not-random.yml')) + f = Flow(restful=restful).add(uses=os.path.join(cur_dir, '../../unit/yaml/dummy-seg-not-random.yml')) with f: f.index(input_fn=random_docs(10, chunks_per_doc=0), on_done=validate_factory(mock)) mock.assert_called_once() diff --git a/tests/integration/multimodal/flow-embedding-multimodal-parallel.yml b/tests/integration/multimodal/flow-embedding-multimodal-parallel.yml index 31507423b1c6c..63d274b254428 100644 --- a/tests/integration/multimodal/flow-embedding-multimodal-parallel.yml +++ b/tests/integration/multimodal/flow-embedding-multimodal-parallel.yml @@ -1,4 +1,6 @@ !Flow +with: + restful: $RESTFUL pods: encoder_modality_1: uses_before: '- !FilterQL | {lookups: {modality__in: [modality1]}, traversal_paths: [c]}' diff --git a/tests/integration/multimodal/flow-multimodal-all-types-parallel.yml b/tests/integration/multimodal/flow-multimodal-all-types-parallel.yml index b5ae2fb3f4feb..96d04ca6d0429 100644 --- a/tests/integration/multimodal/flow-multimodal-all-types-parallel.yml +++ b/tests/integration/multimodal/flow-multimodal-all-types-parallel.yml @@ -1,4 +1,6 @@ !Flow +with: + restful: $RESTFUL pods: pass_modality_1: uses: '- !FilterQL | {lookups: {modality__in: [modality1]}, traversal_paths: [c]}' diff --git a/tests/integration/multimodal/test_multimodal_parallel.py b/tests/integration/multimodal/test_multimodal_parallel.py index 423a0743d0365..e391130606e5f 100644 --- a/tests/integration/multimodal/test_multimodal_parallel.py +++ b/tests/integration/multimodal/test_multimodal_parallel.py @@ -37,7 +37,11 @@ def multimodal_documents(): return docs -def test_multimodal_embedding_parallel(multimodal_documents, mocker): +# TODO(Deepankar): Gets stuck when `restful: True` - issues with `needs='gateway'` +@pytest.mark.parametrize('restful', [False]) +def test_multimodal_embedding_parallel(multimodal_documents, mocker, monkeypatch, restful): + monkeypatch.setenv("RESTFUL", restful) + def validate_response(resp): mock() assert len(resp.index.docs) == NUM_DOCS @@ -83,7 +87,11 @@ def multimodal_all_types_documents(): return docs -def test_multimodal_all_types_parallel(multimodal_all_types_documents, mocker): +# TODO(Deepankar): Gets stuck when `restful: True` - issues with `needs='gateway'` +@pytest.mark.parametrize('restful', [False]) +def test_multimodal_all_types_parallel(multimodal_all_types_documents, mocker, monkeypatch, restful): + monkeypatch.setenv("RESTFUL", restful) + def validate_response(resp): mock() assert len(resp.index.docs) == NUM_DOCS diff --git a/tests/jinad/integration/distributed/helpers.py b/tests/jinad/integration/distributed/helpers.py index 05da4cf373512..631014f57836a 100644 --- a/tests/jinad/integration/distributed/helpers.py +++ b/tests/jinad/integration/distributed/helpers.py @@ -15,7 +15,7 @@ def invoke_requests(method: str, url, data=json.dumps(payload), headers=headers) return response.json() except requests.exceptions.RequestException as e: - print(f'got an exception while invoking request {repr(e)}') + print(f'got an exception while invoking request {e!r}') return def get_results(query: str, diff --git a/tests/jinad/integration/distributed/test_index_query/test_integration.py b/tests/jinad/integration/distributed/test_index_query/test_integration.py index fa20930c83336..69dcbaaff1da6 100644 --- a/tests/jinad/integration/distributed/test_index_query/test_integration.py +++ b/tests/jinad/integration/distributed/test_index_query/test_integration.py @@ -4,7 +4,6 @@ from ..helpers import create_flow, invoke_requests, get_results - cur_dir = os.path.dirname(os.path.abspath(__file__)) compose_yml = os.path.join(cur_dir, 'docker-compose.yml') flow_yml = os.path.join(cur_dir, 'flow.yml') diff --git a/tests/unit/clients/python/test_client.py b/tests/unit/clients/python/test_client.py index 09527f7c7b828..02483dd288c67 100644 --- a/tests/unit/clients/python/test_client.py +++ b/tests/unit/clients/python/test_client.py @@ -75,8 +75,9 @@ def test_gateway_index(flow_with_rest_api_enabled, test_img_1, test_img_2): assert resp['index']['docs'][0]['uri'] == test_img_1 -def test_mime_type(): - f = Flow().add(uses='- !URI2Buffer {}') +@pytest.mark.parametrize('restful', [False, True]) +def test_mime_type(restful): + f = Flow(restful=restful).add(uses='- !URI2Buffer {}') def validate_mime_type(req): for d in req.index.docs: diff --git a/tests/unit/clients/python/test_on_err.py b/tests/unit/clients/python/test_on_err.py index 11ad4133a5485..ba13a43025f0f 100644 --- a/tests/unit/clients/python/test_on_err.py +++ b/tests/unit/clients/python/test_on_err.py @@ -1,3 +1,4 @@ +import pytest import numpy as np from jina.excepts import BadClientCallback @@ -8,7 +9,8 @@ def validate(x): raise NotImplementedError -def test_client_on_error(): +@pytest.mark.parametrize('restful', [False, True]) +def test_client_on_error(restful): # In this particular test, when you write two tests in a row, you are testing the following case: # # You are testing exception in client's callback, not error in client's request generator @@ -19,7 +21,7 @@ def test_client_on_error(): def validate(x): raise NotImplementedError - with Flow().add() as f: + with Flow(restful=restful).add() as f: t = 0 try: f.index_ndarray(np.random.random([5, 4]), on_done=validate, continue_on_error=False) @@ -29,4 +31,3 @@ def validate(x): # now query the gateway again, make sure gateway's channel is still usable f.index_ndarray(np.random.random([5, 4]), on_done=validate, continue_on_error=True) assert t == 1 - diff --git a/tests/unit/flow/test_asyncflow.py b/tests/unit/flow/test_asyncflow.py index a36f767b2f23b..62a1f2506bd3b 100644 --- a/tests/unit/flow/test_asyncflow.py +++ b/tests/unit/flow/test_asyncflow.py @@ -4,6 +4,7 @@ import pytest from jina.flow.asyncio import AsyncFlow +from jina.types.request import Response from jina.logging.profile import TimeContext @@ -12,15 +13,19 @@ def validate(req): assert req.docs[0].blob.ndim == 1 +# TODO(Deepankar): with `restful: True` few of the asyncio tests are flaky. +# Though it runs fine locally, results in - `RuntimeError - Event loop closed` in CI (Disabling for now) + @pytest.mark.asyncio -async def test_run_async_flow(): - with AsyncFlow().add() as f: +@pytest.mark.parametrize('restful', [False]) +async def test_run_async_flow(restful): + with AsyncFlow(restful=restful).add() as f: await f.index_ndarray(np.random.random([5, 4]), on_done=validate) -async def run_async_flow_5s(): +async def run_async_flow_5s(restful): # WaitDriver pause 5s makes total roundtrip ~5s - with AsyncFlow().add(uses='- !WaitDriver {}') as f: + with AsyncFlow(restful=restful).add(uses='- !WaitDriver {}') as f: await f.index_ndarray(np.random.random([5, 4]), on_done=validate) @@ -31,29 +36,44 @@ async def sleep_print(): print('heavylifting done after 5s') -async def concurrent_main(): +async def concurrent_main(restful): # about 5s; but some dispatch cost, can't be just 5s, usually at <7s - await asyncio.gather(run_async_flow_5s(), sleep_print()) + await asyncio.gather(run_async_flow_5s(restful), sleep_print()) -async def sequential_main(): +async def sequential_main(restful): # about 10s; with some dispatch cost , usually at <12s - await run_async_flow_5s() + await run_async_flow_5s(restful) await sleep_print() @pytest.mark.asyncio -async def test_run_async_flow_other_task_sequential(): +@pytest.mark.parametrize('restful', [False]) +async def test_run_async_flow_other_task_sequential(restful): with TimeContext('sequential await') as t: - await sequential_main() + await sequential_main(restful) assert t.duration >= 10 @pytest.mark.asyncio -async def test_run_async_flow_other_task_concurrent(): +@pytest.mark.parametrize('restful', [False]) +async def test_run_async_flow_other_task_concurrent(restful): with TimeContext('concurrent await') as t: - await concurrent_main() + await concurrent_main(restful) # some dispatch cost, can't be just 5s, usually at 7~8s, but must <10s assert t.duration < 10 + + +@pytest.mark.asyncio +@pytest.mark.parametrize('return_results', [False, True]) +@pytest.mark.parametrize('restful', [False]) +async def test_return_results_async_flow(return_results, restful): + with AsyncFlow(restful=restful, return_results=return_results).add() as f: + r = await f.index_ndarray(np.random.random([10, 2])) + if return_results: + assert isinstance(r, list) + assert isinstance(r[0], Response) + else: + assert r is None diff --git a/tests/unit/flow/test_flow.py b/tests/unit/flow/test_flow.py index e64fa9f5b7b45..bf015e7fa2574 100644 --- a/tests/unit/flow/test_flow.py +++ b/tests/unit/flow/test_flow.py @@ -4,7 +4,7 @@ import pytest import requests -from jina import JINA_GLOBAL, Request, AsyncFlow +from jina import JINA_GLOBAL from jina.enums import SocketType from jina.executors import BaseExecutor from jina.flow import Flow @@ -73,15 +73,15 @@ def _validate(f): rm_files(['tmp.yml']) -def test_simple_flow(): +@pytest.mark.parametrize('restful', [False, True]) +def test_simple_flow(restful): bytes_gen = (b'aaa' for _ in range(10)) def bytes_fn(): for _ in range(100): yield b'aaa' - f = (Flow() - .add()) + f = Flow(restful=restful).add() with f: f.index(input_fn=bytes_gen) @@ -158,9 +158,10 @@ def test_flow_identical(): rm_files(['test2.yml']) -def test_flow_no_container(): +@pytest.mark.parametrize('restful', [False, True]) +def test_flow_no_container(restful): - f = (Flow() + f = (Flow(restful=restful) .add(name='dummyEncoder', uses=os.path.join(cur_dir, '../mwu-encoder/mwu_encoder.yml'))) with f: @@ -210,9 +211,10 @@ def test_flow_log_server(): timeout=5) -def test_shards(): - f = Flow().add(name='doc_pb', uses=os.path.join(cur_dir, '../yaml/test-docpb.yml'), parallel=3, - separated_workspace=True) +@pytest.mark.parametrize('restful', [False, True]) +def test_shards(restful): + f = (Flow(restful=restful) + .add(name='doc_pb', uses=os.path.join(cur_dir, '../yaml/test-docpb.yml'), parallel=3, separated_workspace=True)) with f: f.index(input_fn=random_docs(1000), random_doc_id=False) with f: @@ -278,7 +280,8 @@ def test_py_client(): def test_dry_run_with_two_pathways_diverging_at_gateway(): - f = (Flow().add(name='r2') + f = (Flow() + .add(name='r2') .add(name='r3', needs='gateway') .join(['r2', 'r3'])) @@ -301,7 +304,8 @@ def test_dry_run_with_two_pathways_diverging_at_gateway(): def test_dry_run_with_two_pathways_diverging_at_non_gateway(): - f = (Flow().add(name='r1') + f = (Flow() + .add(name='r1') .add(name='r2') .add(name='r3', needs='r1') .join(['r2', 'r3'])) @@ -329,7 +333,8 @@ def test_dry_run_with_two_pathways_diverging_at_non_gateway(): def test_refactor_num_part(): - f = (Flow().add(name='r1', uses='_logforward', needs='gateway') + f = (Flow() + .add(name='r1', uses='_logforward', needs='gateway') .add(name='r2', uses='_logforward', needs='gateway') .join(['r1', 'r2'])) @@ -352,7 +357,8 @@ def test_refactor_num_part(): def test_refactor_num_part_proxy(): - f = (Flow().add(name='r1', uses='_logforward') + f = (Flow() + .add(name='r1', uses='_logforward') .add(name='r2', uses='_logforward', needs='r1') .add(name='r3', uses='_logforward', needs='r1') .join(['r2', 'r3'])) @@ -379,8 +385,10 @@ def test_refactor_num_part_proxy(): assert node.peas_args['peas'][0] == node.tail_args -def test_refactor_num_part_proxy_2(): - f = (Flow().add(name='r1', uses='_logforward') +@pytest.mark.parametrize('restful', [False, True]) +def test_refactor_num_part_proxy_2(restful): + f = (Flow(restful=restful) + .add(name='r1', uses='_logforward') .add(name='r2', uses='_logforward', needs='r1', parallel=2) .add(name='r3', uses='_logforward', needs='r1', parallel=3, polling='ALL') .needs(['r2', 'r3'])) @@ -389,21 +397,23 @@ def test_refactor_num_part_proxy_2(): f.index_lines(lines=['abbcs', 'efgh']) -def test_refactor_num_part_2(): - f = (Flow() +@pytest.mark.parametrize('restful', [False, True]) +def test_refactor_num_part_2(restful): + f = (Flow(restful=restful) .add(name='r1', uses='_logforward', needs='gateway', parallel=3, polling='ALL')) with f: f.index_lines(lines=['abbcs', 'efgh']) - f = (Flow() + f = (Flow(restful=restful) .add(name='r1', uses='_logforward', needs='gateway', parallel=3)) with f: f.index_lines(lines=['abbcs', 'efgh']) -def test_index_text_files(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_index_text_files(mocker, restful): def validate(req): assert len(req.docs) > 0 for d in req.docs: @@ -411,7 +421,8 @@ def validate(req): response_mock = mocker.Mock(wrap=validate) - f = (Flow(read_only=True).add(uses=os.path.join(cur_dir, '../yaml/datauriindex.yml'), timeout_ready=-1)) + f = (Flow(restful=restful, read_only=True) + .add(uses=os.path.join(cur_dir, '../yaml/datauriindex.yml'), timeout_ready=-1)) with f: f.index_files('*.py', on_done=response_mock, callback_on='body') @@ -420,14 +431,16 @@ def validate(req): response_mock.assert_called() -def test_flow_with_publish_driver(mocker): +# TODO(Deepankar): Gets stuck when `restful: True` - issues with `needs='gateway'` +@pytest.mark.parametrize('restful', [False]) +def test_flow_with_publish_driver(mocker, restful): def validate(req): for d in req.docs: assert d.embedding is not None response_mock = mocker.Mock(wrap=validate) - f = (Flow() + f = (Flow(restful=restful) .add(name='r2', uses='!OneHotTextEncoder') .add(name='r3', uses='!OneHotTextEncoder', needs='gateway') .join(needs=['r2', 'r3'])) @@ -438,7 +451,8 @@ def validate(req): response_mock.assert_called() -def test_flow_with_modalitys_simple(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_flow_with_modalitys_simple(mocker, restful): def validate(req): for d in req.index.docs: assert d.modality in ['mode1', 'mode2'] @@ -454,9 +468,10 @@ def input_fn(): response_mock = mocker.Mock(wrap=validate) - flow = Flow().add(name='chunk_seg', parallel=3). \ - add(name='encoder12', parallel=2, - uses='- !FilterQL | {lookups: {modality__in: [mode1, mode2]}, traversal_paths: [c]}') + flow = (Flow(restful=restful) + .add(name='chunk_seg', parallel=3) + .add(name='encoder12', parallel=2, + uses='- !FilterQL | {lookups: {modality__in: [mode1, mode2]}, traversal_paths: [c]}')) with flow: flow.index(input_fn=input_fn, on_done=response_mock) @@ -471,8 +486,10 @@ def test_flow_arguments_priorities(): assert f._pod_nodes['test'].args.port_expose == 12345 -def test_flow_arbitrary_needs(): - f = (Flow().add(name='p1').add(name='p2', needs='gateway') +@pytest.mark.parametrize('restful', [False]) +def test_flow_arbitrary_needs(restful): + f = (Flow(restful=restful) + .add(name='p1').add(name='p2', needs='gateway') .add(name='p3', needs='gateway') .add(name='p4', needs='gateway') .add(name='p5', needs='gateway') @@ -485,12 +502,15 @@ def test_flow_arbitrary_needs(): f.index_lines(['abc', 'def']) -def test_flow_needs_all(): - f = (Flow().add(name='p1', needs='gateway') +@pytest.mark.parametrize('restful', [False]) +def test_flow_needs_all(restful): + f = (Flow(restful=restful) + .add(name='p1', needs='gateway') .needs_all(name='r1')) assert f._pod_nodes['r1'].needs == {'p1'} - f = (Flow().add(name='p1', needs='gateway') + f = (Flow(restful=restful) + .add(name='p1', needs='gateway') .add(name='p2', needs='gateway') .add(name='p3', needs='gateway') .needs(needs=['p1', 'p2'], name='r1') @@ -500,7 +520,8 @@ def test_flow_needs_all(): with f: f.index_ndarray(np.random.random([10, 10])) - f = (Flow().add(name='p1', needs='gateway') + f = (Flow(restful=restful) + .add(name='p1', needs='gateway') .add(name='p2', needs='gateway') .add(name='p3', needs='gateway') .needs(needs=['p1', 'p2'], name='r1') @@ -543,23 +564,12 @@ def __init__(self, *args, **kwargs): @pytest.mark.parametrize('return_results', [False, True]) -def test_return_results_sync_flow(return_results): - with Flow(return_results=return_results).add() as f: +@pytest.mark.parametrize('restful', [False, True]) +def test_return_results_sync_flow(return_results, restful): + with Flow(restful=restful, return_results=return_results).add() as f: r = f.index_ndarray(np.random.random([10, 2])) if return_results: assert isinstance(r, list) assert isinstance(r[0], Response) else: assert r is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize('return_results', [False, True]) -async def test_return_results_async_flow(return_results): - with AsyncFlow(return_results=return_results).add() as f: - r = await f.index_ndarray(np.random.random([10, 2])) - if return_results: - assert isinstance(r, list) - assert isinstance(r[0], Response) - else: - assert r is None diff --git a/tests/unit/flow/test_flow_before_after.py b/tests/unit/flow/test_flow_before_after.py index a315fa709c6ec..2cbbeca70ae5c 100644 --- a/tests/unit/flow/test_flow_before_after.py +++ b/tests/unit/flow/test_flow_before_after.py @@ -1,10 +1,13 @@ +import pytest + from jina.flow import Flow from tests import random_docs -def test_flow(): +@pytest.mark.parametrize('restful', [False, True]) +def test_flow(restful): docs = random_docs(10) - f = Flow().add(name='p1') + f = Flow(restful=restful).add(name='p1') with f: f.index(docs) @@ -13,9 +16,10 @@ def test_flow(): assert f.num_peas == 2 -def test_flow_before(): +@pytest.mark.parametrize('restful', [False, True]) +def test_flow_before(restful): docs = random_docs(10) - f = Flow().add(uses_before='_pass', name='p1') + f = Flow(restful=restful).add(uses_before='_pass', name='p1') with f: f.index(docs) @@ -24,9 +28,10 @@ def test_flow_before(): assert f.num_peas == 3 -def test_flow_after(): +@pytest.mark.parametrize('restful', [False, True]) +def test_flow_after(restful): docs = random_docs(10) - f = Flow().add(uses_after='_pass', name='p1') + f = Flow(restful=restful).add(uses_after='_pass', name='p1') with f: f.index(docs) @@ -35,9 +40,10 @@ def test_flow_after(): assert f.num_peas == 3 -def test_flow_before_after(): +@pytest.mark.parametrize('restful', [False, True]) +def test_flow_before_after(restful): docs = random_docs(10) - f = Flow().add(uses_before='_pass', uses_after='_pass', name='p1') + f = Flow(restful=restful).add(uses_before='_pass', uses_after='_pass', name='p1') with f: f.index(docs) diff --git a/tests/unit/flow/test_flow_except.py b/tests/unit/flow/test_flow_except.py index 4a964360d5092..6c799aab838a0 100644 --- a/tests/unit/flow/test_flow_except.py +++ b/tests/unit/flow/test_flow_except.py @@ -1,3 +1,5 @@ +import pytest + import numpy as np from jina.executors.crafters import BaseCrafter @@ -10,13 +12,15 @@ def craft(self, *args, **kwargs): return 1 / 0 -def test_bad_flow(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_bad_flow(mocker, restful): def validate(req): bad_routes = [r for r in req.routes if r.status.code == jina_pb2.StatusProto.ERROR] assert req.status.code == jina_pb2.StatusProto.ERROR assert bad_routes[0].pod == 'r1' - f = (Flow().add(name='r1', uses='!BaseCrafter') + f = (Flow(restful=restful) + .add(name='r1', uses='!BaseCrafter') .add(name='r2', uses='!BaseEncoder') .add(name='r3', uses='!BaseEncoder')) @@ -32,14 +36,16 @@ def validate(req): on_error_mock_2.assert_called() -def test_bad_flow_customized(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_bad_flow_customized(mocker, restful): def validate(req): bad_routes = [r for r in req.routes if r.status.code == jina_pb2.StatusProto.ERROR] assert req.status.code == jina_pb2.StatusProto.ERROR assert bad_routes[0].pod == 'r2' assert bad_routes[0].status.exception.name == 'ZeroDivisionError' - f = (Flow().add(name='r1') + f = (Flow(restful=restful) + .add(name='r1') .add(name='r2', uses='!DummyCrafter') .add(name='r3', uses='!BaseEncoder')) @@ -58,7 +64,8 @@ def validate(req): on_error_mock_2.assert_called() -def test_except_with_parallel(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_except_with_parallel(mocker, restful): def validate(req): assert req.status.code == jina_pb2.StatusProto.ERROR err_routes = [r.status for r in req.routes if r.status.code == jina_pb2.StatusProto.ERROR] @@ -68,7 +75,8 @@ def validate(req): assert err_routes[0].exception.name == 'ZeroDivisionError' assert err_routes[1].exception.name == 'NotImplementedError' - f = (Flow().add(name='r1') + f = (Flow(restful=restful) + .add(name='r1') .add(name='r2', uses='!DummyCrafter', parallel=3) .add(name='r3', uses='!BaseEncoder')) @@ -87,7 +95,8 @@ def validate(req): on_error_mock_2.assert_called() -def test_on_error_callback(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_on_error_callback(mocker, restful): def validate1(): raise NotImplementedError @@ -97,7 +106,8 @@ def validate2(x, *args): badones = [r for r in x if r.status.code == jina_pb2.StatusProto.ERROR] assert badones[0].pod == 'r3' - f = (Flow().add(name='r1') + f = (Flow(restful=restful) + .add(name='r1') .add(name='r3', uses='!BaseEncoder')) on_error_mock = mocker.Mock(wrap=validate2) @@ -108,14 +118,16 @@ def validate2(x, *args): on_error_mock.assert_called() -def test_no_error_callback(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_no_error_callback(mocker, restful): def validate2(): raise NotImplementedError def validate1(x, *args): pass - f = (Flow().add(name='r1') + f = (Flow(restful=restful) + .add(name='r1') .add(name='r3')) response_mock = mocker.Mock(wrap=validate1) @@ -128,8 +140,9 @@ def validate1(x, *args): on_error_mock.assert_not_called() -def test_flow_on_callback(): - f = Flow().add() +@pytest.mark.parametrize('restful', [False, True]) +def test_flow_on_callback(restful): + f = Flow(restful=restful).add() hit = [] def f1(*args): @@ -150,13 +163,14 @@ def f3(*args): hit.clear() -def test_flow_on_error_callback(): +@pytest.mark.parametrize('restful', [False, True]) +def test_flow_on_error_callback(restful): class DummyCrafter(BaseCrafter): def craft(self, *args, **kwargs): raise NotImplementedError - f = Flow().add(uses='DummyCrafter') + f = Flow(restful=restful).add(uses='DummyCrafter') hit = [] def f1(*args): diff --git a/tests/unit/flow/test_flow_index.py b/tests/unit/flow/test_flow_index.py index 3f5c5e5ee1aa5..ed9c07013d406 100644 --- a/tests/unit/flow/test_flow_index.py +++ b/tests/unit/flow/test_flow_index.py @@ -22,7 +22,8 @@ def random_queries(num_docs, chunks_per_doc=5): @pytest.mark.skipif('GITHUB_WORKFLOW' in os.environ, reason='skip the network test on github workflow') -def test_shards_insufficient_data(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_shards_insufficient_data(mocker, restful): """THIS IS SUPER IMPORTANT FOR TESTING SHARDS IF THIS FAILED, DONT IGNORE IT, DEBUG IT @@ -31,6 +32,7 @@ def test_shards_insufficient_data(mocker): parallel = 4 mock = mocker.Mock() + def validate(req): mock() assert len(req.docs) == 1 @@ -41,10 +43,11 @@ def validate(req): assert d.weight assert d.meta_info == b'hello world' - f = Flow().add(name='doc_pb', - uses=os.path.join(cur_dir, '../yaml/test-docpb.yml'), - parallel=parallel, - separated_workspace=True) + f = (Flow(restful=restful) + .add(name='doc_pb', + uses=os.path.join(cur_dir, '../yaml/test-docpb.yml'), + parallel=parallel, + separated_workspace=True)) with f: f.index(input_fn=random_docs(index_docs)) @@ -52,15 +55,17 @@ def validate(req): with f: pass time.sleep(2) - f = Flow().add(name='doc_pb', - uses=os.path.join(cur_dir, '../yaml/test-docpb.yml'), - parallel=parallel, - separated_workspace=True, polling='all', uses_after='_merge_chunks') + f = (Flow(restful=restful) + .add(name='doc_pb', + uses=os.path.join(cur_dir, '../yaml/test-docpb.yml'), + parallel=parallel, + separated_workspace=True, + polling='all', + uses_after='_merge_chunks')) with f: f.search(input_fn=random_queries(1, index_docs), callback_on='body', - on_done=validate - ) + on_done=validate) time.sleep(2) rm_files(['test-docshard-tmp']) mock.assert_called_once() diff --git a/tests/unit/flow/test_flow_merge.py b/tests/unit/flow/test_flow_merge.py index 4a7f9d77e455a..1ea36966ce993 100644 --- a/tests/unit/flow/test_flow_merge.py +++ b/tests/unit/flow/test_flow_merge.py @@ -22,9 +22,12 @@ def validate(req): assert len(chunk_ids) == 80 +# TODO(Deepankar): Gets stuck when `restful: True` - issues with `needs='gateway'` @pytest.mark.skip('this should fail as explained in https://github.com/jina-ai/jina/pull/730') -def test_this_will_fail(mocker): - f = (Flow().add(name='a11', uses='DummySegment') +@pytest.mark.parametrize('restful', [False]) +def test_this_will_fail(mocker, restful): + f = (Flow(restful=restful) + .add(name='a11', uses='DummySegment') .add(name='a12', uses='DummySegment', needs='gateway') .add(name='r1', uses='_merge_chunks', needs=['a11', 'a12']) .add(name='a21', uses='DummySegment', needs='gateway') @@ -40,9 +43,11 @@ def test_this_will_fail(mocker): response_mock.assert_called() +# TODO(Deepankar): Gets stuck when `restful: True` - issues with `needs='gateway'` @pytest.mark.timeout(180) -def test_this_should_work(mocker): - f = (Flow() +@pytest.mark.parametrize('restful', [False]) +def test_this_should_work(mocker, restful): + f = (Flow(restful=restful) .add(name='a1') .add(name='a11', uses='DummySegment', needs='a1') .add(name='a12', uses='DummySegment', needs='a1') diff --git a/tests/unit/flow/test_flow_multimode.py b/tests/unit/flow/test_flow_multimode.py index aa38c2c0c7a31..5ff6ea09e2383 100644 --- a/tests/unit/flow/test_flow_multimode.py +++ b/tests/unit/flow/test_flow_multimode.py @@ -1,6 +1,7 @@ import os from typing import List, Dict +import pytest import numpy as np from jina.executors.crafters import BaseSegmenter @@ -35,7 +36,8 @@ def encode(self, data: str, *args, **kwargs) -> 'np.ndarray': return np.array(output) -def test_flow_with_modalities(tmpdir): +@pytest.mark.parametrize('restful', [False, True]) +def test_flow_with_modalities(tmpdir, restful): os.environ['JINA_TEST_FLOW_MULTIMODE_WORKSPACE'] = str(tmpdir) def input_fn(): @@ -53,12 +55,13 @@ def input_fn(): return [doc1, doc2, doc3] - flow = Flow().add(name='crafter', uses='!MockSegmenter'). \ - add(name='encoder1', uses=os.path.join(cur_dir, 'yaml/mockencoder-mode1.yml')). \ - add(name='indexer1', uses=os.path.join(cur_dir, 'yaml/numpy-indexer-1.yml'), needs=['encoder1']). \ - add(name='encoder2', uses=os.path.join(cur_dir, 'yaml/mockencoder-mode2.yml'), needs=['crafter']). \ - add(name='indexer2', uses=os.path.join(cur_dir, 'yaml/numpy-indexer-2.yml')). \ - join(['indexer1', 'indexer2']) + flow = (Flow(restful=restful) + .add(name='crafter', uses='!MockSegmenter') + .add(name='encoder1', uses=os.path.join(cur_dir, 'yaml/mockencoder-mode1.yml')) + .add(name='indexer1', uses=os.path.join(cur_dir, 'yaml/numpy-indexer-1.yml'), needs=['encoder1']) + .add(name='encoder2', uses=os.path.join(cur_dir, 'yaml/mockencoder-mode2.yml'), needs=['crafter']) + .add(name='indexer2', uses=os.path.join(cur_dir, 'yaml/numpy-indexer-2.yml')) + .join(['indexer1', 'indexer2'])) with flow: flow.index(input_fn=input_fn) diff --git a/tests/unit/flow/test_flow_skip.py b/tests/unit/flow/test_flow_skip.py index 2cf1473d016de..23d054bc5d055 100644 --- a/tests/unit/flow/test_flow_skip.py +++ b/tests/unit/flow/test_flow_skip.py @@ -1,3 +1,5 @@ +import pytest + from jina.enums import SkipOnErrorType from jina.executors.crafters import BaseCrafter from jina.flow import Flow @@ -9,7 +11,8 @@ def craft(self, *args, **kwargs): return 1 / 0 -def test_bad_flow_skip_handle(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_bad_flow_skip_handle(mocker, restful): def validate(req): bad_routes = [r for r in req.routes if r.status.code >= jina_pb2.StatusProto.ERROR] assert len(bad_routes) == 3 @@ -20,7 +23,8 @@ def validate(req): assert bad_routes[2].pod == 'r3' assert bad_routes[2].status.code == jina_pb2.StatusProto.ERROR_CHAINED - f = (Flow(skip_on_error=SkipOnErrorType.HANDLE).add(name='r1', uses='DummyCrafter') + f = (Flow(restful=restful, skip_on_error=SkipOnErrorType.HANDLE) + .add(name='r1', uses='DummyCrafter') .add(name='r2') .add(name='r3')) @@ -33,7 +37,8 @@ def validate(req): on_error_mock.assert_called() -def test_bad_flow_skip_handle_join(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_bad_flow_skip_handle_join(mocker, restful): """When skipmode is set to handle, reduce driver wont work anymore""" def validate(req): @@ -53,7 +58,8 @@ def validate(req): assert bad_routes[-1].status.code == jina_pb2.StatusProto.ERROR assert bad_routes[-1].status.exception.name == 'GatewayPartialMessage' - f = (Flow(skip_on_error=SkipOnErrorType.HANDLE).add(name='r1', uses='DummyCrafter') + f = (Flow(restful=restful, skip_on_error=SkipOnErrorType.HANDLE) + .add(name='r1', uses='DummyCrafter') .add(name='r2') .add(name='r3', needs='r1') .needs(['r3', 'r2'])) @@ -67,14 +73,16 @@ def validate(req): on_error_mock.assert_called() -def test_bad_flow_skip_exec(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_bad_flow_skip_exec(mocker, restful): def validate(req): bad_routes = [r for r in req.routes if r.status.code >= jina_pb2.StatusProto.ERROR] assert len(bad_routes) == 1 assert req.status.code == jina_pb2.StatusProto.ERROR assert bad_routes[0].pod == 'r1' - f = (Flow(skip_on_error=SkipOnErrorType.EXECUTOR).add(name='r1', uses='DummyCrafter') + f = (Flow(restful=restful, skip_on_error=SkipOnErrorType.EXECUTOR) + .add(name='r1', uses='DummyCrafter') .add(name='r2') .add(name='r3')) @@ -87,7 +95,8 @@ def validate(req): on_error_mock.assert_called() -def test_bad_flow_skip_exec_join(mocker): +@pytest.mark.parametrize('restful', [False, True]) +def test_bad_flow_skip_exec_join(mocker, restful): """Make sure the exception wont affect the gather/reduce ops""" def validate(req): @@ -96,7 +105,8 @@ def validate(req): assert req.status.code == jina_pb2.StatusProto.ERROR assert bad_routes[0].pod == 'r1' - f = (Flow(skip_on_error=SkipOnErrorType.EXECUTOR).add(name='r1', uses='DummyCrafter') + f = (Flow(restful=restful, skip_on_error=SkipOnErrorType.EXECUTOR) + .add(name='r1', uses='DummyCrafter') .add(name='r2') .add(name='r3', needs='r1') .needs(['r3', 'r2']))