Skip to content

Commit

Permalink
Feat pyproxy direct (#692)
Browse files Browse the repository at this point in the history
* 支持透传,accesstoken

* feat-pyproxy_direct

兼容了--accesstoken和透传的各种情况,修复bug

* 修复无accesstoken也可鉴权的情况

* 修复了js众测中 流式不报错的问题

* Update proxy.py

* 删除冗余注释

* Update proxy.py

* fix[trainer]: fix client trainer info not work with --taks-id / function chat call (#693)

* bug: 升级了 function call 字段的返回结构体以适配 langchain function call (#687)

* 升级了 function call 字段的返回结构体

* 格式化

* update: openai_adapter.ipynb

* feat: 浏览器增加鉴权开关 & 流式数据切分优化 (#689)

* feat: 浏览器增加鉴权开关 & 流式数据切分优化

* fix: 鉴权开关字段修改& 更新版本号

* fix: 鉴权开关字段修改& 更新版本号

---------

Co-authored-by: wangting31 <wangting31@baidu.com>

* fix: remove ipynb useless content

* fix: replace ERNIE-Bot description

* fix: add trainer cli info docs & remove ernie-bot in ipynb

* fix: remove useless or ambiguous content (#690)

* fix: remove ipynb useless content

* fix: replace ERNIE-Bot description

* fix: add trainer cli info docs & remove ernie-bot in ipynb

* fix[trainer]: cli trainer

* fix[resources]: functions passthourgh req warning

* chore: add trainer_ppl_file_tmpl.json

---------

Co-authored-by: Dobiichi-Origami <56953648+Dobiichi-Origami@users.noreply.github.com>
Co-authored-by: stonekim <shikuan@baidu.com>
Co-authored-by: wangting829 <1940087162@qq.com>
Co-authored-by: wangting31 <wangting31@baidu.com>

* feat[stat]: add star_timestamp (#685)

* feat[requestor]: add start_timestamp statis in python requestor

* chore: update version->0.4.1rc0

* chore: bump version-> 0.4.3

---------

Co-authored-by: NuODaniel <zhonghanjun@baidu.com>
Co-authored-by: Dobiichi-Origami <56953648+Dobiichi-Origami@users.noreply.github.com>
Co-authored-by: stonekim <shikuan@baidu.com>
Co-authored-by: wangting829 <1940087162@qq.com>
Co-authored-by: wangting31 <wangting31@baidu.com>
  • Loading branch information
6 people authored Jul 26, 2024
1 parent ed96bfa commit 7c90786
Show file tree
Hide file tree
Showing 16 changed files with 163 additions and 59 deletions.
57 changes: 57 additions & 0 deletions docs/trainer_ppl_file_tmpl.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"actions": [
{
"type": "LoadDataSetAction",
"datasets": {
"sourceType": "Platform",
"versions": [
{
"versionId": "ds-syeynswz2yxav8ix",
"samplingRate": 1
}
]
},
"eval_split_ratio": 10,
"corpus_config": {
"copy_data": true,
"config": [
{
"proportion": "1:2",
"corpus_type": 3,
"labels": [
"文本创作"
]
}
]
}
},
{
"type": "TrainAction",
"init_params": {
"train_mode": "SFT",
"train_type": "ERNIE-Speed-8K",
"train_config": {
"peft_type": "LoRA",
"epoch": 1,
"learning_rate": 0.0003,
"max_seq_len": 4096,
"logging_steps": 1,
"warmup_ratio": 0.1,
"weight_decay": 0.01,
"lora_rank": 8,
"lora_all_linear": "True"
},
"is_incr": false,
"job_name": "speed_math02",
"task_description": "task_desc1",
"job_description": "job_desc1"
}
},
{
"type": "ModelPublishAction"
}
],
"case_init_params": {
"case_type": "Finetune"
}
}
1 change: 1 addition & 0 deletions javascript/src/Base/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ export class BaseClient {
let fetchOptions;
// 如果enableOauth开启, 则放开鉴权
if (getCurrentEnvironment() === 'node' || this.enableOauth) {

// 检查鉴权信息
if (!(this.qianfanAccessKey && this.qianfanSecretKey) && !(this.qianfanAk && this.qianfanSk)) {
throw new Error('请设置AK/SK或QIANFAN_ACCESS_KEY/QIANFAN_SECRET_KEY');
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "qianfan"
version = "0.4.2"
version = "0.4.3"
description = "文心千帆大模型平台 Python SDK"
authors = []
license = "Apache-2.0"
Expand Down
4 changes: 2 additions & 2 deletions python/qianfan/common/client/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
)
from qianfan.consts import DefaultLLMModel
from qianfan.errors import InternalError
from qianfan.resources.llm.base import BaseResourceV1
from qianfan.resources.llm.chat_completion import (
ChatCompletion,
_ChatCompletionV1,
_ChatCompletionV2,
)
from qianfan.resources.typing import Literal, QfMessages, QfResponse
Expand Down Expand Up @@ -152,7 +152,7 @@ def _markup(s: str) -> str:
name: str
_client = client._real
if self.version == "1":
assert isinstance(_client, _ChatCompletionV1)
assert isinstance(_client, BaseResourceV1)
if _client._model is not None:
name = f"Model {_markup(_client._model)}"
elif _client._endpoint is not None:
Expand Down
27 changes: 3 additions & 24 deletions python/qianfan/common/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
print_error_msg,
print_info_msg,
)
from qianfan.config import encoding, get_config
from qianfan.config import encoding
from qianfan.utils.utils import check_dependency

app = typer.Typer(
Expand All @@ -53,29 +53,6 @@
_enable_traceback = False


@app.command(name="cache")
def clear(
clear: Optional[bool] = typer.Option(
None,
"--clear",
help="clear qianfan cache",
),
) -> None:
"""
clear qianfan cache.
"""
import shutil

# 要删除的目录路径
dir_path = get_config().CACHE_DIR
# 删除目录
try:
shutil.rmtree(dir_path)
print_info_msg(f"目录 {dir_path} 已删除")
except OSError as e:
print_info_msg(f"删除目录 {dir_path} 失败: {e}")


@app.command(name="openai")
@credential_required
def openai(
Expand Down Expand Up @@ -206,6 +183,7 @@ def proxy(
help="Ciphers to use (see stdlib ssl module's) [default: TLSv1]",
),
access_token: str = typer.Option("", "--access-token", help="Access token"),
direct: bool = typer.Option(False, "--direct", help="Direct connection to server"),
) -> None:
"""
Create a proxy server.
Expand Down Expand Up @@ -236,6 +214,7 @@ def proxy(
mock_port=mock_port,
ssl_config=ssl_config,
access_token=access_token,
direct=direct,
)


Expand Down
29 changes: 23 additions & 6 deletions python/qianfan/common/client/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def get_stream(


@base_app.middleware("http")
async def base_iam(request: Request, callback: Callable) -> Response:
async def base_openapi(request: Request, callback: Callable) -> Response:
"""
用于向base请求中添加访问令牌。
Expand All @@ -54,8 +54,21 @@ async def base_iam(request: Request, callback: Callable) -> Response:
Returns:
Response: 处理后的响应对象。
"""
if "access_token" in request.url._url:
key = request.url._url.split("?access_token=")[1]
if not proxy.direct and proxy.access_token is not None:
try:
key = request.url._url.split("?access_token=")[1]
except Exception:
return JSONResponse(
{
"error": {
"message": "No ACCESS_TOKEN provided, please check",
"type": "invalid_request_error",
"param": None,
"code": "NO_ACCESS_TOKEN",
}
},
status_code=401,
)
if key != proxy.access_token:
return JSONResponse(
{
Expand All @@ -72,14 +85,15 @@ async def base_iam(request: Request, callback: Callable) -> Response:
)
else:
new_scope = dict(request.scope)
new_scope["query_string"] = b""
if proxy._config.ACCESS_KEY and proxy._config.SECRET_KEY:
new_scope["query_string"] = b""
else:
proxy._direct = True
request = StarletteRequest(scope=new_scope, receive=request.receive)

else:
pass

resp = await proxy.get_response(request, DefaultValue.BaseURL)

if isinstance(resp, AsyncIterator):
return StreamingResponse(get_stream(resp), media_type="text/event-stream")

Expand Down Expand Up @@ -115,6 +129,7 @@ def entry(
mock_port: int,
ssl_config: Dict[str, Any],
access_token: Optional[str],
direct: bool,
) -> None:
import os

Expand All @@ -131,6 +146,8 @@ def entry(

if access_token is not None:
proxy.access_token = access_token
if direct:
proxy._direct = True

proxy.mock_port = mock_port

Expand Down
3 changes: 2 additions & 1 deletion python/qianfan/common/client/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def finetune(
if train_lora_all_linear is not None:
trainer.train_action.train_config.lora_all_linear = train_lora_all_linear

console.log(f"trainer id: {trainer.id} is created")
if daemon:
trainer.start()
console.print(
Expand Down Expand Up @@ -535,7 +536,7 @@ def info(
elif task_id:
trainers = Finetune.list()
for t in trainers:
for action in t.actions:
for action in t.actions.values():
if isinstance(action, TrainAction) and action.task_id == task_id:
trainer = t
break
Expand Down
4 changes: 3 additions & 1 deletion python/qianfan/common/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def credential_required(func: Callable) -> Callable:

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
check_credential()
direct = kwargs.get("direct", False)
if not direct:
check_credential()
return func(*args, **kwargs)

return wrapper
Expand Down
34 changes: 28 additions & 6 deletions python/qianfan/extensions/proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ class ClientProxy(object):
_retry_base_config: Optional[RetryConfig] = None
_retry_console_config: Optional[RetryConfig] = None
_access_token: Optional[str] = None
_direct: bool = False

def __init__(self) -> None:
pass

@property
def direct(self) -> Optional[bool]:
return self._direct

@property
def access_token(self) -> Optional[str]:
return self._access_token
Expand Down Expand Up @@ -87,9 +92,14 @@ def _sign(self, request: QfRequest) -> None:
url, path = request.url, urlparse(request.url).path

request.url = path
iam_sign(str(self._config.ACCESS_KEY), str(self._config.SECRET_KEY), request)
if not (request.query.get("client_id") or request.query.get("client_secret")):
iam_sign(
str(self._config.ACCESS_KEY), str(self._config.SECRET_KEY), request
)
request.url = url
if not request.headers.get("Authorization", None):
self._auth._ak = request.query.get("client_id")
self._auth._sk = request.query.get("client_secret")
request.query["access_token"] = self._auth.access_token()

async def get_request(self, request: Request, url_route: str) -> QfRequest:
Expand All @@ -112,9 +122,9 @@ async def get_request(self, request: Request, url_route: str) -> QfRequest:
# 获取请求头
if self.mock_port != -1:
url_route = f"http://127.0.0.1:{self.mock_port}"

url = url_route + request.url.path
host = urlparse(url_route).netloc

headers = {
"Content-Type": "application/json",
"Host": host,
Expand All @@ -124,7 +134,7 @@ async def get_request(self, request: Request, url_route: str) -> QfRequest:
json_body = await request.json()
return QfRequest(
url=url,
headers=headers,
headers=headers if not self._direct else dict(request.headers),
method=request.method,
query=dict(request.query_params),
json_body=json_body,
Expand All @@ -149,11 +159,23 @@ async def get_response(
try:
async with self._rate_limiter:
qf_req = await self.get_request(request, url_route)
self._sign(qf_req)
if self._direct:
pass
else:
self._sign(qf_req)
logging.debug(f"request: {qf_req}")

if qf_req.json_body.get("stream", False):
return self._client.arequest_stream(qf_req)
resp, session = await self._client.arequest(qf_req)
if (
"Content-Type" in resp.headers
and "application/json" in resp.headers["Content-Type"]
): # 判断返回中是否有流式数据
resp, session = await self._client.arequest(qf_req)
async with session:
json_body = await resp.json()
return json_body
else:
return self._client.arequest_stream(qf_req)
else:
resp, session = await self._client.arequest(qf_req)
async with session:
Expand Down
6 changes: 0 additions & 6 deletions python/qianfan/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,6 @@ def __init__(
self.name = name
if id is None or set_id is None:
self.auto_complete_info()
if (
(id is None and set_id is None)
or self.task_id is None
or self.job_id is None
):
log_warn("set_id/id or job_id/task_id should be provided")

def exec(
self, input: Optional[Dict] = None, **kwargs: Dict
Expand Down
4 changes: 3 additions & 1 deletion python/qianfan/resources/console/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,12 @@ class TrainStatus(str, Enum):
"""训练完成"""
Running = "Running"
"""训练进行中"""
Fail = "Fail"
Fail = "Failed"
"""训练失败"""
Stop = "Stopped"
"""训练停止"""
Waiting = "Waiting"
"""排队中"""


class ModelState(str, Enum):
Expand Down
29 changes: 20 additions & 9 deletions python/qianfan/resources/llm/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _default_model(cls) -> str:
def do(
self,
messages: Union[List[Dict], QfMessages],
functions: List[Dict],
functions: List[Dict] = [],
**kwargs: Any,
) -> Union[QfResponse, Iterator[QfResponse]]:
"""
Expand Down Expand Up @@ -152,18 +152,29 @@ def do(
```
"""
if len(functions) <= 0:
raise errors.InvalidArgumentError(
"functions should be a list of functions, "
"each function is a dictionary with name and description."
)
if kwargs.get("stream") is True:
raise errors.InvalidArgumentError("Function does not support stream mode.")

if isinstance(messages, QfMessages):
temp_messages = messages._to_list()
else:
temp_messages = messages
for k in [
"auto_concat_truncate",
"truncated_continue_prompt",
"truncate_overlong_msgs",
]:
if k in kwargs:
del kwargs[k]

for k in ["request_id"]:
if k in kwargs and kwargs.get(k) is None:
del kwargs[k]

if not functions:
# 没有传入functions,不特殊处理,直接走普通的base_resource请求模式
kwargs["messages"] = temp_messages
return super()._do(**kwargs)

if kwargs.get("stream") is True:
raise errors.InvalidArgumentError("Function does not support stream mode.")

functions_schemas = self._render_functions_prompt(functions)
temp_messages[0] = self._render_user_query_msg(
Expand Down
Loading

0 comments on commit 7c90786

Please sign in to comment.