-
Notifications
You must be signed in to change notification settings - Fork 571
/
grpc.py
100 lines (77 loc) · 3.42 KB
/
grpc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright 2023 BentoML Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import logging
import typing as t
import orjson
import openllm
from .base import BaseAsyncClient, BaseClient
if t.TYPE_CHECKING:
import grpc_health.v1.health_pb2 as health_pb2
from bentoml.grpc.v1.service_pb2 import Response
logger = logging.getLogger(__name__)
class GrpcClientMixin:
_metadata: Response
@property
def model_name(self) -> str:
try:
return self._metadata.json.struct_value.fields["model_name"].string_value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
@property
def framework(self) -> t.Literal["pt", "flax", "tf"]:
try:
value = self._metadata.json.struct_value.fields["framework"].string_value
if value not in ("pt", "flax", "tf"):
raise KeyError
return value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
@property
def timeout(self) -> int:
try:
return int(self._metadata.json.struct_value.fields["timeout"].number_value)
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
@property
def model_id(self) -> str:
try:
return self._metadata.json.struct_value.fields["model_id"].string_value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
@property
def configuration(self) -> dict[str, t.Any]:
try:
v = self._metadata.json.struct_value.fields["configuration"].string_value
return orjson.loads(v)
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
if isinstance(result, dict):
return openllm.GenerationOutput(**result)
from google.protobuf.json_format import MessageToDict
return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))
class GrpcClient(GrpcClientMixin, BaseClient, client_type="grpc"):
def __init__(self, address: str, timeout: int = 30):
self._host, self._port = address.split(":")
super().__init__(address, timeout)
def health(self) -> health_pb2.HealthCheckResponse:
return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient, client_type="grpc"):
def __init__(self, address: str, timeout: int = 30):
self._host, self._port = address.split(":")
super().__init__(address, timeout)
async def health(self) -> health_pb2.HealthCheckResponse:
return await self._cached.health("bentoml.grpc.v1.BentoService")