Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ openai = "^0.27.4"
pinecone-client = "^2.2.1"
python = "^3.8.1"
python-dotenv = "^1.0.0"
gentrace-py = {path = "../package/dist/gentrace_py-0.15.0.tar.gz", develop = true}
gentrace-py = {path = "../package/dist/gentrace_py-0.15.2.tar.gz", develop = true}

[tool.poetry.group.lint.dependencies]
black = "^23.3.0"
Expand Down
93 changes: 56 additions & 37 deletions package/gentrace/providers/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from gentrace.providers.init import (
GENTRACE_CONFIG_STATE,
)
from gentrace.providers.utils import (
decrement_test_counter,
get_test_counter,
increment_test_counter,
)


class Run(TypedDict):
Expand Down Expand Up @@ -293,32 +298,38 @@ def run_test(pipeline_slug: str, handler) -> Result:
"resultId": "161c623d-ee92-417f-823a-cf9f7eccf557",
}
"""
config = GENTRACE_CONFIG_STATE["global_gentrace_config"]
if not config:
raise ValueError("Gentrace API key not initialized. Call init() first.")
increment_test_counter()

api_client = ApiClient(configuration=config)
api = CoreApi(api_client=api_client)
try:
config = GENTRACE_CONFIG_STATE["global_gentrace_config"]
if not config:
raise ValueError("Gentrace API key not initialized. Call init() first.")

all_pipelines = get_pipelines()
api_client = ApiClient(configuration=config)
api = CoreApi(api_client=api_client)

matching_pipeline = next(
(pipeline for pipeline in all_pipelines if pipeline["slug"] == pipeline_slug),
None,
)
all_pipelines = get_pipelines()

if not matching_pipeline:
raise ValueError(f"Could not find the specified pipeline ({pipeline_slug})")
matching_pipeline = next(
(
pipeline
for pipeline in all_pipelines
if pipeline["slug"] == pipeline_slug
),
None,
)

test_cases = get_test_cases(matching_pipeline["id"])
if not matching_pipeline:
raise ValueError(f"Could not find the specified pipeline ({pipeline_slug})")

test_runs = []
test_cases = get_test_cases(matching_pipeline["id"])

for test_case in test_cases:
[output, pipeline_run] = handler(test_case)
test_runs = []

test_runs.append(
{
for test_case in test_cases:
[output, pipeline_run] = handler(test_case)

test_run = {
"caseId": test_case["id"],
"stepRuns": [
{
Expand All @@ -336,30 +347,38 @@ def run_test(pipeline_slug: str, handler) -> Result:
for step_run in pipeline_run.step_runs
],
}
)

params = {
"pipelineId": matching_pipeline["id"],
"testRuns": test_runs,
}
if pipeline_run.get_id():
test_run["id"] = pipeline_run.get_id()

if GENTRACE_CONFIG_STATE["GENTRACE_RUN_NAME"]:
params["name"] = GENTRACE_CONFIG_STATE["GENTRACE_RUN_NAME"]
test_runs.append(test_run)

if os.getenv("GENTRACE_BRANCH") or GENTRACE_CONFIG_STATE["GENTRACE_BRANCH"]:
params["branch"] = GENTRACE_CONFIG_STATE["GENTRACE_BRANCH"] or os.getenv(
"GENTRACE_BRANCH"
)
params = {
"pipelineId": matching_pipeline["id"],
"testRuns": test_runs,
}

if os.getenv("GENTRACE_COMMIT") or GENTRACE_CONFIG_STATE["GENTRACE_COMMIT"]:
params["commit"] = GENTRACE_CONFIG_STATE["GENTRACE_COMMIT"] or os.getenv(
"GENTRACE_COMMIT"
)

params["collectionMethod"] = "runner"
if GENTRACE_CONFIG_STATE["GENTRACE_RUN_NAME"]:
params["name"] = GENTRACE_CONFIG_STATE["GENTRACE_RUN_NAME"]

response = api.test_result_post(params)
return response.body
if os.getenv("GENTRACE_BRANCH") or GENTRACE_CONFIG_STATE["GENTRACE_BRANCH"]:
params["branch"] = GENTRACE_CONFIG_STATE["GENTRACE_BRANCH"] or os.getenv(
"GENTRACE_BRANCH"
)

if os.getenv("GENTRACE_COMMIT") or GENTRACE_CONFIG_STATE["GENTRACE_COMMIT"]:
params["commit"] = GENTRACE_CONFIG_STATE["GENTRACE_COMMIT"] or os.getenv(
"GENTRACE_COMMIT"
)

params["collectionMethod"] = "runner"

response = api.test_result_post(params)
return response.body
except Exception as e:
raise e
finally:
decrement_test_counter()


__all__ = [
Expand Down
18 changes: 15 additions & 3 deletions package/gentrace/providers/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gentrace.providers.step_run import StepRun
from gentrace.providers.utils import (
from_date_string,
get_test_counter,
run_post_background,
to_date_string,
)
Expand Down Expand Up @@ -47,6 +48,9 @@ def __init__(self, pipeline, id: Optional[str] = None):
self.pipeline_run_id: str = id or str(uuid.uuid4())
self.step_runs: List[StepRun] = []

def get_id(self):
return self.pipeline_run_id

def get_pipeline(self):
return self.pipeline

Expand Down Expand Up @@ -250,6 +254,11 @@ def checkpoint(self, step_info):
)

async def asubmit(self) -> Dict:
if get_test_counter() > 0:
return {
"pipelineRunId": self.get_id(),
}

configuration = Configuration(host=self.pipeline.config.get("host"))
configuration.access_token = self.pipeline.config.get("api_key")
api_client = ApiClient(configuration=configuration)
Expand All @@ -271,13 +280,11 @@ async def asubmit(self) -> Dict:
for step_run in self.step_runs
]

pipeline_run_id = str(uuid.uuid4())

try:
pipeline_post_response = await run_post_background(
core_api,
{
"id": pipeline_run_id,
"id": self.pipeline_run_id,
"slug": self.pipeline.slug,
"stepRuns": step_runs_data,
},
Expand All @@ -292,6 +299,11 @@ async def asubmit(self) -> Dict:
return {"pipelineRunId": None}

def submit(self, wait_for_server=False) -> Dict:
if get_test_counter() > 0:
return {
"pipelineRunId": self.get_id(),
}

configuration = Configuration(host=self.pipeline.config.get("host"))
configuration.access_token = self.pipeline.config.get("api_key")

Expand Down
22 changes: 22 additions & 0 deletions package/gentrace/providers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
"log_debug",
"log_info",
"log_warn",
"get_test_counter",
"increment_test_counter",
"decrement_test_counter",
]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -127,3 +130,22 @@ def wrapped_api_invocation():
executor, wrapped_api_invocation
)
return result


test_counter = 0


def get_test_counter():
return test_counter


def increment_test_counter():
global test_counter
test_counter += 1
return test_counter


def decrement_test_counter():
global test_counter
test_counter -= 1
return test_counter
70 changes: 66 additions & 4 deletions package/tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import gentrace
from gentrace.providers.evaluation import OutputStep, construct_submission_payload
from gentrace.providers.init import GENTRACE_CONFIG_STATE
from gentrace.providers.utils import get_test_counter


def test_evaluation_get_test_cases(mocker, test_cases, setup_teardown_openai):
Expand All @@ -23,8 +24,6 @@ def test_evaluation_get_test_cases(mocker, test_cases, setup_teardown_openai):

body = json.dumps(test_cases, ensure_ascii=False).encode("utf-8")

print("before this")

gentrace_response = HTTPResponse(
body=body,
headers=headers,
Expand All @@ -35,8 +34,6 @@ def test_evaluation_get_test_cases(mocker, test_cases, setup_teardown_openai):
enforce_content_length=True,
)

print("after this")

gentrace_request = mocker.patch.object(gentrace.api_client.ApiClient, "request")
gentrace_request.return_value = gentrace_response

Expand Down Expand Up @@ -376,3 +373,68 @@ def create_checkpoint_callback(test_case):
for runner in runner_list:
assert len(runner.step_runs) == 2
assert runner.step_runs[0].outputs == {"value": 1100}


def test_evaluation_counter_rest(mocker, setup_teardown_openai, test_result_response):
pipeline = gentrace.Pipeline(
"guess-the-year",
openai_config={
"api_key": os.getenv("OPENAI_KEY"),
},
)

pipeline.setup()

runner_list = []

def create_checkpoint_callback(test_case):
runner = pipeline.start()
runner.checkpoint({"inputs": {"x": 100, "y": 1000}, "outputs": {"value": 1100}})
runner.checkpoint({"inputs": {"x": 100, "y": 1000}, "outputs": {"value": 1100}})
runner_list.append(runner)

return ["something", runner]

assert get_test_counter() == 0

response = gentrace.run_test("guess-the-year", create_checkpoint_callback)

assert get_test_counter() == 0

assert response.get("resultId", None) is not None

for runner in runner_list:
assert len(runner.step_runs) == 2
assert runner.step_runs[0].outputs == {"value": 1100}


def test_evaluation_counter_rest_when_run_test_fails(
mocker, setup_teardown_openai, test_result_response
):
pipeline = gentrace.Pipeline(
"guess-the-year",
openai_config={
"api_key": os.getenv("OPENAI_KEY"),
},
)

pipeline.setup()

runner_list = []

def create_checkpoint_callback(test_case):
runner = pipeline.start()
runner.checkpoint({"inputs": {"x": 100, "y": 1000}, "outputs": {"value": 1100}})
runner.checkpoint({"inputs": {"x": 100, "y": 1000}, "outputs": {"value": 1100}})
runner_list.append(runner)

return ["something", runner]

assert get_test_counter() == 0

try:
gentrace.run_test("random-slug-no-exist", create_checkpoint_callback)
except Exception as e:
pass

assert get_test_counter() == 0