From e9e40b5d2b524cf30554ad84b541b87147bd197b Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Mon, 5 Feb 2024 12:37:24 +0800 Subject: [PATCH] fix(server): clean the request resources after the response is consumed (#4481) * fix(server): clean the request resources after the response is consumed Signed-off-by: Frost Ming * fix: remove unnecessary attr params Signed-off-by: Frost Ming * fix: set attrs minimum version Signed-off-by: Frost Ming --------- Signed-off-by: Frost Ming --- pdm.lock | 2 +- pyproject.toml | 2 +- src/_bentoml_impl/server/app.py | 105 ++++++++++++++++---------------- 3 files changed, 55 insertions(+), 54 deletions(-) diff --git a/pdm.lock b/pdm.lock index 11a56921e33..2589b3ad77b 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "testing", "io", "grpc-channelz", "tracing-otlp", "monitor-otlp", "io-image", "aws", "io-file", "docs", "tracing-zipkin", "grpc-reflection", "grpc", "tracing", "tooling", "tracing-jaeger", "io-pandas", "triton"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:8a33d3fb6916aa6c768b53e3e032b624a35ecddcb2955b80957c9f5838478917" +content_hash = "sha256:6ab46492b995785e8120dbcb1596223087d8667dab036d4a49d405aaef373517" [[package]] name = "aiohttp" diff --git a/pyproject.toml b/pyproject.toml index fd6a3d899f9..b09e084b76a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "Jinja2>=3.0.1", "PyYAML>=5.0", "aiohttp", - "attrs>=21.1.0", + "attrs>=22.2.0", "cattrs>=22.1.0,<23.2.0", "circus>=0.17.0,!=0.17.2", "click>=7.0", diff --git a/src/_bentoml_impl/server/app.py b/src/_bentoml_impl/server/app.py index db6dba919eb..1120aa4512b 100644 --- a/src/_bentoml_impl/server/app.py +++ b/src/_bentoml_impl/server/app.py @@ -394,6 +394,8 @@ async def inner_infer( )(value) async def api_endpoint(self, name: str, request: Request) -> Response: + from starlette.background import BackgroundTask + from _bentoml_sdk.io_models import ARGS from _bentoml_sdk.io_models import KWARGS from bentoml._internal.container import BentoMLContainer @@ -409,58 +411,57 @@ async def api_endpoint(self, name: str, request: Request) -> Response: method = self.service.apis[name] func = getattr(self._service_instance, name) ctx = self.service.context - try: - serde = ALL_SERDE[media_type]() - input_data = await method.input_spec.from_http_request(request, serde) - input_args: tuple[t.Any, ...] = () - input_params = {k: getattr(input_data, k) for k in input_data.model_fields} - if method.ctx_param is not None: - input_params[method.ctx_param] = ctx - if ARGS in input_params: - input_args = tuple(input_params.pop(ARGS)) - if KWARGS in input_params: - input_params.update(input_params.pop(KWARGS)) - - original_func = get_original_func(func) - - if method.batchable: - output = await self.batch_infer(name, input_args, input_params) - elif inspect.iscoroutinefunction(original_func): - output = await func(*input_args, **input_params) - elif inspect.isasyncgenfunction(original_func): - output = func(*input_args, **input_params) - elif inspect.isgeneratorfunction(original_func): - - async def inner() -> t.AsyncGenerator[t.Any, None]: - gen = func(*input_args, **input_params) - while True: - try: - yield await self._to_thread(next, gen) - except StopIteration: + serde = ALL_SERDE[media_type]() + input_data = await method.input_spec.from_http_request(request, serde) + input_args: tuple[t.Any, ...] = () + input_params = {k: getattr(input_data, k) for k in input_data.model_fields} + if method.ctx_param is not None: + input_params[method.ctx_param] = ctx + if ARGS in input_params: + input_args = tuple(input_params.pop(ARGS)) + if KWARGS in input_params: + input_params.update(input_params.pop(KWARGS)) + + original_func = get_original_func(func) + + if method.batchable: + output = await self.batch_infer(name, input_args, input_params) + elif inspect.iscoroutinefunction(original_func): + output = await func(*input_args, **input_params) + elif inspect.isasyncgenfunction(original_func): + output = func(*input_args, **input_params) + elif inspect.isgeneratorfunction(original_func): + + async def inner() -> t.AsyncGenerator[t.Any, None]: + gen = func(*input_args, **input_params) + while True: + try: + yield await self._to_thread(next, gen) + except StopIteration: + break + except RuntimeError as e: + if "StopIteration" in str(e): break - except RuntimeError as e: - if "StopIteration" in str(e): - break - raise + raise - output = inner() - else: - output = await self._to_thread(func, *input_args, **input_params) - - response = await method.output_spec.to_http_response(output, serde) - response.headers.update({"Server": f"BentoML Service/{self.service.name}"}) - - if method.ctx_param is not None: - response.status_code = ctx.response.status_code - response.headers.update(ctx.response.metadata) - set_cookies(response, ctx.response.cookies) - if trace_context.request_id is not None: - response.headers["X-BentoML-Request-ID"] = str(trace_context.request_id) - if ( - BentoMLContainer.http.response.trace_id.get() - and trace_context.trace_id is not None - ): - response.headers["X-BentoML-Trace-ID"] = str(trace_context.trace_id) - finally: - await request.close() + output = inner() + else: + output = await self._to_thread(func, *input_args, **input_params) + + response = await method.output_spec.to_http_response(output, serde) + response.headers.update({"Server": f"BentoML Service/{self.service.name}"}) + + if method.ctx_param is not None: + response.status_code = ctx.response.status_code + response.headers.update(ctx.response.metadata) + set_cookies(response, ctx.response.cookies) + if trace_context.request_id is not None: + response.headers["X-BentoML-Request-ID"] = str(trace_context.request_id) + if ( + BentoMLContainer.http.response.trace_id.get() + and trace_context.trace_id is not None + ): + response.headers["X-BentoML-Trace-ID"] = str(trace_context.trace_id) + # clean the request resources after the response is consumed. + response.background = BackgroundTask(request.close) return response