diff --git a/examples/poetry.lock b/examples/poetry.lock index e6dfe100..f48b73a3 100644 --- a/examples/poetry.lock +++ b/examples/poetry.lock @@ -494,13 +494,13 @@ files = [ [[package]] name = "gentrace-py" -version = "0.15.0" +version = "0.15.2" description = "Python SDK for the Gentrace API" category = "main" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "gentrace_py-0.15.0.tar.gz", hash = "sha256:988e21f3ee82d9b837041acbd9665c895accb9e6866be7f9ca5f1866e9c6f6ef"}, + {file = "gentrace_py-0.15.2.tar.gz", hash = "sha256:7b507a4c5a03eb48dc04df0262f774664c61c817e2b939190b1fe7ab3998605c"}, ] [package.dependencies] @@ -519,7 +519,7 @@ vectorstores = ["pinecone-client (>=2.2.1,<3.0.0)"] [package.source] type = "file" -url = "../package/dist/gentrace_py-0.15.0.tar.gz" +url = "../package/dist/gentrace_py-0.15.2.tar.gz" [[package]] name = "idna" @@ -1251,4 +1251,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "8f02727726ffaa7cc25183df7814f5ed3e9aedb4f3497e8e91ce4d031fc23424" +content-hash = "218debc7274b1e79739bf1270572cf036f8e059bebdfe6cbfcfd3e45fdee3f8c" diff --git a/examples/pyproject.toml b/examples/pyproject.toml index ccc42084..9cb40059 100644 --- a/examples/pyproject.toml +++ b/examples/pyproject.toml @@ -11,7 +11,7 @@ openai = "^0.27.4" pinecone-client = "^2.2.1" python = "^3.8.1" python-dotenv = "^1.0.0" -gentrace-py = {path = "../package/dist/gentrace_py-0.15.0.tar.gz", develop = true} +gentrace-py = {path = "../package/dist/gentrace_py-0.15.2.tar.gz", develop = true} [tool.poetry.group.lint.dependencies] black = "^23.3.0" diff --git a/package/gentrace/providers/evaluation.py b/package/gentrace/providers/evaluation.py index 78283a35..6b1e8fe3 100644 --- a/package/gentrace/providers/evaluation.py +++ b/package/gentrace/providers/evaluation.py @@ -9,6 +9,11 @@ from gentrace.providers.init import ( GENTRACE_CONFIG_STATE, ) +from gentrace.providers.utils import ( + decrement_test_counter, + get_test_counter, + increment_test_counter, +) class Run(TypedDict): @@ -293,32 +298,38 @@ def run_test(pipeline_slug: str, handler) -> Result: "resultId": "161c623d-ee92-417f-823a-cf9f7eccf557", } """ - config = GENTRACE_CONFIG_STATE["global_gentrace_config"] - if not config: - raise ValueError("Gentrace API key not initialized. Call init() first.") + increment_test_counter() - api_client = ApiClient(configuration=config) - api = CoreApi(api_client=api_client) + try: + config = GENTRACE_CONFIG_STATE["global_gentrace_config"] + if not config: + raise ValueError("Gentrace API key not initialized. Call init() first.") - all_pipelines = get_pipelines() + api_client = ApiClient(configuration=config) + api = CoreApi(api_client=api_client) - matching_pipeline = next( - (pipeline for pipeline in all_pipelines if pipeline["slug"] == pipeline_slug), - None, - ) + all_pipelines = get_pipelines() - if not matching_pipeline: - raise ValueError(f"Could not find the specified pipeline ({pipeline_slug})") + matching_pipeline = next( + ( + pipeline + for pipeline in all_pipelines + if pipeline["slug"] == pipeline_slug + ), + None, + ) - test_cases = get_test_cases(matching_pipeline["id"]) + if not matching_pipeline: + raise ValueError(f"Could not find the specified pipeline ({pipeline_slug})") - test_runs = [] + test_cases = get_test_cases(matching_pipeline["id"]) - for test_case in test_cases: - [output, pipeline_run] = handler(test_case) + test_runs = [] - test_runs.append( - { + for test_case in test_cases: + [output, pipeline_run] = handler(test_case) + + test_run = { "caseId": test_case["id"], "stepRuns": [ { @@ -336,30 +347,38 @@ def run_test(pipeline_slug: str, handler) -> Result: for step_run in pipeline_run.step_runs ], } - ) - params = { - "pipelineId": matching_pipeline["id"], - "testRuns": test_runs, - } + if pipeline_run.get_id(): + test_run["id"] = pipeline_run.get_id() - if GENTRACE_CONFIG_STATE["GENTRACE_RUN_NAME"]: - params["name"] = GENTRACE_CONFIG_STATE["GENTRACE_RUN_NAME"] + test_runs.append(test_run) - if os.getenv("GENTRACE_BRANCH") or GENTRACE_CONFIG_STATE["GENTRACE_BRANCH"]: - params["branch"] = GENTRACE_CONFIG_STATE["GENTRACE_BRANCH"] or os.getenv( - "GENTRACE_BRANCH" - ) + params = { + "pipelineId": matching_pipeline["id"], + "testRuns": test_runs, + } - if os.getenv("GENTRACE_COMMIT") or GENTRACE_CONFIG_STATE["GENTRACE_COMMIT"]: - params["commit"] = GENTRACE_CONFIG_STATE["GENTRACE_COMMIT"] or os.getenv( - "GENTRACE_COMMIT" - ) - - params["collectionMethod"] = "runner" + if GENTRACE_CONFIG_STATE["GENTRACE_RUN_NAME"]: + params["name"] = GENTRACE_CONFIG_STATE["GENTRACE_RUN_NAME"] - response = api.test_result_post(params) - return response.body + if os.getenv("GENTRACE_BRANCH") or GENTRACE_CONFIG_STATE["GENTRACE_BRANCH"]: + params["branch"] = GENTRACE_CONFIG_STATE["GENTRACE_BRANCH"] or os.getenv( + "GENTRACE_BRANCH" + ) + + if os.getenv("GENTRACE_COMMIT") or GENTRACE_CONFIG_STATE["GENTRACE_COMMIT"]: + params["commit"] = GENTRACE_CONFIG_STATE["GENTRACE_COMMIT"] or os.getenv( + "GENTRACE_COMMIT" + ) + + params["collectionMethod"] = "runner" + + response = api.test_result_post(params) + return response.body + except Exception as e: + raise e + finally: + decrement_test_counter() __all__ = [ diff --git a/package/gentrace/providers/pipeline_run.py b/package/gentrace/providers/pipeline_run.py index e2dca513..37e831da 100644 --- a/package/gentrace/providers/pipeline_run.py +++ b/package/gentrace/providers/pipeline_run.py @@ -14,6 +14,7 @@ from gentrace.providers.step_run import StepRun from gentrace.providers.utils import ( from_date_string, + get_test_counter, run_post_background, to_date_string, ) @@ -47,6 +48,9 @@ def __init__(self, pipeline, id: Optional[str] = None): self.pipeline_run_id: str = id or str(uuid.uuid4()) self.step_runs: List[StepRun] = [] + def get_id(self): + return self.pipeline_run_id + def get_pipeline(self): return self.pipeline @@ -250,6 +254,11 @@ def checkpoint(self, step_info): ) async def asubmit(self) -> Dict: + if get_test_counter() > 0: + return { + "pipelineRunId": self.get_id(), + } + configuration = Configuration(host=self.pipeline.config.get("host")) configuration.access_token = self.pipeline.config.get("api_key") api_client = ApiClient(configuration=configuration) @@ -271,13 +280,11 @@ async def asubmit(self) -> Dict: for step_run in self.step_runs ] - pipeline_run_id = str(uuid.uuid4()) - try: pipeline_post_response = await run_post_background( core_api, { - "id": pipeline_run_id, + "id": self.pipeline_run_id, "slug": self.pipeline.slug, "stepRuns": step_runs_data, }, @@ -292,6 +299,11 @@ async def asubmit(self) -> Dict: return {"pipelineRunId": None} def submit(self, wait_for_server=False) -> Dict: + if get_test_counter() > 0: + return { + "pipelineRunId": self.get_id(), + } + configuration = Configuration(host=self.pipeline.config.get("host")) configuration.access_token = self.pipeline.config.get("api_key") diff --git a/package/gentrace/providers/utils.py b/package/gentrace/providers/utils.py index 812b91c3..083d9527 100644 --- a/package/gentrace/providers/utils.py +++ b/package/gentrace/providers/utils.py @@ -16,6 +16,9 @@ "log_debug", "log_info", "log_warn", + "get_test_counter", + "increment_test_counter", + "decrement_test_counter", ] logger = logging.getLogger(__name__) @@ -127,3 +130,22 @@ def wrapped_api_invocation(): executor, wrapped_api_invocation ) return result + + +test_counter = 0 + + +def get_test_counter(): + return test_counter + + +def increment_test_counter(): + global test_counter + test_counter += 1 + return test_counter + + +def decrement_test_counter(): + global test_counter + test_counter -= 1 + return test_counter diff --git a/package/tests/test_evaluation.py b/package/tests/test_evaluation.py index 1532c1ee..93cdcc2a 100644 --- a/package/tests/test_evaluation.py +++ b/package/tests/test_evaluation.py @@ -14,6 +14,7 @@ import gentrace from gentrace.providers.evaluation import OutputStep, construct_submission_payload from gentrace.providers.init import GENTRACE_CONFIG_STATE +from gentrace.providers.utils import get_test_counter def test_evaluation_get_test_cases(mocker, test_cases, setup_teardown_openai): @@ -23,8 +24,6 @@ def test_evaluation_get_test_cases(mocker, test_cases, setup_teardown_openai): body = json.dumps(test_cases, ensure_ascii=False).encode("utf-8") - print("before this") - gentrace_response = HTTPResponse( body=body, headers=headers, @@ -35,8 +34,6 @@ def test_evaluation_get_test_cases(mocker, test_cases, setup_teardown_openai): enforce_content_length=True, ) - print("after this") - gentrace_request = mocker.patch.object(gentrace.api_client.ApiClient, "request") gentrace_request.return_value = gentrace_response @@ -376,3 +373,68 @@ def create_checkpoint_callback(test_case): for runner in runner_list: assert len(runner.step_runs) == 2 assert runner.step_runs[0].outputs == {"value": 1100} + + +def test_evaluation_counter_rest(mocker, setup_teardown_openai, test_result_response): + pipeline = gentrace.Pipeline( + "guess-the-year", + openai_config={ + "api_key": os.getenv("OPENAI_KEY"), + }, + ) + + pipeline.setup() + + runner_list = [] + + def create_checkpoint_callback(test_case): + runner = pipeline.start() + runner.checkpoint({"inputs": {"x": 100, "y": 1000}, "outputs": {"value": 1100}}) + runner.checkpoint({"inputs": {"x": 100, "y": 1000}, "outputs": {"value": 1100}}) + runner_list.append(runner) + + return ["something", runner] + + assert get_test_counter() == 0 + + response = gentrace.run_test("guess-the-year", create_checkpoint_callback) + + assert get_test_counter() == 0 + + assert response.get("resultId", None) is not None + + for runner in runner_list: + assert len(runner.step_runs) == 2 + assert runner.step_runs[0].outputs == {"value": 1100} + + +def test_evaluation_counter_rest_when_run_test_fails( + mocker, setup_teardown_openai, test_result_response +): + pipeline = gentrace.Pipeline( + "guess-the-year", + openai_config={ + "api_key": os.getenv("OPENAI_KEY"), + }, + ) + + pipeline.setup() + + runner_list = [] + + def create_checkpoint_callback(test_case): + runner = pipeline.start() + runner.checkpoint({"inputs": {"x": 100, "y": 1000}, "outputs": {"value": 1100}}) + runner.checkpoint({"inputs": {"x": 100, "y": 1000}, "outputs": {"value": 1100}}) + runner_list.append(runner) + + return ["something", runner] + + assert get_test_counter() == 0 + + try: + gentrace.run_test("random-slug-no-exist", create_checkpoint_callback) + except Exception as e: + pass + + assert get_test_counter() == 0