Skip to content

Commit

Permalink
feat: wait all floating tasks at gateway shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jul 12, 2022
1 parent d1b60d4 commit 86c6ef5
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
1 change: 1 addition & 0 deletions jina/serve/runtimes/gateway/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ async def async_teardown(self):
# usually async_cancel should already have been called, but then its a noop
# if the runtime is stopped without a sigterm (e.g. as a context manager, this can happen)
self._health_servicer.enter_graceful_shutdown()
await self.streamer.wait_floating_requests_end()
await self.async_cancel()
await self._connection_pool.close()

Expand Down
5 changes: 3 additions & 2 deletions jina/serve/runtimes/gateway/http/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_fastapi_app(

@app.on_event('shutdown')
async def _shutdown():
await streamer.wait_floating_requests_end()
await connection_pool.close()

openapi_tags = []
Expand Down Expand Up @@ -106,7 +107,6 @@ async def _gateway_health():
return {}

from docarray import DocumentArray

from jina.proto import jina_pb2
from jina.serve.executors import __dry_run_endpoint__
from jina.serve.runtimes.gateway.http.models import (
Expand Down Expand Up @@ -311,14 +311,15 @@ async def foo(body: JinaRequestModel):
from dataclasses import asdict

import strawberry
from docarray import DocumentArray
from docarray.document.strawberry_type import (
JSONScalar,
StrawberryDocument,
StrawberryDocumentInput,
)
from strawberry.fastapi import GraphQLRouter

from docarray import DocumentArray

async def get_docs_from_endpoint(
data, target_executor, parameters, exec_endpoint
):
Expand Down
2 changes: 1 addition & 1 deletion jina/serve/runtimes/gateway/websocket/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ async def _status():

@app.on_event('shutdown')
async def _shutdown():
await streamer.wait_floating_requests_end()
await connection_pool.close()

@app.websocket('/')
Expand Down Expand Up @@ -233,7 +234,6 @@ async def _get_singleton_result(request_iterator) -> Dict:
return request_dict

from docarray import DocumentArray

from jina.proto import jina_pb2
from jina.serve.executors import __dry_run_endpoint__
from jina.serve.runtimes.gateway.http.models import PROTO_TO_PYDANTIC_MODELS
Expand Down
14 changes: 13 additions & 1 deletion jina/serve/stream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
self._request_handler = request_handler
self._result_handler = result_handler
self._end_of_iter_handler = end_of_iter_handler
self.total_num_floating_tasks_alive = 0

async def stream(
self, request_iterator, context=None, *args
Expand Down Expand Up @@ -187,6 +188,12 @@ async def handle_floating_responses():

asyncio.create_task(iterate_requests())
handle_floating_task = asyncio.create_task(handle_floating_responses())
self.total_num_floating_tasks_alive += 1

def floating_task_done(*args):
self.total_num_floating_tasks_alive -= 1

handle_floating_task.add_done_callback(floating_task_done)

while not all_requests_handled.is_set():
future = await result_queue.get()
Expand All @@ -198,4 +205,9 @@ async def handle_floating_responses():
except self._EndOfStreaming:
pass

await handle_floating_task
async def wait_floating_requests_end(self):
"""
Await this coroutine to make sure that all the floating tasks that the request handler may bring are properly consumed
"""
while self.total_num_floating_tasks_alive > 0:
await asyncio.sleep(0)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from jina import DocumentArray, Executor, Flow, __default_endpoint__, requests

TIME_SLEEP_FLOATING = 5
TIME_SLEEP_FLOATING = 2


class FloatingTestExecutor(Executor):
Expand Down Expand Up @@ -87,7 +87,6 @@ def test_multiple_floating_points(tmpdir, protocol):
start_time = time.time()
ret = f.post(on=__default_endpoint__, inputs=DocumentArray.empty(1))
end_time = time.time()
print(f' reply took {end_time - start_time}s')
assert (
end_time - start_time
) < TIME_SLEEP_FLOATING # check that the response arrives before the
Expand Down Expand Up @@ -142,7 +141,6 @@ def test_complex_flow(tmpdir, protocol):
start_time = time.time()
ret = f.post(on=__default_endpoint__, inputs=DocumentArray.empty(1))
end_time = time.time()
print(f' reply took {end_time - start_time}s')
assert (
end_time - start_time
) < TIME_SLEEP_FLOATING # check that the response arrives before the
Expand Down

0 comments on commit 86c6ef5

Please sign in to comment.