-
Notifications
You must be signed in to change notification settings - Fork 189
/
post_model.py
302 lines (261 loc) · 10.4 KB
/
post_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# -*- coding: utf-8 -*-
"""Model wrapper for post-based inference apis."""
import json
import time
from abc import ABC
from typing import Any, Union, Sequence, List
import requests
from loguru import logger
from .model import ModelWrapperBase, ModelResponse
from ..constants import _DEFAULT_MAX_RETRIES
from ..constants import _DEFAULT_MESSAGES_KEY
from ..constants import _DEFAULT_RETRY_INTERVAL
from ..message import MessageBase
from ..utils.tools import _convert_to_str
class PostAPIModelWrapperBase(ModelWrapperBase, ABC):
"""The base model wrapper for the model deployed on the POST API."""
model_type: str = "post_api"
def __init__(
self,
config_name: str,
api_url: str,
headers: dict = None,
max_length: int = 2048,
timeout: int = 30,
json_args: dict = None,
post_args: dict = None,
max_retries: int = _DEFAULT_MAX_RETRIES,
messages_key: str = _DEFAULT_MESSAGES_KEY,
retry_interval: int = _DEFAULT_RETRY_INTERVAL,
**kwargs: Any,
) -> None:
"""Initialize the model wrapper.
Args:
config_name (`str`):
The id of the model.
api_url (`str`):
The url of the post request api.
headers (`dict`, defaults to `None`):
The headers of the api. Defaults to None.
max_length (`int`, defaults to `2048`):
The maximum length of the model.
timeout (`int`, defaults to `30`):
The timeout of the api. Defaults to 30.
json_args (`dict`, defaults to `None`):
The json arguments of the api. Defaults to None.
post_args (`dict`, defaults to `None`):
The post arguments of the api. Defaults to None.
max_retries (`int`, defaults to `3`):
The maximum number of retries when the `parse_func` raise an
exception.
messages_key (`str`, defaults to `inputs`):
The key of the input messages in the json argument.
retry_interval (`int`, defaults to `1`):
The interval between retries when a request fails.
Note:
When an object of `PostApiModelWrapper` is called, the arguments
will of post requests will be used as follows:
.. code-block:: python
request.post(
url=api_url,
headers=headers,
json={
messages_key: messages,
**json_args
},
**post_args
)
"""
super().__init__(config_name=config_name)
self.api_url = api_url
self.headers = headers
self.max_length = max_length
self.timeout = timeout
self.json_args = json_args or {}
self.post_args = post_args or {}
self.max_retries = max_retries
self.messages_key = messages_key
self.retry_interval = retry_interval
def _parse_response(self, response: dict) -> ModelResponse:
"""Parse the response json data into ModelResponse"""
return ModelResponse(raw=response)
def __call__(self, input_: str, **kwargs: Any) -> ModelResponse:
"""Calling the model with requests.post.
Args:
input_ (`str`):
The input string to the model.
Returns:
`dict`: A dictionary that contains the response of the model and
related
information (e.g. cost, time, the number of tokens, etc.).
Note:
`parse_func`, `fault_handler` and `max_retries` are reserved for
`_response_parse_decorator` to parse and check the response
generated by model wrapper. Their usages are listed as follows:
- `parse_func` is a callable function used to parse and check
the response generated by the model, which takes the response
as input.
- `max_retries` is the maximum number of retries when the
`parse_func` raise an exception.
- `fault_handler` is a callable function which is called
when the response generated by the model is invalid after
`max_retries` retries.
"""
# step1: prepare keyword arguments
post_args = {**self.post_args, **kwargs}
request_kwargs = {
"url": self.api_url,
"json": {self.messages_key: input_, **self.json_args},
"headers": self.headers or {},
**post_args,
}
# step2: prepare post requests
for i in range(1, self.max_retries + 1):
response = requests.post(**request_kwargs)
if response.status_code == requests.codes.ok:
break
if i < self.max_retries:
logger.warning(
f"Failed to call the model with "
f"requests.codes == {response.status_code}, retry "
f"{i + 1}/{self.max_retries} times",
)
time.sleep(i * self.retry_interval)
# step3: record model invocation
# record the model api invocation, which will be skipped if
# `FileManager.save_api_invocation` is `False`
self._save_model_invocation(
arguments=request_kwargs,
response=response.json(),
)
# step4: parse the response
if response.status_code == requests.codes.ok:
return self._parse_response(response.json())
else:
logger.error(json.dumps(request_kwargs, indent=4))
raise RuntimeError(
f"Failed to call the model with {response.json()}",
)
class PostAPIChatWrapper(PostAPIModelWrapperBase):
"""A post api model wrapper compatible with openai chat, e.g., vLLM,
FastChat."""
model_type: str = "post_api_chat"
def _parse_response(self, response: dict) -> ModelResponse:
return ModelResponse(
text=response["data"]["response"]["choices"][0]["message"][
"content"
],
)
def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
) -> Union[List[dict]]:
"""Format the input messages into a list of dict, which is
compatible to OpenAI Chat API.
Args:
args (`Union[MessageBase, Sequence[MessageBase]]`):
The input arguments to be formatted, where each argument
should be a `Msg` object, or a list of `Msg` objects.
In distribution, placeholder is also allowed.
Returns:
`Union[List[dict]]`:
The formatted messages.
"""
messages = []
for arg in args:
if arg is None:
continue
if isinstance(arg, MessageBase):
messages.append(
{
"role": arg.role,
"name": arg.name,
"content": _convert_to_str(arg.content),
},
)
elif isinstance(arg, list):
messages.extend(self.format(*arg))
else:
raise TypeError(
f"The input should be a Msg object or a list "
f"of Msg objects, got {type(arg)}.",
)
return messages
class PostAPIDALLEWrapper(PostAPIModelWrapperBase):
"""A post api model wrapper compatible with openai dall_e"""
model_type: str = "post_api_dall_e"
deprecated_model_type: str = "post_api_dalle"
def _parse_response(self, response: dict) -> ModelResponse:
if "data" not in response["data"]["response"]:
if "error" in response["data"]["response"]:
error_msg = response["data"]["response"]["error"]["message"]
else:
error_msg = response["data"]["response"]
logger.error(f"Error in API call:\n{error_msg}")
raise ValueError(f"Error in API call:\n{error_msg}")
urls = [img["url"] for img in response["data"]["response"]["data"]]
return ModelResponse(image_urls=urls)
def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
) -> Union[List[dict], str]:
raise RuntimeError(
f"Model Wrapper [{type(self).__name__}] doesn't "
f"need to format the input. Please try to use the "
f"model wrapper directly.",
)
class PostAPIEmbeddingWrapper(PostAPIModelWrapperBase):
"""
A post api model wrapper for embedding model
"""
model_type: str = "post_api_embedding"
def _parse_response(self, response: dict) -> ModelResponse:
"""
Parse the response json data into ModelResponse with embedding.
Args:
response (`dict`):
The response obtained from the API. This parsing assume the
structure of the response is as following:
{
"code": 200,
"data": {
...
"response": {
"data": [
{
"embedding": [
0.001,
...
],
...
}
],
"model": "xxxx",
...
},
},
}
"""
if "data" not in response["data"]["response"]:
if "error" in response["data"]["response"]:
error_msg = response["data"]["response"]["error"]["message"]
else:
error_msg = response["data"]["response"]
logger.error(f"Error in embedding API call:\n{error_msg}")
raise ValueError(f"Error in embedding API call:\n{error_msg}")
embeddings = [
data["embedding"] for data in response["data"]["response"]["data"]
]
return ModelResponse(
embedding=embeddings,
raw=response,
)
def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
) -> Union[List[dict], str]:
raise RuntimeError(
f"Model Wrapper [{type(self).__name__}] doesn't "
f"need to format the input. Please try to use the "
f"model wrapper directly.",
)