From 0cffb264b86871de0ee663cfbc9490fc0db45c1f Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Fri, 17 Oct 2025 14:25:40 -0700 Subject: [PATCH 01/12] Fix streaming trace ending before spans complete --- src/agents/run.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/agents/run.py b/src/agents/run.py index 85607e7dd..66fee7f75 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1138,6 +1138,8 @@ async def _start_streaming( streamed_result.is_complete = True finally: + if streamed_result._input_guardrails_task: + await streamed_result._input_guardrails_task if current_span: current_span.finish(reset_current=True) if streamed_result.trace: From 35fb6c108fa3930cfcdd4994c33a534f06733d6d Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Fri, 17 Oct 2025 14:52:39 -0700 Subject: [PATCH 02/12] error handling --- src/agents/run.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/agents/run.py b/src/agents/run.py index 66fee7f75..72fffef88 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1139,7 +1139,11 @@ async def _start_streaming( streamed_result.is_complete = True finally: if streamed_result._input_guardrails_task: - await streamed_result._input_guardrails_task + try: + streamed_result.input_guardrail_results = await streamed_result._input_guardrails_task + except Exception: + # Exceptions will be checked in the stream_events loop + output_guardrail_results = [] if current_span: current_span.finish(reset_current=True) if streamed_result.trace: From eb0b2b4d8f66efc166fc006390a6300ee903440a Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Fri, 17 Oct 2025 15:07:32 -0700 Subject: [PATCH 03/12] lint --- src/agents/run.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index 72fffef88..dc7555fe1 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1138,12 +1138,11 @@ async def _start_streaming( streamed_result.is_complete = True finally: - if streamed_result._input_guardrails_task: + if task := streamed_result._input_guardrails_task: try: - streamed_result.input_guardrail_results = await streamed_result._input_guardrails_task + streamed_result.input_guardrail_results = await task except Exception: - # Exceptions will be checked in the stream_events loop - output_guardrail_results = [] + streamed_result.input_guardrail_results = [] if current_span: current_span.finish(reset_current=True) if streamed_result.trace: From 5310c78be3b8967a0f1eff9ff8867bd26962e382 Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Fri, 17 Oct 2025 15:25:42 -0700 Subject: [PATCH 04/12] review comments --- src/agents/run.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index dc7555fe1..568342c5e 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -905,7 +905,10 @@ async def _run_input_guardrails_with_queue( t.cancel() raise + # Store the full set of input guardrail results on the streamed result + # and return them so callers awaiting this task can receive the list. streamed_result.input_guardrail_results = guardrail_results + return guardrail_results @classmethod async def _start_streaming( @@ -1138,11 +1141,13 @@ async def _start_streaming( streamed_result.is_complete = True finally: - if task := streamed_result._input_guardrails_task: + if streamed_result._input_guardrails_task: try: - streamed_result.input_guardrail_results = await task - except Exception: - streamed_result.input_guardrail_results = [] + await streamed_result._input_guardrails_task + except Exception as e: + logger.debug( + f"Error in streamed_result finalize for agent {current_agent.name} - {e}" + ) if current_span: current_span.finish(reset_current=True) if streamed_result.trace: From ef235200e7b48ca0e2200248e4dfe3fd05bbff3e Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Fri, 17 Oct 2025 15:33:32 -0700 Subject: [PATCH 05/12] await tripwire triggered instead of task --- src/agents/run.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/agents/run.py b/src/agents/run.py index 568342c5e..95617fc10 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1143,7 +1143,9 @@ async def _start_streaming( finally: if streamed_result._input_guardrails_task: try: - await streamed_result._input_guardrails_task + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) except Exception as e: logger.debug( f"Error in streamed_result finalize for agent {current_agent.name} - {e}" From ae843384ccc24f4c60d781b82b5c3e31a87fee20 Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Fri, 17 Oct 2025 15:37:17 -0700 Subject: [PATCH 06/12] raise error after logging --- src/agents/run.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/agents/run.py b/src/agents/run.py index 95617fc10..a07ed27f5 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1150,6 +1150,7 @@ async def _start_streaming( logger.debug( f"Error in streamed_result finalize for agent {current_agent.name} - {e}" ) + raise e if current_span: current_span.finish(reset_current=True) if streamed_result.trace: From 1989347eb907ba58be83d49b6f9479a4b7ec51b6 Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Mon, 20 Oct 2025 15:54:53 -0700 Subject: [PATCH 07/12] add unit tests --- tests/test_stream_input_guardrail_timing.py | 236 ++++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 tests/test_stream_input_guardrail_timing.py diff --git a/tests/test_stream_input_guardrail_timing.py b/tests/test_stream_input_guardrail_timing.py new file mode 100644 index 000000000..4a5f69932 --- /dev/null +++ b/tests/test_stream_input_guardrail_timing.py @@ -0,0 +1,236 @@ +import asyncio +from datetime import datetime +from typing import Any + +import pytest + +from agents import Agent, GuardrailFunctionOutput, InputGuardrail, Runner, RunContextWrapper +from agents.items import TResponseInputItem +from agents.exceptions import InputGuardrailTripwireTriggered + +from .fake_model import FakeModel +from openai.types.responses import ResponseCompletedEvent +from .test_responses import get_text_message + + +def make_input_guardrail(delay_seconds: float, *, trip: bool) -> InputGuardrail[Any]: + async def guardrail( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + # Simulate variable guardrail completion timing. + if delay_seconds > 0: + await asyncio.sleep(delay_seconds) + return GuardrailFunctionOutput( + output_info={"delay": delay_seconds}, tripwire_triggered=trip + ) + + name = "tripping_input_guardrail" if trip else "delayed_input_guardrail" + return InputGuardrail(guardrail_function=guardrail, name=name) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guardrail_delay", [0.0, 0.2]) +async def test_run_streamed_input_guardrail_timing_is_consistent(guardrail_delay: float): + """Ensure streaming behavior matches whether input guardrail finishes before or after LLM stream. + + We verify that: + - The sequence of streamed event types is identical. + - Final output matches. + - Exactly one input guardrail result is recorded and does not trigger. + """ + + # Arrange: Agent with a single text output and a delayed input guardrail + model = FakeModel() + model.set_next_output([get_text_message("Final response")]) + + agent = Agent( + name="TimingAgent", + model=model, + input_guardrails=[make_input_guardrail(guardrail_delay, trip=False)], + ) + + # Act: Run streamed and collect event types + result = Runner.run_streamed(agent, input="Hello") + event_types: list[str] = [] + + async for event in result.stream_events(): + event_types.append(event.type) + + # Assert: Guardrail results populated and identical behavioral outcome + assert len(result.input_guardrail_results) == 1, "Expected exactly one input guardrail result" + assert result.input_guardrail_results[0].guardrail.get_name() == "delayed_input_guardrail", ( + "Guardrail name mismatch" + ) + assert result.input_guardrail_results[0].output.tripwire_triggered is False, ( + "Guardrail should not trigger in this test" + ) + + # Final output should be the text from the model's single message + assert result.final_output == "Final response" + + # Minimal invariants on event sequence to ensure stability across timing + # Must start with agent update and include raw response events + assert len(event_types) >= 3, f"Unexpectedly few events: {event_types}" + assert event_types[0] == "agent_updated_stream_event" + # Ensure we observed raw response events in the stream irrespective of guardrail timing + assert any(t == "raw_response_event" for t in event_types) + + +@pytest.mark.asyncio +async def test_run_streamed_input_guardrail_sequences_match_between_fast_and_slow(): + """Run twice with fast vs slow input guardrail and compare event sequences exactly.""" + + async def run_once(delay: float) -> list[str]: + model = FakeModel() + model.set_next_output([get_text_message("Final response")]) + agent = Agent( + name="TimingAgent", + model=model, + input_guardrails=[make_input_guardrail(delay, trip=False)], + ) + result = Runner.run_streamed(agent, input="Hello") + events: list[str] = [] + async for ev in result.stream_events(): + events.append(ev.type) + return events + + events_fast = await run_once(0.0) + events_slow = await run_once(0.2) + + assert events_fast == events_slow, ( + f"Event sequences differ between guardrail timings:\nfast={events_fast}\nslow={events_slow}" + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guardrail_delay", [0.0, 0.2]) +async def test_run_streamed_input_guardrail_tripwire_raises(guardrail_delay: float): + """Guardrail tripwire must raise from stream_events regardless of timing.""" + + model = FakeModel() + model.set_next_output([get_text_message("Final response")]) + + agent = Agent( + name="TimingAgentTrip", + model=model, + input_guardrails=[make_input_guardrail(guardrail_delay, trip=True)], + ) + + result = Runner.run_streamed(agent, input="Hello") + + with pytest.raises(InputGuardrailTripwireTriggered) as excinfo: + async for _ in result.stream_events(): + pass + + # Exception contains the guardrail result and run data + exc = excinfo.value + assert exc.guardrail_result.output.tripwire_triggered is True + assert exc.run_data is not None + assert len(exc.run_data.input_guardrail_results) == 1 + assert ( + exc.run_data.input_guardrail_results[0].guardrail.get_name() == "tripping_input_guardrail" + ) + + +class SlowCompleteFakeModel(FakeModel): + """A FakeModel that delays just before emitting ResponseCompletedEvent in streaming.""" + + def __init__(self, delay_seconds: float, tracing_enabled: bool = True): + super().__init__(tracing_enabled=tracing_enabled) + self._delay_seconds = delay_seconds + + async def stream_response(self, *args, **kwargs): # type: ignore[override] + async for ev in super().stream_response(*args, **kwargs): + if isinstance(ev, ResponseCompletedEvent) and self._delay_seconds > 0: + await asyncio.sleep(self._delay_seconds) + yield ev + + +def _get_span_by_type(spans, span_type: str): + for s in spans: + exported = s.export() + if not exported: + continue + if exported.get("span_data", {}).get("type") == span_type: + return s + return None + + +def _iso(s: str | None) -> datetime: + assert s is not None + return datetime.fromisoformat(s) + + +@pytest.mark.asyncio +async def test_parent_span_and_trace_finish_after_slow_input_guardrail(): + """Agent span and trace finish after guardrail when guardrail completes last.""" + + model = FakeModel(tracing_enabled=True) + model.set_next_output([get_text_message("Final response")]) + agent = Agent( + name="TimingAgentTrace", + model=model, + input_guardrails=[make_input_guardrail(0.2, trip=False)], # guardrail slower than model + ) + + result = Runner.run_streamed(agent, input="Hello") + async for _ in result.stream_events(): + pass + + from .testing_processor import fetch_ordered_spans + + spans = fetch_ordered_spans() + agent_span = _get_span_by_type(spans, "agent") + guardrail_span = _get_span_by_type(spans, "guardrail") + generation_span = _get_span_by_type(spans, "generation") + + assert agent_span and guardrail_span and generation_span, ( + "Expected agent, guardrail, generation spans" + ) + + # Agent span must finish last + assert _iso(agent_span.ended_at) >= _iso(guardrail_span.ended_at) + assert _iso(agent_span.ended_at) >= _iso(generation_span.ended_at) + + # Trace should end after all spans end + from .testing_processor import fetch_events + + events = fetch_events() + assert events[-1] == "trace_end" + + +@pytest.mark.asyncio +async def test_parent_span_and_trace_finish_after_slow_model(): + """Agent span and trace finish after model when model completes last.""" + + model = SlowCompleteFakeModel(delay_seconds=0.2, tracing_enabled=True) + model.set_next_output([get_text_message("Final response")]) + agent = Agent( + name="TimingAgentTrace", + model=model, + input_guardrails=[make_input_guardrail(0.0, trip=False)], # guardrail faster than model + ) + + result = Runner.run_streamed(agent, input="Hello") + async for _ in result.stream_events(): + pass + + from .testing_processor import fetch_ordered_spans + + spans = fetch_ordered_spans() + agent_span = _get_span_by_type(spans, "agent") + guardrail_span = _get_span_by_type(spans, "guardrail") + generation_span = _get_span_by_type(spans, "generation") + + assert agent_span and guardrail_span and generation_span, ( + "Expected agent, guardrail, generation spans" + ) + + # Agent span must finish last + assert _iso(agent_span.ended_at) >= _iso(guardrail_span.ended_at) + assert _iso(agent_span.ended_at) >= _iso(generation_span.ended_at) + + from .testing_processor import fetch_events + + events = fetch_events() + assert events[-1] == "trace_end" From 6bff7d9b1e8e742ff2f542811f4368623c1c54e0 Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Mon, 20 Oct 2025 16:07:50 -0700 Subject: [PATCH 08/12] remove unnecessary raise --- src/agents/run.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/agents/run.py b/src/agents/run.py index a07ed27f5..95617fc10 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1150,7 +1150,6 @@ async def _start_streaming( logger.debug( f"Error in streamed_result finalize for agent {current_agent.name} - {e}" ) - raise e if current_span: current_span.finish(reset_current=True) if streamed_result.trace: From eacd839861228acfd2f53274f62590d0106f6ece Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Mon, 20 Oct 2025 16:12:50 -0700 Subject: [PATCH 09/12] remove relative imports --- tests/test_stream_input_guardrail_timing.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/test_stream_input_guardrail_timing.py b/tests/test_stream_input_guardrail_timing.py index 4a5f69932..18f65dcd5 100644 --- a/tests/test_stream_input_guardrail_timing.py +++ b/tests/test_stream_input_guardrail_timing.py @@ -8,9 +8,10 @@ from agents.items import TResponseInputItem from agents.exceptions import InputGuardrailTripwireTriggered -from .fake_model import FakeModel from openai.types.responses import ResponseCompletedEvent -from .test_responses import get_text_message +from tests.fake_model import FakeModel +from tests.test_responses import get_text_message +from tests.testing_processor import fetch_ordered_spans, fetch_events def make_input_guardrail(delay_seconds: float, *, trip: bool) -> InputGuardrail[Any]: @@ -177,8 +178,6 @@ async def test_parent_span_and_trace_finish_after_slow_input_guardrail(): async for _ in result.stream_events(): pass - from .testing_processor import fetch_ordered_spans - spans = fetch_ordered_spans() agent_span = _get_span_by_type(spans, "agent") guardrail_span = _get_span_by_type(spans, "guardrail") @@ -193,8 +192,6 @@ async def test_parent_span_and_trace_finish_after_slow_input_guardrail(): assert _iso(agent_span.ended_at) >= _iso(generation_span.ended_at) # Trace should end after all spans end - from .testing_processor import fetch_events - events = fetch_events() assert events[-1] == "trace_end" @@ -215,8 +212,6 @@ async def test_parent_span_and_trace_finish_after_slow_model(): async for _ in result.stream_events(): pass - from .testing_processor import fetch_ordered_spans - spans = fetch_ordered_spans() agent_span = _get_span_by_type(spans, "agent") guardrail_span = _get_span_by_type(spans, "guardrail") @@ -230,7 +225,5 @@ async def test_parent_span_and_trace_finish_after_slow_model(): assert _iso(agent_span.ended_at) >= _iso(guardrail_span.ended_at) assert _iso(agent_span.ended_at) >= _iso(generation_span.ended_at) - from .testing_processor import fetch_events - events = fetch_events() assert events[-1] == "trace_end" From d8709027108fbace3bcbb50567efaf70cd363bea Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Mon, 20 Oct 2025 22:05:01 -0700 Subject: [PATCH 10/12] remove unnecessary edit --- src/agents/run.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index 95617fc10..58eef335e 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -905,10 +905,7 @@ async def _run_input_guardrails_with_queue( t.cancel() raise - # Store the full set of input guardrail results on the streamed result - # and return them so callers awaiting this task can receive the list. streamed_result.input_guardrail_results = guardrail_results - return guardrail_results @classmethod async def _start_streaming( From daf42e4547edd143d844f5d77db7d03e4a1670bd Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Mon, 20 Oct 2025 22:07:01 -0700 Subject: [PATCH 11/12] lint --- tests/test_stream_input_guardrail_timing.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_stream_input_guardrail_timing.py b/tests/test_stream_input_guardrail_timing.py index 18f65dcd5..07e870c67 100644 --- a/tests/test_stream_input_guardrail_timing.py +++ b/tests/test_stream_input_guardrail_timing.py @@ -3,15 +3,14 @@ from typing import Any import pytest +from openai.types.responses import ResponseCompletedEvent -from agents import Agent, GuardrailFunctionOutput, InputGuardrail, Runner, RunContextWrapper -from agents.items import TResponseInputItem +from agents import Agent, GuardrailFunctionOutput, InputGuardrail, RunContextWrapper, Runner from agents.exceptions import InputGuardrailTripwireTriggered - -from openai.types.responses import ResponseCompletedEvent +from agents.items import TResponseInputItem from tests.fake_model import FakeModel from tests.test_responses import get_text_message -from tests.testing_processor import fetch_ordered_spans, fetch_events +from tests.testing_processor import fetch_events, fetch_ordered_spans def make_input_guardrail(delay_seconds: float, *, trip: bool) -> InputGuardrail[Any]: @@ -32,7 +31,7 @@ async def guardrail( @pytest.mark.asyncio @pytest.mark.parametrize("guardrail_delay", [0.0, 0.2]) async def test_run_streamed_input_guardrail_timing_is_consistent(guardrail_delay: float): - """Ensure streaming behavior matches whether input guardrail finishes before or after LLM stream. + """Ensure streaming behavior matches when input guardrail finishes before and after LLM stream. We verify that: - The sequence of streamed event types is identical. @@ -140,7 +139,7 @@ def __init__(self, delay_seconds: float, tracing_enabled: bool = True): super().__init__(tracing_enabled=tracing_enabled) self._delay_seconds = delay_seconds - async def stream_response(self, *args, **kwargs): # type: ignore[override] + async def stream_response(self, *args, **kwargs): async for ev in super().stream_response(*args, **kwargs): if isinstance(ev, ResponseCompletedEvent) and self._delay_seconds > 0: await asyncio.sleep(self._delay_seconds) From 93c03852a065934ae10cb3bdc05d9cf3838621d0 Mon Sep 17 00:00:00 2001 From: zbirenbaum Date: Tue, 21 Oct 2025 09:47:10 -0700 Subject: [PATCH 12/12] 3.9 compat --- tests/test_stream_input_guardrail_timing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_stream_input_guardrail_timing.py b/tests/test_stream_input_guardrail_timing.py index 07e870c67..3de8897aa 100644 --- a/tests/test_stream_input_guardrail_timing.py +++ b/tests/test_stream_input_guardrail_timing.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from datetime import datetime from typing import Any