From 5b3eefa58c1afe1de9f5cdad469df649c5fe81e0 Mon Sep 17 00:00:00 2001 From: Judah Rand <17158624+judahrand@users.noreply.github.com> Date: Wed, 16 Aug 2023 14:57:25 +0100 Subject: [PATCH] Don't use `async` code where there is not `await`ing --- src/bentoml/_internal/client/grpc.py | 10 ++-- src/bentoml/_internal/client/http.py | 8 +-- src/bentoml/_internal/io_descriptors/base.py | 17 +++++-- src/bentoml/_internal/io_descriptors/file.py | 6 +-- src/bentoml/_internal/io_descriptors/image.py | 6 +-- src/bentoml/_internal/io_descriptors/json.py | 6 +-- .../_internal/io_descriptors/multipart.py | 50 ++++++++----------- src/bentoml/_internal/io_descriptors/numpy.py | 16 +++--- .../_internal/io_descriptors/pandas.py | 16 +++--- src/bentoml/_internal/io_descriptors/text.py | 6 +-- src/bentoml/_internal/utils/formparser.py | 2 +- tests/unit/_internal/io/test_base.py | 6 +-- 12 files changed, 76 insertions(+), 73 deletions(-) diff --git a/src/bentoml/_internal/client/grpc.py b/src/bentoml/_internal/client/grpc.py index 75fed0ac465..947be5d6c75 100644 --- a/src/bentoml/_internal/client/grpc.py +++ b/src/bentoml/_internal/client/grpc.py @@ -306,9 +306,9 @@ async def _call( raise BentoMLException( f"'{_bentoml_api.name}' takes multiple inputs; all inputs must be passed as keyword arguments." ) - serialized_req = await _bentoml_api.input.to_proto(attrs) + serialized_req = _bentoml_api.input.sync_to_proto(attrs) else: - serialized_req = await _bentoml_api.input.to_proto(inp) + serialized_req = _bentoml_api.input.sync_to_proto(inp) # A call includes api_name and given proto_fields api_fn = {v: k for k, v in self._svc.apis.items()} @@ -328,7 +328,7 @@ async def _call( channel_kwargs=channel_kwargs, method_kwargs=kwargs, ) - return await _bentoml_api.output.from_proto( + return _bentoml_api.output.sync_from_proto( getattr(proto, proto.WhichOneof("content")) ) @@ -658,9 +658,9 @@ def _call( raise BentoMLException( f"'{_bentoml_api.name}' takes multiple inputs; all inputs must be passed as keyword arguments." ) - serialized_req = asyncio.run(_bentoml_api.input.to_proto(attrs)) + serialized_req = _bentoml_api.input.sync_to_proto(attrs) else: - serialized_req = asyncio.run(_bentoml_api.input.to_proto(inp)) + serialized_req = _bentoml_api.input.sync_to_proto(inp) # A call includes api_name and given proto_fields api_fn = {v: k for k, v in self._svc.apis.items()} diff --git a/src/bentoml/_internal/client/http.py b/src/bentoml/_internal/client/http.py index de7b31ed33a..0fddbd4b786 100644 --- a/src/bentoml/_internal/client/http.py +++ b/src/bentoml/_internal/client/http.py @@ -153,9 +153,9 @@ async def _call( raise BentoMLException( f"'{api.name}' takes multiple inputs; all inputs must be passed as keyword arguments." ) - fake_resp = await api.input.to_http_response(kwargs, None) + fake_resp = api.input.sync_to_http_response(kwargs, None) else: - fake_resp = await api.input.to_http_response(inp, None) + fake_resp = api.input.sync_to_http_response(inp, None) req_body = fake_resp.body resp = await self.client.post( @@ -299,9 +299,9 @@ def _call( raise BentoMLException( f"'{api.name}' takes multiple inputs; all inputs must be passed as keyword arguments." ) - fake_resp = asyncio.run(api.input.to_http_response(kwargs, None)) + fake_resp = api.input.sync_to_http_response(kwargs, None) else: - fake_resp = asyncio.run(api.input.to_http_response(inp, None)) + fake_resp = api.input.sync_to_http_response(inp, None) req_body = fake_resp.body resp = self.client.post( diff --git a/src/bentoml/_internal/io_descriptors/base.py b/src/bentoml/_internal/io_descriptors/base.py index 657cc47f06c..5d95981fb6c 100644 --- a/src/bentoml/_internal/io_descriptors/base.py +++ b/src/bentoml/_internal/io_descriptors/base.py @@ -175,18 +175,29 @@ def input_type(self) -> InputType: async def from_http_request(self, request: Request) -> IOType: raise NotImplementedError - @abstractmethod async def to_http_response( self, obj: IOType, ctx: Context | None = None ) -> Response: - raise NotImplementedError + return self.sync_to_http_response(obj, ctx) @abstractmethod - async def from_proto(self, field: t.Any) -> IOType: + def sync_to_http_response( + self, obj: IOType, ctx: Context | None = None + ) -> Response: raise NotImplementedError + async def from_proto(self, field: t.Any) -> IOType: + return self.sync_from_proto(field) + @abstractmethod + def sync_from_proto(self, field: t.Any) -> IOType: + raise NotImplementedError + async def to_proto(self, obj: IOType) -> t.Any: + return self.sync_to_proto(obj) + + @abstractmethod + def sync_to_proto(self, obj: IOType) -> t.Any: raise NotImplementedError def from_arrow(self, batch: pyarrow.RecordBatch) -> IOType: diff --git a/src/bentoml/_internal/io_descriptors/file.py b/src/bentoml/_internal/io_descriptors/file.py index d503762d571..2fa7a68a44a 100644 --- a/src/bentoml/_internal/io_descriptors/file.py +++ b/src/bentoml/_internal/io_descriptors/file.py @@ -214,7 +214,7 @@ def openapi_responses(self) -> OpenAPIResponse: "x-bentoml-io-descriptor": self.to_spec(), } - async def to_http_response(self, obj: FileType, ctx: Context | None = None): + def sync_to_http_response(self, obj: FileType, ctx: Context | None = None): if isinstance(obj, bytes): body = obj else: @@ -238,7 +238,7 @@ async def to_http_response(self, obj: FileType, ctx: Context | None = None): ) return res - async def to_proto(self, obj: FileType) -> pb.File: + def sync_to_proto(self, obj: FileType) -> pb.File: if isinstance(obj, bytes): body = obj else: @@ -270,7 +270,7 @@ async def to_proto_v1alpha1(self, obj: FileType) -> pb_v1alpha1.File: return pb_v1alpha1.File(kind=kind, content=body) - async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]: + def sync_from_proto(self, field: pb.File | bytes) -> FileLike[bytes]: raise NotImplementedError async def from_http_request(self, request: Request) -> FileLike[bytes]: diff --git a/src/bentoml/_internal/io_descriptors/image.py b/src/bentoml/_internal/io_descriptors/image.py index 482552d1464..c2f8ae2f541 100644 --- a/src/bentoml/_internal/io_descriptors/image.py +++ b/src/bentoml/_internal/io_descriptors/image.py @@ -378,7 +378,7 @@ async def from_http_request(self, request: Request) -> ImageType: except PIL.UnidentifiedImageError as err: raise BadInput(f"Failed to parse uploaded image file: {err}") from None - async def to_http_response( + def sync_to_http_response( self, obj: ImageType, ctx: Context | None = None ) -> Response: if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj): @@ -421,7 +421,7 @@ async def to_http_response( headers={"content-disposition": content_disposition}, ) - async def from_proto(self, field: pb.File | pb_v1alpha1.File | bytes) -> ImageType: + def sync_from_proto(self, field: pb.File | pb_v1alpha1.File | bytes) -> ImageType: if isinstance(field, bytes): content = field elif isinstance(field, pb_v1alpha1.File): @@ -477,7 +477,7 @@ async def to_proto_v1alpha1(self, obj: ImageType) -> pb_v1alpha1.File: return pb_v1alpha1.File(kind=kind, content=ret.getvalue()) - async def to_proto(self, obj: ImageType) -> pb.File: + def sync_to_proto(self, obj: ImageType) -> pb.File: if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj): image = PIL.Image.fromarray(obj, mode=self._pilmode) elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(obj): diff --git a/src/bentoml/_internal/io_descriptors/json.py b/src/bentoml/_internal/io_descriptors/json.py index 1c20a43dc9e..af9a7cbe84f 100644 --- a/src/bentoml/_internal/io_descriptors/json.py +++ b/src/bentoml/_internal/io_descriptors/json.py @@ -396,7 +396,7 @@ async def from_http_request(self, request: Request) -> JSONType: else: return json_obj - async def to_http_response( + def sync_to_http_response( self, obj: JSONType | pydantic.BaseModel, ctx: Context | None = None ): # This is to prevent cases where custom JSON encoder is used. @@ -431,7 +431,7 @@ async def to_http_response( else: return Response(json_str, media_type=self._mime_type) - async def from_proto(self, field: struct_pb2.Value | bytes) -> JSONType: + def sync_from_proto(self, field: struct_pb2.Value | bytes) -> JSONType: from google.protobuf.json_format import MessageToDict if isinstance(field, bytes): @@ -464,7 +464,7 @@ async def from_proto(self, field: struct_pb2.Value | bytes) -> JSONType: raise BadInput(f"Invalid JSON input received: {e}") from None return parsed - async def to_proto(self, obj: JSONType) -> struct_pb2.Value: + def sync_to_proto(self, obj: JSONType) -> struct_pb2.Value: if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(obj): if pkg_version_info("pydantic")[0] >= 2: obj = obj.model_dump() diff --git a/src/bentoml/_internal/io_descriptors/multipart.py b/src/bentoml/_internal/io_descriptors/multipart.py index a403529f63e..ea0440bb49c 100644 --- a/src/bentoml/_internal/io_descriptors/multipart.py +++ b/src/bentoml/_internal/io_descriptors/multipart.py @@ -283,15 +283,11 @@ async def from_http_request(self, request: Request) -> dict[str, t.Any]: return res - async def to_http_response( + def sync_to_http_response( self, obj: dict[str, t.Any], ctx: Context | None = None ) -> Response: - resps = await asyncio.gather( - *tuple( - io_.to_http_response(obj[key], ctx) for key, io_ in self._inputs.items() - ) - ) - return await concat_to_multipart_response(dict(zip(self._inputs, resps)), ctx) + resps = (io_.to_http_response(obj[key], ctx) for key, io_ in self._inputs.items()) + return concat_to_multipart_response(dict(zip(self._inputs, resps)), ctx) def validate_input_mapping(self, field: t.MutableMapping[str, t.Any]) -> None: if len(set(field) - set(self._inputs)) != 0: @@ -299,7 +295,7 @@ def validate_input_mapping(self, field: t.MutableMapping[str, t.Any]) -> None: f"'{self!r}' accepts the following keys: {set(self._inputs)}. Given {field.__class__.__qualname__} has invalid fields: {set(field) - set(self._inputs)}", ) from None - async def from_proto(self, field: pb.Multipart) -> dict[str, t.Any]: + def sync_from_proto(self, field: pb.Multipart) -> dict[str, t.Any]: from bentoml.grpc.utils import validate_proto_fields if isinstance(field, bytes): @@ -309,44 +305,40 @@ async def from_proto(self, field: pb.Multipart) -> dict[str, t.Any]: message = field.fields self.validate_input_mapping(message) to_populate = {self._inputs[k]: message[k] for k in self._inputs} - reqs = await asyncio.gather( - *tuple( - descriptor.from_proto( - getattr( - part, - validate_proto_fields( - part.WhichOneof("representation"), descriptor - ), - ) + reqs = ( + descriptor.from_proto( + getattr( + part, + validate_proto_fields( + part.WhichOneof("representation"), descriptor + ), ) - for descriptor, part in to_populate.items() ) + for descriptor, part in to_populate.items() ) return dict(zip(self._inputs.keys(), reqs)) @t.overload - async def _to_proto_impl( + def _to_proto_impl( self, obj: dict[str, t.Any], *, version: t.Literal["v1"] ) -> pb.Multipart: ... @t.overload - async def _to_proto_impl( + def _to_proto_impl( self, obj: dict[str, t.Any], *, version: t.Literal["v1alpha1"] ) -> pb_v1alpha1.Multipart: ... - async def _to_proto_impl( + def _to_proto_impl( self, obj: dict[str, t.Any], *, version: str ) -> _message.Message: pb, _ = import_generated_stubs(version) self.validate_input_mapping(obj) - resps = await asyncio.gather( - *tuple( - io_.to_proto(data) - for io_, data in zip(self._inputs.values(), obj.values()) - ) + resps = ( + io_.to_proto(data) + for io_, data in zip(self._inputs.values(), obj.values()) ) return pb.Multipart( fields=dict( @@ -361,8 +353,8 @@ async def _to_proto_impl( ) ) - async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart: - return await self._to_proto_impl(obj, version="v1") + def sync_to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart: + return self._to_proto_impl(obj, version="v1") async def to_proto_v1alpha1(self, obj: dict[str, t.Any]) -> pb_v1alpha1.Multipart: - return await self._to_proto_impl(obj, version="v1alpha1") + return self._to_proto_impl(obj, version="v1alpha1") diff --git a/src/bentoml/_internal/io_descriptors/numpy.py b/src/bentoml/_internal/io_descriptors/numpy.py index 02b53b51424..af306df650e 100644 --- a/src/bentoml/_internal/io_descriptors/numpy.py +++ b/src/bentoml/_internal/io_descriptors/numpy.py @@ -406,7 +406,7 @@ async def from_http_request(self, request: Request) -> ext.NpNDArray: return self.validate_array(res) - async def to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None): + def sync_to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None): """ Process given objects and convert it to HTTP response. @@ -492,7 +492,7 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]: self._shape = sample.shape return sample - async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray: + def sync_from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray: """ Process incoming protobuf request and convert it to ``numpy.ndarray`` @@ -591,18 +591,18 @@ async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray: return self.validate_array(array) @t.overload - async def _to_proto_impl( + def _to_proto_impl( self, obj: ext.NpNDArray, *, version: t.Literal["v1"] ) -> pb.NDArray: ... @t.overload - async def _to_proto_impl( + def _to_proto_impl( self, obj: ext.NpNDArray, *, version: t.Literal["v1alpha1"] ) -> pb_v1alpha1.NDArray: ... - async def _to_proto_impl( + def _to_proto_impl( self, obj: ext.NpNDArray, *, version: str ) -> _message.Message: """ @@ -671,8 +671,8 @@ def spark_schema(self) -> pyspark.sql.types.StructType: ) return StructType([StructField("out", out_spark_type, nullable=False)]) - async def to_proto(self, obj: ext.NpNDArray) -> pb.NDArray: - return await self._to_proto_impl(obj, version="v1") + def sync_to_proto(self, obj: ext.NpNDArray) -> pb.NDArray: + return self._to_proto_impl(obj, version="v1") async def to_proto_v1alpha1(self, obj: ext.NpNDArray) -> pb_v1alpha1.NDArray: - return await self._to_proto_impl(obj, version="v1alpha1") + return self._to_proto_impl(obj, version="v1alpha1") diff --git a/src/bentoml/_internal/io_descriptors/pandas.py b/src/bentoml/_internal/io_descriptors/pandas.py index b6efe7a70e1..314b04575f1 100644 --- a/src/bentoml/_internal/io_descriptors/pandas.py +++ b/src/bentoml/_internal/io_descriptors/pandas.py @@ -1035,7 +1035,7 @@ async def from_http_request(self, request: Request) -> ext.PdSeries: ) return self.validate_series(res) - async def to_http_response( + def sync_to_http_response( self, obj: t.Any, ctx: Context | None = None ) -> Response: """ @@ -1094,7 +1094,7 @@ def validate_series( return series - async def from_proto(self, field: pb.Series | bytes) -> ext.PdSeries: + def sync_from_proto(self, field: pb.Series | bytes) -> ext.PdSeries: """ Process incoming protobuf request and convert it to ``pandas.Series`` @@ -1144,18 +1144,18 @@ async def from_proto(self, field: pb.Series | bytes) -> ext.PdSeries: return self.validate_series(series) @t.overload - async def _to_proto_impl( + def _to_proto_impl( self, obj: ext.PdSeries, *, version: t.Literal["v1"] ) -> pb.Series: ... @t.overload - async def _to_proto_impl( + def _to_proto_impl( self, obj: ext.PdSeries, *, version: t.Literal["v1alpha1"] ) -> pb_v1alpha1.Series: ... - async def _to_proto_impl( + def _to_proto_impl( self, obj: ext.PdSeries, *, version: str ) -> _message.Message: """ @@ -1229,8 +1229,8 @@ def spark_schema(self) -> pyspark.sql.types.StructType: return StructType([StructField("out", out_spark_type)]) - async def to_proto(self, obj: ext.PdSeries) -> pb.Series: - return await self._to_proto_impl(obj, version="v1") + def sync_to_proto(self, obj: ext.PdSeries) -> pb.Series: + return self._to_proto_impl(obj, version="v1") async def to_proto_v1alpha1(self, obj: ext.PdSeries) -> pb_v1alpha1.Series: - return await self._to_proto_impl(obj, version="v1alpha1") + return self._to_proto_impl(obj, version="v1alpha1") diff --git a/src/bentoml/_internal/io_descriptors/text.py b/src/bentoml/_internal/io_descriptors/text.py index 5bfa1ae3ced..9060621a921 100644 --- a/src/bentoml/_internal/io_descriptors/text.py +++ b/src/bentoml/_internal/io_descriptors/text.py @@ -163,7 +163,7 @@ async def from_http_request(self, request: Request) -> str: obj = await request.body() return str(obj.decode("utf-8")) - async def to_http_response(self, obj: str, ctx: Context | None = None) -> Response: + def sync_to_http_response(self, obj: str, ctx: Context | None = None) -> Response: if ctx is not None: res = Response( obj, @@ -176,12 +176,12 @@ async def to_http_response(self, obj: str, ctx: Context | None = None) -> Respon else: return Response(obj, media_type=self._mime_type) - async def from_proto(self, field: wrappers_pb2.StringValue | bytes) -> str: + def sync_from_proto(self, field: wrappers_pb2.StringValue | bytes) -> str: if isinstance(field, bytes): return field.decode("utf-8") else: assert isinstance(field, wrappers_pb2.StringValue) return field.value - async def to_proto(self, obj: str) -> wrappers_pb2.StringValue: + def sync_to_proto(self, obj: str) -> wrappers_pb2.StringValue: return wrappers_pb2.StringValue(value=obj) diff --git a/src/bentoml/_internal/utils/formparser.py b/src/bentoml/_internal/utils/formparser.py index 65b9f3126d1..3d047daed8f 100644 --- a/src/bentoml/_internal/utils/formparser.py +++ b/src/bentoml/_internal/utils/formparser.py @@ -338,7 +338,7 @@ def _get_disp_filename(headers: MutableHeaders) -> t.Optional[bytes]: return None -async def concat_to_multipart_response( +def concat_to_multipart_response( responses: t.Mapping[str, Response], ctx: Context | None ) -> Response: boundary = uuid.uuid4().hex diff --git a/tests/unit/_internal/io/test_base.py b/tests/unit/_internal/io/test_base.py index 2e4033ec527..6baf4fe9b52 100644 --- a/tests/unit/_internal/io/test_base.py +++ b/tests/unit/_internal/io/test_base.py @@ -46,16 +46,16 @@ def from_spec(cls, spec: dict[str, t.Any]) -> t.Self: def input_type(self) -> t.Any: return str - async def from_http_request(self, request: t.Any) -> t.Any: + def sync_from_http_request(self, request: t.Any) -> t.Any: return request async def to_http_response(self, obj: t.Any, ctx: Context | None = None) -> t.Any: return obj, ctx - async def from_proto(self, field: t.Any) -> t.Any: + def sync_from_proto(self, field: t.Any) -> t.Any: return field - async def to_proto(self, obj: t.Any) -> t.Any: + def sync_to_proto(self, obj: t.Any) -> t.Any: return obj def _from_sample(self, sample: t.Any):