Skip to content

Commit 03a11d5

Browse files
hellovaiaaronvg
andauthored
add test cases which show collector fails in a non-recoverable way whenever baml-runtime fails (bad type_builder, bad client_registry, bad args) (#1633)
<!-- ELLIPSIS_HIDDEN --> > [!IMPORTANT] > Add test cases to ensure collector fails non-recoverably on baml-runtime errors, with async handling in `runtime_interface.rs`. > > - **Behavior**: > - Refactor `call_function_impl` in `runtime_interface.rs` to handle async execution and error handling. > - Ensure function execution logs errors when BAML function does not exist or argument validation fails. > - **Tests**: > - Add `test_collector_failures_arg_type` to check for `BamlInvalidArgumentError` on bad argument types. > - Add `test_collector_failures_client_registry` to check for `BamlError` when client registry is incorrect. > - Add streaming variants `test_collector_failures_arg_type_streaming` and `test_collector_failures_client_registry_streaming` for similar checks. > - **Misc**: > - Format adjustments in `test_collector.py` for readability. > > <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup> for 816aec8. It will automatically update as commits are pushed.</sup> <!-- ELLIPSIS_HIDDEN --> --------- Co-authored-by: Aaron Villalpando <aaron@boundaryml.com>
1 parent 4584322 commit 03a11d5

File tree

2 files changed

+138
-59
lines changed

2 files changed

+138
-59
lines changed

engine/baml-runtime/src/runtime/runtime_interface.rs

Lines changed: 40 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -372,47 +372,8 @@ impl RuntimeInterface for InternalBamlRuntime {
372372
params: &BamlMap<String, BamlValue>,
373373
ctx: RuntimeContext,
374374
) -> Result<crate::FunctionResult> {
375-
let func = match self.get_function(&function_name, &ctx) {
376-
Ok(func) => func,
377-
Err(e) => {
378-
return Ok(FunctionResult::new(
379-
OrchestrationScope::default(),
380-
LLMResponse::UserFailure(format!(
381-
"BAML function {function_name} does not exist in baml_src/ (did you typo it?): {:?}",
382-
e
383-
)),
384-
None,
385-
))
386-
}
387-
};
388-
let baml_args = self.ir().check_function_params(
389-
&func,
390-
params,
391-
ArgCoercer {
392-
span_path: None,
393-
allow_implicit_cast_to_string: false,
394-
},
395-
)?;
396-
// let baml_args = match self.ir().check_function_params(
397-
// &func,
398-
// &params,
399-
// ArgCoercer {
400-
// span_path: None,
401-
// allow_implicit_cast_to_string: false,
402-
// },
403-
// ) {
404-
// Ok(args) => args,
405-
// Err(e) => {
406-
// return Ok(FunctionResult::new(
407-
// OrchestrationScope::default(),
408-
// LLMResponse::UserFailure(format!(
409-
// "Failed while validating args for {function_name}: {:?}",
410-
// e
411-
// )),
412-
// None,
413-
// ))
414-
// }
415-
// };
375+
let local_span_id = ctx.span_id.clone();
376+
let local_function_name = function_name.clone();
416377

417378
if let Some(span_id) = ctx.span_id {
418379
let trace_event = TraceEvent {
@@ -443,15 +404,44 @@ impl RuntimeInterface for InternalBamlRuntime {
443404
);
444405
}
445406

446-
let renderer = PromptRenderer::from_function(&func, self.ir(), &ctx)?;
447-
let orchestrator = self.orchestration_graph(renderer.client_spec(), &ctx)?;
407+
let future = async {
408+
let func = match self.get_function(&function_name, &ctx) {
409+
Ok(func) => func,
410+
Err(e) => {
411+
return Ok(FunctionResult::new(
412+
OrchestrationScope::default(),
413+
LLMResponse::UserFailure(format!(
414+
"BAML function {function_name} does not exist in baml_src/ (did you typo it?): {:?}",
415+
e
416+
)),
417+
None,
418+
))
419+
}
420+
};
448421

449-
// Now actually execute the code.
450-
let (history, _) =
451-
orchestrate_call(orchestrator, self.ir(), &ctx, &renderer, &baml_args, |s| {
452-
renderer.parse(self.ir(), &ctx, s, false)
453-
})
454-
.await;
422+
let baml_args = self.ir().check_function_params(
423+
&func,
424+
params,
425+
ArgCoercer {
426+
span_path: None,
427+
allow_implicit_cast_to_string: false,
428+
},
429+
)?;
430+
431+
let renderer = PromptRenderer::from_function(&func, self.ir(), &ctx)?;
432+
let orchestrator = self.orchestration_graph(renderer.client_spec(), &ctx)?;
433+
434+
// Now actually execute the code.
435+
let (history, _) =
436+
orchestrate_call(orchestrator, self.ir(), &ctx, &renderer, &baml_args, |s| {
437+
renderer.parse(self.ir(), &ctx, s, false)
438+
})
439+
.await;
440+
441+
FunctionResult::new_chain(history)
442+
};
443+
444+
let result = future.await;
455445

456446
let end_time = web_time::SystemTime::now();
457447
if let Some(span_id) = ctx.span_id {
@@ -475,7 +465,7 @@ impl RuntimeInterface for InternalBamlRuntime {
475465
);
476466
}
477467

478-
FunctionResult::new_chain(history)
468+
result
479469
}
480470

481471
// Note that this only returns a FunctionResultStream object,

integ-tests/python/tests/test_collector.py

Lines changed: 98 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1+
from baml_py.errors import BamlInvalidArgumentError, BamlError, BamlClientError
12
import pytest
23
import dotenv
34
from openai.types.chat import ChatCompletion
45

56
from ..baml_client import b
67
from ..baml_client.sync_client import b as b_sync
7-
from baml_py import Collector
8+
from baml_py import ClientRegistry, Collector
89
import gc
910
import sys
1011
import asyncio
1112

1213
dotenv.load_dotenv()
1314

15+
1416
def function_span_count():
1517
return Collector.__function_span_count() # type: ignore
1618

19+
1720
@pytest.fixture(autouse=True)
1821
def ensure_collector_is_empty():
1922
assert function_span_count() == 0
@@ -240,8 +243,12 @@ async def test_collector_async_multiple_calls_usage():
240243

241244
# Capture usage after second call and verify it's the sum of both calls
242245
second_call_usage = function_logs[1].usage
243-
total_input = (first_call_usage.input_tokens or 0) + (second_call_usage.input_tokens or 0)
244-
total_output = (first_call_usage.output_tokens or 0) + (second_call_usage.output_tokens or 0)
246+
total_input = (first_call_usage.input_tokens or 0) + (
247+
second_call_usage.input_tokens or 0
248+
)
249+
total_output = (first_call_usage.output_tokens or 0) + (
250+
second_call_usage.output_tokens or 0
251+
)
245252
assert collector.usage.input_tokens == total_input
246253
assert collector.usage.output_tokens == total_output
247254

@@ -286,11 +293,11 @@ async def test_collector_multiple_collectors():
286293

287294
# Verify coll1 usage is now the sum of both calls
288295
usage_second_call_coll1 = logs1[1].usage
289-
total_input = (
290-
(usage_first_call_coll1.input_tokens or 0) + (usage_second_call_coll1.input_tokens or 0)
296+
total_input = (usage_first_call_coll1.input_tokens or 0) + (
297+
usage_second_call_coll1.input_tokens or 0
291298
)
292-
total_output = (
293-
(usage_first_call_coll1.output_tokens or 0) + (usage_second_call_coll1.output_tokens or 0)
299+
total_output = (usage_first_call_coll1.output_tokens or 0) + (
300+
usage_second_call_coll1.output_tokens or 0
294301
)
295302
assert coll1.usage.input_tokens == total_input
296303
assert coll1.usage.output_tokens == total_output
@@ -322,8 +329,12 @@ async def test_collector_mixed_async_sync_calls():
322329
# Verify the second call's usage
323330
usage_second_call = logs[1].usage
324331
assert logs[1].timing.start_time_utc_ms > logs[0].timing.start_time_utc_ms
325-
total_input = (usage_first_call.input_tokens or 0) + (usage_second_call.input_tokens or 0)
326-
total_output = (usage_first_call.output_tokens or 0) + (usage_second_call.output_tokens or 0)
332+
total_input = (usage_first_call.input_tokens or 0) + (
333+
usage_second_call.input_tokens or 0
334+
)
335+
total_output = (usage_first_call.output_tokens or 0) + (
336+
usage_second_call.output_tokens or 0
337+
)
327338
assert collector.usage.input_tokens == total_input
328339
assert collector.usage.output_tokens == total_output
329340

@@ -361,3 +372,81 @@ async def test_collector_parallel_async_calls():
361372
# total_output = usage_call1.output_tokens + usage_call2.output_tokens
362373
# assert collector.usage.input_tokens == total_input
363374
# assert collector.usage.output_tokens == total_output
375+
376+
377+
@pytest.mark.asyncio
378+
async def test_collector_failures_arg_type():
379+
collector = Collector(name="my-collector")
380+
with pytest.raises(BamlInvalidArgumentError):
381+
value: str = 124 # type: ignore (We want to test the error)
382+
await b.TestOpenAIGPT4oMini(value, baml_options={"collector": collector})
383+
384+
assert len(collector.logs) == 1
385+
last_log = collector.last
386+
print("------------------------- last_log", last_log)
387+
assert last_log is not None
388+
assert last_log.function_name == "TestOpenAIGPT4oMini"
389+
390+
391+
@pytest.mark.asyncio
392+
async def test_collector_failures_client_registry():
393+
collector = Collector(name="my-collector")
394+
client_registry = ClientRegistry()
395+
client_registry.set_primary("DoesNotExist")
396+
with pytest.raises(BamlError):
397+
await b.TestOpenAIGPT4oMini(
398+
"hi there",
399+
baml_options={"collector": collector, "client_registry": client_registry},
400+
)
401+
assert len(collector.logs) == 1
402+
last_log = collector.last
403+
assert last_log is not None
404+
assert last_log.function_name == "TestOpenAIGPT4oMini"
405+
406+
407+
@pytest.mark.asyncio
408+
async def test_collector_failures_arg_type_streaming():
409+
collector = Collector(name="my-collector")
410+
with pytest.raises(BamlInvalidArgumentError):
411+
value: str = 124 # type: ignore (We want to test the error)
412+
async for _ in b.stream.TestOpenAIGPT4oMini(
413+
value, baml_options={"collector": collector}
414+
):
415+
pass
416+
417+
# Fails before the stream is even started
418+
# We don't have a state for streams that were "registered" but not started
419+
assert len(collector.logs) == 0
420+
421+
422+
@pytest.mark.asyncio
423+
async def test_collector_failures_client_registry_streaming():
424+
collector = Collector(name="my-collector")
425+
client_registry = ClientRegistry()
426+
client_registry.add_llm_client(
427+
"TestClient",
428+
"openai",
429+
{"model": "gpt-4o-mini", "base_url": "https://does-not-exist.com"},
430+
)
431+
client_registry.set_primary("TestClient")
432+
with pytest.raises(BamlClientError):
433+
try:
434+
stream = b.stream.TestOpenAIGPT4oMini(
435+
"hi there",
436+
baml_options={
437+
"collector": collector,
438+
"client_registry": client_registry,
439+
},
440+
)
441+
# TODO: baml doesnt yet throw if theres a connection error during the stream..
442+
async for _ in stream:
443+
pass
444+
# So we try to call get final response to make sure it fails
445+
await stream.get_final_response()
446+
except Exception as e:
447+
print(f"Error occurred: {e}")
448+
raise
449+
assert len(collector.logs) == 1
450+
last_log = collector.last
451+
assert last_log is not None
452+
assert last_log.function_name == "TestOpenAIGPT4oMini"

0 commit comments

Comments
 (0)