Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions packages/sdk/server-ai/tests/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,3 +909,111 @@ def test_client_create_tracker_fails_on_invalid_json():
result = ai_client.create_tracker(bad_token, context)
assert not result.is_success()
assert "Invalid resumption token" in result.error


# --- PR 10: LDAIMetrics enrichment + tracker integration ---


def test_ldai_metrics_to_dict_includes_tool_calls_and_duration_ms():
metrics = LDAIMetrics(
success=True,
usage=TokenUsage(total=10, input=4, output=6),
tool_calls=["search", "lookup"],
duration_ms=123,
)
d = metrics.to_dict()
assert d["success"] is True
assert d["usage"] == {"total": 10, "input": 4, "output": 6}
assert d["toolCalls"] == ["search", "lookup"]
assert d["durationMs"] == 123


def test_ldai_metrics_to_dict_omits_optional_fields_when_none():
metrics = LDAIMetrics(success=False)
d = metrics.to_dict()
assert d == {"success": False}


def test_track_metrics_of_uses_metrics_duration_ms_when_set(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(
ld_client=client, run_id="test-run-id", config_key="config-key",
variation_key="variation-key", version=3, model_name="m",
provider_name="p", context=context,
)

def fn():
return "done"

def extract(_r):
return LDAIMetrics(success=True, duration_ms=999)

tracker.track_metrics_of(extract, fn)
assert tracker.get_summary().duration_ms == 999


@pytest.mark.asyncio
async def test_track_metrics_of_async_uses_metrics_duration_ms_when_set(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(
ld_client=client, run_id="test-run-id", config_key="config-key",
variation_key="variation-key", version=3, model_name="m",
provider_name="p", context=context,
)

async def fn():
return "done"

def extract(_r):
return LDAIMetrics(success=True, duration_ms=42)

await tracker.track_metrics_of_async(extract, fn)
assert tracker.get_summary().duration_ms == 42


def test_track_metrics_of_calls_track_tool_calls_when_present(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(
ld_client=client, run_id="test-run-id", config_key="config-key",
variation_key="variation-key", version=3, model_name="m",
provider_name="p", context=context,
)

def fn():
return "done"

def extract(_r):
return LDAIMetrics(success=True, tool_calls=["foo", "bar"])

tracker.track_metrics_of(extract, fn)
summary = tracker.get_summary()
assert summary.tool_calls == ["foo", "bar"]
# One $ld:ai:tool_call event per tool key.
tool_call_events = [
c for c in client.track.mock_calls # type: ignore
if c.args[0] == "$ld:ai:tool_call"
]
assert len(tool_call_events) == 2


def test_track_metrics_of_skips_track_tool_calls_when_absent(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(
ld_client=client, run_id="test-run-id", config_key="config-key",
variation_key="variation-key", version=3, model_name="m",
provider_name="p", context=context,
)

def fn():
return "done"

def extract(_r):
return LDAIMetrics(success=True, usage=None)

tracker.track_metrics_of(extract, fn)
assert tracker.get_summary().tool_calls is None
tool_call_events = [
c for c in client.track.mock_calls # type: ignore
if c.args[0] == "$ld:ai:tool_call"
]
assert tool_call_events == []
Loading