Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/0.2.1.2 #234

Merged
merged 2 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ include = ["./bisheng/*", "bisheng/**/*"]
bisheng = "bisheng.__main__:main"

[tool.poetry.dependencies]
bisheng_langchain = "0.2.1"
bisheng_langchain = "0.2.1.2"
bisheng_pyautogen = "0.1.18"
minio = "^7.2.0"
fastapi_jwt_auth = "^0.5.0"
Expand Down
242 changes: 95 additions & 147 deletions src/bisheng-langchain/bisheng_langchain/chat_models/host_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import requests
import sseclient
from bisheng_langchain.utils.requests import Requests
from langchain.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.schema import ChatGeneration, ChatResult
Expand Down Expand Up @@ -153,16 +153,28 @@ def validate_environment(cls, values: Dict) -> Dict:
values['host_base_url'] = get_from_dict_or_env(values, 'host_base_url', 'HostBaseUrl')
model = values['model_name']
try:
url = values['host_base_url'].split('/')[2]
config_ep = f'http://{url}/v2/models/{model}/config'
config = requests.get(url=config_ep, json={}, timeout=5).json()
policy = config.get('model_transaction_policy', {})
values['decoupled'] = policy.get('decoupled', False)
if cls != CustomLLMChat:
url = values['host_base_url'].split('/')[2]
config_ep = f'http://{url}/v2/models/{model}/config'
config = requests.get(url=config_ep, json={}, timeout=5).json()
policy = config.get('model_transaction_policy', {})
values['decoupled'] = policy.get('decoupled', False)
# Host class should set below code
if values['decoupled']:
values[
'host_base_url'] = f"{values['host_base_url']}/{values['model_name']}/generate_stream"
else:
values[
'host_base_url'] = f"{values['host_base_url']}/{values['model_name']}/infer"
except Exception:
raise Exception(f'Update Decoupled status faild for model {model}')

try:
values['client'] = requests.post
if values['headers']:
headers = values['headers']
else:
headers = {'Content-Type': 'application/json'}
values['client'] = Requests(headers=headers, request_timeout=values['request_timeout'])
except AttributeError:
raise ValueError('Try upgrading it with `pip install --upgrade requests`.')
return values
Expand All @@ -185,6 +197,7 @@ def completion_with_retry(self, **kwargs: Any) -> Any:

@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
self.client.headers = self.headers
messages = kwargs.get('messages')
temperature = kwargs.get('temperature')
top_p = kwargs.get('top_p')
Expand All @@ -204,17 +217,17 @@ def _completion_with_retry(**kwargs: Any) -> Any:
# print('messages:', messages)
# print('functions:', kwargs.get('functions', []))
if self.verbose:
print('payload', params)

method_name = 'infer' if not self.decoupled else 'generate'
url = f'{self.host_base_url}/{self.model_name}/{method_name}'
logger.info(f'payload={params}')
try:
resp = self.client(
url=url, json=params, timeout=self.request_timeout).json()
except requests.exceptions.Timeout:
raise Exception(f'timeout in host llm infer, url=[{url}]')
resp = self.client.post(url=self.host_base_url, json=params)
if resp.text.startswith('data:'):
resp = json.loads(resp.text.replace('data:', ''))
else:
resp = resp.json()
except requests.exceptions.Timeout as exc:
raise ValueError(f'timeout in host llm infer, url=[{self.host_base_url}]') from exc
except Exception as e:
raise Exception(f'exception in host llm infer: [{e}]')
raise ValueError(f'exception in host llm infer: [{e}]') from e

if not resp.get('choices', []):
logger.info(resp)
Expand Down Expand Up @@ -249,63 +262,46 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
'''用来处理同步请求'''
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)

def _stream(self, **kwargs: Any) -> Any:
async def acompletion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(self)

@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
if self.streaming:
if not self.decoupled:
raise Exception('Not supported stream protocol with non decoupled model')

headers = {'Accept': 'text/event-stream'}
url = f'{self.host_base_url}/{self.model_name}/generate_stream'
try:
res = requests.post(
url=url,
data=json.dumps(kwargs),
headers=headers,
stream=False)
except Exception as e:
raise Exception(f'exception in host llm sse infer: [{e}]')

res.raise_for_status()
try:
client = sseclient.SSEClient(res, timeout=self.request_timeout)
for event in client.events():
delta_data = json.loads(event.data)
yield delta_data
except requests.exceptions.Timeout:
raise Exception(f'timeout in host llm sse infer, url=[{url}]')
except Exception as e:
raise Exception(f'exception in host llm sse infer: [{e}]')
else:
method_name = 'infer' if not self.decoupled else 'generate'
url = f'{self.host_base_url}/{self.model_name}/{method_name}'
try:
res = requests.post(
url=url,
data=json.dumps(kwargs),
stream=False,
timeout=self.request_timeout)
return res.json()
except requests.exceptions.Timeout:
raise Exception(f'timeout in host llm infer, url=[{url}]')
except Exception as e:
raise Exception(f'exception in host llm infer: [{e}]')

if self.streaming:
for response in _completion_with_retry(**kwargs):
if response:
yield response
else:
return _completion_with_retry(**kwargs)
async def _acompletion_with_retry(**kwargs: Any) -> Any:
try:
async with self.client.apost(url=self.host_base_url, json=kwargs) as response:
if response.status != 200:
raise ValueError(f'Error: {response.status}')
async for txt in response.content.iter_any():
if b'\n' in txt:
for txt_ in txt.split(b'\n'):
yield txt_.decode('utf-8').strip()
else:
yield txt.decode('utf-8').strip()
except requests.exceptions.Timeout as exc:
raise ValueError(f'timeout in host llm infer, url=[{self.host_base_url}]') from exc
except Exception as e:
raise ValueError(f'exception in host llm infer: [{e}]') from e

async for response in _acompletion_with_retry(**kwargs):
is_error = False
if response:
if response.startswith('event:error'):
is_error = True
elif response.startswith('data:'):
yield (is_error, response[len('data:'):])
if is_error:
break
elif response.startswith('{'):
yield (is_error, response)
else:
continue

async def _agenerate(
self,
Expand All @@ -314,41 +310,49 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if not self.decoupled:
return self._generate(messages, stop, run_manager, **kwargs)

"""Generate chat completion with retry."""
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ''
role = 'assistant'
params['stream'] = True
function_call: Optional[dict] = None
for stream_resp in self._stream(
messages=message_dicts, **params
):
role = stream_resp['choices'][0]['delta'].get('role', role)
token = stream_resp['choices'][0]['delta'].get('content', '')
inner_completion += token or ''
_function_call = stream_resp['choices'][0]['delta'].get('function_call')
if _function_call:
if function_call is None:
function_call = _function_call
else:
function_call['arguments'] += _function_call['arguments']
if run_manager:
await run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
'content': inner_completion,
'role': role,
'function_call': function_call,
}
)
async for is_error, stream_resp in self.acompletion_with_retry(messages=message_dicts,
**params):
output = json.loads(stream_resp)
if is_error:
logger.error(stream_resp)
raise ValueError(stream_resp)

choices = output.get('choices')
if choices:
for choice in choices:
role = choice['delta'].get('role', role)
token = choice['delta'].get('content', '')
inner_completion += token or ''
_function_call = choice['delta'].get('function_call')
if run_manager:
await run_manager.on_llm_new_token(token)
if _function_call:
if function_call is None:
function_call = _function_call
else:
function_call['arguments'] += _function_call['arguments']
message = _convert_dict_to_message({
'content': inner_completion,
'role': role,
'function_call': function_call,
})
return ChatResult(generations=[ChatGeneration(message=message)])
else:
params['stream'] = False
response = self._stream(messages=message_dicts, **params)
response = [
response
async for _, response in self.acompletion_with_retry(messages=message_dicts,
**params)
]
response = json.loads(response[0])
return self._create_chat_result(response)

def _create_message_dicts(
Expand All @@ -373,7 +377,7 @@ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
gen = ChatGeneration(message=message)
generations.append(gen)

llm_output = {'token_usage': response['usage'], 'model_name': self.model_name}
llm_output = {'token_usage': response.get('usage'), 'model_name': self.model_name}
return ChatResult(generations=generations, llm_output=llm_output)

@property
Expand Down Expand Up @@ -525,65 +529,9 @@ class CustomLLMChat(BaseHostChatLLM):
temperature: float = 0.1
top_p: float = 0.1
max_tokens: int = 4096
host_base_url: str

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return 'custom_llm_chat'

def completion_with_retry(self, **kwargs: Any) -> Any:
retry_decorator = _create_retry_decorator(self)

@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
messages = kwargs.get('messages')
temperature = kwargs.get('temperature')
top_p = kwargs.get('top_p')
max_tokens = kwargs.get('max_tokens')
do_sample = kwargs.get('do_sample')
params = {
'messages': messages,
'model': self.model_name,
'top_p': top_p,
'temperature': temperature,
'max_tokens': max_tokens,
'do_sample': do_sample
}

if self.verbose:
print('payload', params)

resp = None
try:
resp = self.client(
url=self.host_base_url,
json=params,
timeout=self.request_timeout).json()
except requests.exceptions.Timeout:
raise Exception(
f'timeout in custom host llm infer, url=[{self.host_base_url}]')
except Exception as e:
raise Exception(f'exception in custom host llm infer: [{e}]')

return resp

return _completion_with_retry(**kwargs)

def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response['choices']:
message = _convert_dict_to_message(res['message'])
gen = ChatGeneration(message=message)
generations.append(gen)

llm_output = {'token_usage': response.get('usage', {}), 'model_name': self.model_name}
return ChatResult(generations=generations, llm_output=llm_output)

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return self._generate(messages, stop, run_manager, **kwargs)
Loading