diff --git a/src/agents/realtime/model_inputs.py b/src/agents/realtime/model_inputs.py index 9d7ab143d..411177b7a 100644 --- a/src/agents/realtime/model_inputs.py +++ b/src/agents/realtime/model_inputs.py @@ -95,6 +95,9 @@ class RealtimeModelSendToolOutput: class RealtimeModelSendInterrupt: """Send an interrupt to the model.""" + force_response_cancel: bool = False + """Force sending a response.cancel event even if automatic cancellation is enabled.""" + @dataclass class RealtimeModelSendSessionUpdate: diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 04a227ac8..873062c1d 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -395,6 +395,7 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None: current_item_id = playback_state.get("current_item_id") current_item_content_index = playback_state.get("current_item_content_index") elapsed_ms = playback_state.get("elapsed_ms") + if current_item_id is None or elapsed_ms is None: logger.debug( "Skipping interrupt. " @@ -402,29 +403,28 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None: f"elapsed ms: {elapsed_ms}, " f"content index: {current_item_content_index}" ) - return - - current_item_content_index = current_item_content_index or 0 - if elapsed_ms > 0: - await self._emit_event( - RealtimeModelAudioInterruptedEvent( - item_id=current_item_id, - content_index=current_item_content_index, - ) - ) - converted = _ConversionHelper.convert_interrupt( - current_item_id, - current_item_content_index, - int(elapsed_ms), - ) - await self._send_raw_message(converted) else: - logger.debug( - "Didn't interrupt bc elapsed ms is < 0. " - f"Item id: {current_item_id}, " - f"elapsed ms: {elapsed_ms}, " - f"content index: {current_item_content_index}" - ) + current_item_content_index = current_item_content_index or 0 + if elapsed_ms > 0: + await self._emit_event( + RealtimeModelAudioInterruptedEvent( + item_id=current_item_id, + content_index=current_item_content_index, + ) + ) + converted = _ConversionHelper.convert_interrupt( + current_item_id, + current_item_content_index, + int(elapsed_ms), + ) + await self._send_raw_message(converted) + else: + logger.debug( + "Didn't interrupt bc elapsed ms is < 0. " + f"Item id: {current_item_id}, " + f"elapsed ms: {elapsed_ms}, " + f"content index: {current_item_content_index}" + ) session = self._created_session automatic_response_cancellation_enabled = ( @@ -434,12 +434,16 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None: and session.audio.input.turn_detection is not None and session.audio.input.turn_detection.interrupt_response is True ) - if not automatic_response_cancellation_enabled: + should_cancel_response = event.force_response_cancel or ( + not automatic_response_cancellation_enabled + ) + if should_cancel_response: await self._cancel_response() - self._audio_state_tracker.on_interrupted() - if self._playback_tracker: - self._playback_tracker.on_interrupted() + if current_item_id is not None and elapsed_ms is not None: + self._audio_state_tracker.on_interrupted() + if self._playback_tracker: + self._playback_tracker.on_interrupted() async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> None: """Send a session update to the model.""" diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 42dcf531a..e10b48e53 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -704,7 +704,7 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool: ) # Interrupt the model - await self._model.send_event(RealtimeModelSendInterrupt()) + await self._model.send_event(RealtimeModelSendInterrupt(force_response_cancel=True)) # Send guardrail triggered message guardrail_names = [result.guardrail.get_name() for result in triggered_results] diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 29b6fbd9a..2b9683456 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -1,4 +1,5 @@ import json +from types import SimpleNamespace from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch @@ -509,6 +510,59 @@ async def test_send_event_dispatch(self, model, monkeypatch): # session update -> 1 assert send_raw.await_count == 8 + @pytest.mark.asyncio + async def test_interrupt_force_cancel_overrides_auto_cancellation(self, model, monkeypatch): + """Interrupt should send response.cancel even when auto cancel is enabled.""" + model._audio_state_tracker.set_audio_format("pcm16") + model._audio_state_tracker.on_audio_delta("item_1", 0, b"\x00" * 4800) + model._ongoing_response = True + model._created_session = SimpleNamespace( + audio=SimpleNamespace( + input=SimpleNamespace( + turn_detection=SimpleNamespace(interrupt_response=True) + ) + ) + ) + + send_raw = AsyncMock() + emit_event = AsyncMock() + monkeypatch.setattr(model, "_send_raw_message", send_raw) + monkeypatch.setattr(model, "_emit_event", emit_event) + + await model._send_interrupt(RealtimeModelSendInterrupt(force_response_cancel=True)) + + assert send_raw.await_count == 2 + payload_types = {call.args[0].type for call in send_raw.call_args_list} + assert payload_types == {"conversation.item.truncate", "response.cancel"} + assert model._ongoing_response is False + assert model._audio_state_tracker.get_last_audio_item() is None + + @pytest.mark.asyncio + async def test_interrupt_respects_auto_cancellation_when_not_forced(self, model, monkeypatch): + """Interrupt should avoid sending response.cancel when relying on automatic cancellation.""" + model._audio_state_tracker.set_audio_format("pcm16") + model._audio_state_tracker.on_audio_delta("item_1", 0, b"\x00" * 4800) + model._ongoing_response = True + model._created_session = SimpleNamespace( + audio=SimpleNamespace( + input=SimpleNamespace( + turn_detection=SimpleNamespace(interrupt_response=True) + ) + ) + ) + + send_raw = AsyncMock() + emit_event = AsyncMock() + monkeypatch.setattr(model, "_send_raw_message", send_raw) + monkeypatch.setattr(model, "_emit_event", emit_event) + + await model._send_interrupt(RealtimeModelSendInterrupt()) + + assert send_raw.await_count == 1 + assert send_raw.call_args_list[0].args[0].type == "conversation.item.truncate" + assert all(call.args[0].type != "response.cancel" for call in send_raw.call_args_list) + assert model._ongoing_response is True + def test_add_remove_listener_and_tools_conversion(self, model): listener = AsyncMock() model.add_listener(listener)