Skip to content

Commit

Permalink
Don't use async code where there is not awaiting
Browse files Browse the repository at this point in the history
  • Loading branch information
judahrand committed Aug 16, 2023
1 parent 21bb120 commit 5b3eefa
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 73 deletions.
10 changes: 5 additions & 5 deletions src/bentoml/_internal/client/grpc.py
Expand Up @@ -306,9 +306,9 @@ async def _call(
raise BentoMLException(
f"'{_bentoml_api.name}' takes multiple inputs; all inputs must be passed as keyword arguments."
)
serialized_req = await _bentoml_api.input.to_proto(attrs)
serialized_req = _bentoml_api.input.sync_to_proto(attrs)
else:
serialized_req = await _bentoml_api.input.to_proto(inp)
serialized_req = _bentoml_api.input.sync_to_proto(inp)

# A call includes api_name and given proto_fields
api_fn = {v: k for k, v in self._svc.apis.items()}
Expand All @@ -328,7 +328,7 @@ async def _call(
channel_kwargs=channel_kwargs,
method_kwargs=kwargs,
)
return await _bentoml_api.output.from_proto(
return _bentoml_api.output.sync_from_proto(
getattr(proto, proto.WhichOneof("content"))
)

Expand Down Expand Up @@ -658,9 +658,9 @@ def _call(
raise BentoMLException(
f"'{_bentoml_api.name}' takes multiple inputs; all inputs must be passed as keyword arguments."
)
serialized_req = asyncio.run(_bentoml_api.input.to_proto(attrs))
serialized_req = _bentoml_api.input.sync_to_proto(attrs)
else:
serialized_req = asyncio.run(_bentoml_api.input.to_proto(inp))
serialized_req = _bentoml_api.input.sync_to_proto(inp)

# A call includes api_name and given proto_fields
api_fn = {v: k for k, v in self._svc.apis.items()}
Expand Down
8 changes: 4 additions & 4 deletions src/bentoml/_internal/client/http.py
Expand Up @@ -153,9 +153,9 @@ async def _call(
raise BentoMLException(
f"'{api.name}' takes multiple inputs; all inputs must be passed as keyword arguments."
)
fake_resp = await api.input.to_http_response(kwargs, None)
fake_resp = api.input.sync_to_http_response(kwargs, None)
else:
fake_resp = await api.input.to_http_response(inp, None)
fake_resp = api.input.sync_to_http_response(inp, None)
req_body = fake_resp.body

resp = await self.client.post(
Expand Down Expand Up @@ -299,9 +299,9 @@ def _call(
raise BentoMLException(
f"'{api.name}' takes multiple inputs; all inputs must be passed as keyword arguments."
)
fake_resp = asyncio.run(api.input.to_http_response(kwargs, None))
fake_resp = api.input.sync_to_http_response(kwargs, None)
else:
fake_resp = asyncio.run(api.input.to_http_response(inp, None))
fake_resp = api.input.sync_to_http_response(inp, None)
req_body = fake_resp.body

resp = self.client.post(
Expand Down
17 changes: 14 additions & 3 deletions src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -175,18 +175,29 @@ def input_type(self) -> InputType:
async def from_http_request(self, request: Request) -> IOType:
raise NotImplementedError

@abstractmethod
async def to_http_response(
self, obj: IOType, ctx: Context | None = None
) -> Response:
raise NotImplementedError
return self.sync_to_http_response(obj, ctx)

@abstractmethod
async def from_proto(self, field: t.Any) -> IOType:
def sync_to_http_response(
self, obj: IOType, ctx: Context | None = None
) -> Response:
raise NotImplementedError

async def from_proto(self, field: t.Any) -> IOType:
return self.sync_from_proto(field)

@abstractmethod
def sync_from_proto(self, field: t.Any) -> IOType:
raise NotImplementedError

async def to_proto(self, obj: IOType) -> t.Any:
return self.sync_to_proto(obj)

@abstractmethod
def sync_to_proto(self, obj: IOType) -> t.Any:
raise NotImplementedError

def from_arrow(self, batch: pyarrow.RecordBatch) -> IOType:
Expand Down
6 changes: 3 additions & 3 deletions src/bentoml/_internal/io_descriptors/file.py
Expand Up @@ -214,7 +214,7 @@ def openapi_responses(self) -> OpenAPIResponse:
"x-bentoml-io-descriptor": self.to_spec(),
}

async def to_http_response(self, obj: FileType, ctx: Context | None = None):
def sync_to_http_response(self, obj: FileType, ctx: Context | None = None):
if isinstance(obj, bytes):
body = obj
else:
Expand All @@ -238,7 +238,7 @@ async def to_http_response(self, obj: FileType, ctx: Context | None = None):
)
return res

async def to_proto(self, obj: FileType) -> pb.File:
def sync_to_proto(self, obj: FileType) -> pb.File:
if isinstance(obj, bytes):
body = obj
else:
Expand Down Expand Up @@ -270,7 +270,7 @@ async def to_proto_v1alpha1(self, obj: FileType) -> pb_v1alpha1.File:

return pb_v1alpha1.File(kind=kind, content=body)

async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
def sync_from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
raise NotImplementedError

async def from_http_request(self, request: Request) -> FileLike[bytes]:
Expand Down
6 changes: 3 additions & 3 deletions src/bentoml/_internal/io_descriptors/image.py
Expand Up @@ -378,7 +378,7 @@ async def from_http_request(self, request: Request) -> ImageType:
except PIL.UnidentifiedImageError as err:
raise BadInput(f"Failed to parse uploaded image file: {err}") from None

async def to_http_response(
def sync_to_http_response(
self, obj: ImageType, ctx: Context | None = None
) -> Response:
if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj):
Expand Down Expand Up @@ -421,7 +421,7 @@ async def to_http_response(
headers={"content-disposition": content_disposition},
)

async def from_proto(self, field: pb.File | pb_v1alpha1.File | bytes) -> ImageType:
def sync_from_proto(self, field: pb.File | pb_v1alpha1.File | bytes) -> ImageType:
if isinstance(field, bytes):
content = field
elif isinstance(field, pb_v1alpha1.File):
Expand Down Expand Up @@ -477,7 +477,7 @@ async def to_proto_v1alpha1(self, obj: ImageType) -> pb_v1alpha1.File:

return pb_v1alpha1.File(kind=kind, content=ret.getvalue())

async def to_proto(self, obj: ImageType) -> pb.File:
def sync_to_proto(self, obj: ImageType) -> pb.File:
if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj):
image = PIL.Image.fromarray(obj, mode=self._pilmode)
elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(obj):
Expand Down
6 changes: 3 additions & 3 deletions src/bentoml/_internal/io_descriptors/json.py
Expand Up @@ -396,7 +396,7 @@ async def from_http_request(self, request: Request) -> JSONType:
else:
return json_obj

async def to_http_response(
def sync_to_http_response(
self, obj: JSONType | pydantic.BaseModel, ctx: Context | None = None
):
# This is to prevent cases where custom JSON encoder is used.
Expand Down Expand Up @@ -431,7 +431,7 @@ async def to_http_response(
else:
return Response(json_str, media_type=self._mime_type)

async def from_proto(self, field: struct_pb2.Value | bytes) -> JSONType:
def sync_from_proto(self, field: struct_pb2.Value | bytes) -> JSONType:
from google.protobuf.json_format import MessageToDict

if isinstance(field, bytes):
Expand Down Expand Up @@ -464,7 +464,7 @@ async def from_proto(self, field: struct_pb2.Value | bytes) -> JSONType:
raise BadInput(f"Invalid JSON input received: {e}") from None
return parsed

async def to_proto(self, obj: JSONType) -> struct_pb2.Value:
def sync_to_proto(self, obj: JSONType) -> struct_pb2.Value:
if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(obj):
if pkg_version_info("pydantic")[0] >= 2:
obj = obj.model_dump()
Expand Down
50 changes: 21 additions & 29 deletions src/bentoml/_internal/io_descriptors/multipart.py
Expand Up @@ -283,23 +283,19 @@ async def from_http_request(self, request: Request) -> dict[str, t.Any]:

return res

async def to_http_response(
def sync_to_http_response(
self, obj: dict[str, t.Any], ctx: Context | None = None
) -> Response:
resps = await asyncio.gather(
*tuple(
io_.to_http_response(obj[key], ctx) for key, io_ in self._inputs.items()
)
)
return await concat_to_multipart_response(dict(zip(self._inputs, resps)), ctx)
resps = (io_.to_http_response(obj[key], ctx) for key, io_ in self._inputs.items())
return concat_to_multipart_response(dict(zip(self._inputs, resps)), ctx)

def validate_input_mapping(self, field: t.MutableMapping[str, t.Any]) -> None:
if len(set(field) - set(self._inputs)) != 0:
raise InvalidArgument(
f"'{self!r}' accepts the following keys: {set(self._inputs)}. Given {field.__class__.__qualname__} has invalid fields: {set(field) - set(self._inputs)}",
) from None

async def from_proto(self, field: pb.Multipart) -> dict[str, t.Any]:
def sync_from_proto(self, field: pb.Multipart) -> dict[str, t.Any]:
from bentoml.grpc.utils import validate_proto_fields

if isinstance(field, bytes):
Expand All @@ -309,44 +305,40 @@ async def from_proto(self, field: pb.Multipart) -> dict[str, t.Any]:
message = field.fields
self.validate_input_mapping(message)
to_populate = {self._inputs[k]: message[k] for k in self._inputs}
reqs = await asyncio.gather(
*tuple(
descriptor.from_proto(
getattr(
part,
validate_proto_fields(
part.WhichOneof("representation"), descriptor
),
)
reqs = (
descriptor.from_proto(
getattr(
part,
validate_proto_fields(
part.WhichOneof("representation"), descriptor
),
)
for descriptor, part in to_populate.items()
)
for descriptor, part in to_populate.items()
)
return dict(zip(self._inputs.keys(), reqs))

@t.overload
async def _to_proto_impl(
def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: t.Literal["v1"]
) -> pb.Multipart:
...

@t.overload
async def _to_proto_impl(
def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: t.Literal["v1alpha1"]
) -> pb_v1alpha1.Multipart:
...

async def _to_proto_impl(
def _to_proto_impl(
self, obj: dict[str, t.Any], *, version: str
) -> _message.Message:
pb, _ = import_generated_stubs(version)

self.validate_input_mapping(obj)
resps = await asyncio.gather(
*tuple(
io_.to_proto(data)
for io_, data in zip(self._inputs.values(), obj.values())
)
resps = (
io_.to_proto(data)
for io_, data in zip(self._inputs.values(), obj.values())
)
return pb.Multipart(
fields=dict(
Expand All @@ -361,8 +353,8 @@ async def _to_proto_impl(
)
)

async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
return await self._to_proto_impl(obj, version="v1")
def sync_to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
return self._to_proto_impl(obj, version="v1")

async def to_proto_v1alpha1(self, obj: dict[str, t.Any]) -> pb_v1alpha1.Multipart:
return await self._to_proto_impl(obj, version="v1alpha1")
return self._to_proto_impl(obj, version="v1alpha1")
16 changes: 8 additions & 8 deletions src/bentoml/_internal/io_descriptors/numpy.py
Expand Up @@ -406,7 +406,7 @@ async def from_http_request(self, request: Request) -> ext.NpNDArray:

return self.validate_array(res)

async def to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None):
def sync_to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None):
"""
Process given objects and convert it to HTTP response.
Expand Down Expand Up @@ -492,7 +492,7 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]:
self._shape = sample.shape
return sample

async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:
def sync_from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:
"""
Process incoming protobuf request and convert it to ``numpy.ndarray``
Expand Down Expand Up @@ -591,18 +591,18 @@ async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:
return self.validate_array(array)

@t.overload
async def _to_proto_impl(
def _to_proto_impl(
self, obj: ext.NpNDArray, *, version: t.Literal["v1"]
) -> pb.NDArray:
...

@t.overload
async def _to_proto_impl(
def _to_proto_impl(
self, obj: ext.NpNDArray, *, version: t.Literal["v1alpha1"]
) -> pb_v1alpha1.NDArray:
...

async def _to_proto_impl(
def _to_proto_impl(
self, obj: ext.NpNDArray, *, version: str
) -> _message.Message:
"""
Expand Down Expand Up @@ -671,8 +671,8 @@ def spark_schema(self) -> pyspark.sql.types.StructType:
)
return StructType([StructField("out", out_spark_type, nullable=False)])

async def to_proto(self, obj: ext.NpNDArray) -> pb.NDArray:
return await self._to_proto_impl(obj, version="v1")
def sync_to_proto(self, obj: ext.NpNDArray) -> pb.NDArray:
return self._to_proto_impl(obj, version="v1")

async def to_proto_v1alpha1(self, obj: ext.NpNDArray) -> pb_v1alpha1.NDArray:
return await self._to_proto_impl(obj, version="v1alpha1")
return self._to_proto_impl(obj, version="v1alpha1")
16 changes: 8 additions & 8 deletions src/bentoml/_internal/io_descriptors/pandas.py
Expand Up @@ -1035,7 +1035,7 @@ async def from_http_request(self, request: Request) -> ext.PdSeries:
)
return self.validate_series(res)

async def to_http_response(
def sync_to_http_response(
self, obj: t.Any, ctx: Context | None = None
) -> Response:
"""
Expand Down Expand Up @@ -1094,7 +1094,7 @@ def validate_series(

return series

async def from_proto(self, field: pb.Series | bytes) -> ext.PdSeries:
def sync_from_proto(self, field: pb.Series | bytes) -> ext.PdSeries:
"""
Process incoming protobuf request and convert it to ``pandas.Series``
Expand Down Expand Up @@ -1144,18 +1144,18 @@ async def from_proto(self, field: pb.Series | bytes) -> ext.PdSeries:
return self.validate_series(series)

@t.overload
async def _to_proto_impl(
def _to_proto_impl(
self, obj: ext.PdSeries, *, version: t.Literal["v1"]
) -> pb.Series:
...

@t.overload
async def _to_proto_impl(
def _to_proto_impl(
self, obj: ext.PdSeries, *, version: t.Literal["v1alpha1"]
) -> pb_v1alpha1.Series:
...

async def _to_proto_impl(
def _to_proto_impl(
self, obj: ext.PdSeries, *, version: str
) -> _message.Message:
"""
Expand Down Expand Up @@ -1229,8 +1229,8 @@ def spark_schema(self) -> pyspark.sql.types.StructType:

return StructType([StructField("out", out_spark_type)])

async def to_proto(self, obj: ext.PdSeries) -> pb.Series:
return await self._to_proto_impl(obj, version="v1")
def sync_to_proto(self, obj: ext.PdSeries) -> pb.Series:
return self._to_proto_impl(obj, version="v1")

async def to_proto_v1alpha1(self, obj: ext.PdSeries) -> pb_v1alpha1.Series:
return await self._to_proto_impl(obj, version="v1alpha1")
return self._to_proto_impl(obj, version="v1alpha1")
6 changes: 3 additions & 3 deletions src/bentoml/_internal/io_descriptors/text.py
Expand Up @@ -163,7 +163,7 @@ async def from_http_request(self, request: Request) -> str:
obj = await request.body()
return str(obj.decode("utf-8"))

async def to_http_response(self, obj: str, ctx: Context | None = None) -> Response:
def sync_to_http_response(self, obj: str, ctx: Context | None = None) -> Response:
if ctx is not None:
res = Response(
obj,
Expand All @@ -176,12 +176,12 @@ async def to_http_response(self, obj: str, ctx: Context | None = None) -> Respon
else:
return Response(obj, media_type=self._mime_type)

async def from_proto(self, field: wrappers_pb2.StringValue | bytes) -> str:
def sync_from_proto(self, field: wrappers_pb2.StringValue | bytes) -> str:
if isinstance(field, bytes):
return field.decode("utf-8")
else:
assert isinstance(field, wrappers_pb2.StringValue)
return field.value

async def to_proto(self, obj: str) -> wrappers_pb2.StringValue:
def sync_to_proto(self, obj: str) -> wrappers_pb2.StringValue:
return wrappers_pb2.StringValue(value=obj)

0 comments on commit 5b3eefa

Please sign in to comment.