Skip to content

Commit

Permalink
fix: trainer run split_ratio (#548)
Browse files Browse the repository at this point in the history
* fix: trainer run split_ratio

* fix: langchain community version

* fix: langchain community version

* fix: langchain community version
  • Loading branch information
danielhjz committed May 24, 2024
1 parent 2ab4358 commit 1e4ef27
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "qianfan"
version = "0.3.13"
version = "0.3.13.1"
description = "文心千帆大模型平台 Python SDK"
authors = []
license = "Apache-2.0"
Expand Down
15 changes: 15 additions & 0 deletions python/qianfan/tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pytest

from qianfan import errors
from qianfan.dataset import Dataset
from qianfan.dataset.data_source import QianfanDataSource
from qianfan.errors import InternalError, InvalidArgumentError
Expand Down Expand Up @@ -71,11 +72,25 @@ def test_train_action():
train_type="ERNIE-Speed", train_mode=console_consts.TrainMode.PostPretrain
)

with pytest.raises(errors.RequestError):
output = ta.exec(
input={
"datasets": {
"sourceType": (
console_consts.TrainDatasetSourceType.PrivateBos.value
),
"versions": [{"versionBosUri": "bos:/aaa/"}],
}
}
)

ta = TrainAction(train_type="ERNIE-Speed", train_mode=console_consts.TrainMode.SFT)
output = ta.exec(
input={
"datasets": {
"sourceType": console_consts.TrainDatasetSourceType.PrivateBos.value,
"versions": [{"versionBosUri": "bos:/aaa/"}],
"splitRatio": 20,
}
}
)
Expand Down
12 changes: 12 additions & 0 deletions python/qianfan/tests/utils/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,18 @@ def finetune_v2_create_job(body):
def finetune_v2_create_task(body):
task_id = f"task-{generate_letter_num_random_id(12)}"
job_id = body["jobId"]
if not body.get("datasetConfig", {}).get("splitRatio"):
return json_response(
{
"requestId": "bfad9ba9-9fc2-406d-ae84-c9e1ea92140a",
"code": "InappropriateJSON",
"message": (
"The JSON you provided was well-formed and valid, but not"
" appropriate for this operation. param[splitRatio] invalid."
),
},
status_code=400,
)
if job_id == MockFailedJobId:
task_id = MockFailedTaskId
global finetune_task_call_times
Expand Down
3 changes: 2 additions & 1 deletion python/qianfan/trainer/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ def __init__(
self,
dataset: Optional[Union[DatasetConfig, Dataset, str]] = None,
dataset_template: Optional[console_consts.DataTemplateType] = None,
eval_split_ratio: float = 20,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.eval_split_ratio = eval_split_ratio
self.corpus_proportion = kwargs.get("corpus_proportion")
self.eval_split_ratio = kwargs.get("eval_split_ratio")
self.sampling_rate = kwargs.get("sampling_rate")
if dataset is None:
raise InvalidArgumentError("dataset must be set")
Expand Down

0 comments on commit 1e4ef27

Please sign in to comment.