Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/agents/realtime/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class RealtimeModelSendToolOutput:
class RealtimeModelSendInterrupt:
"""Send an interrupt to the model."""

force_response_cancel: bool = False
"""Force sending a response.cancel event even if automatic cancellation is enabled."""


@dataclass
class RealtimeModelSendSessionUpdate:
Expand Down
56 changes: 30 additions & 26 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,36 +395,36 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
current_item_id = playback_state.get("current_item_id")
current_item_content_index = playback_state.get("current_item_content_index")
elapsed_ms = playback_state.get("elapsed_ms")

if current_item_id is None or elapsed_ms is None:
logger.debug(
"Skipping interrupt. "
f"Item id: {current_item_id}, "
f"elapsed ms: {elapsed_ms}, "
f"content index: {current_item_content_index}"
)
return

current_item_content_index = current_item_content_index or 0
if elapsed_ms > 0:
await self._emit_event(
RealtimeModelAudioInterruptedEvent(
item_id=current_item_id,
content_index=current_item_content_index,
)
)
converted = _ConversionHelper.convert_interrupt(
current_item_id,
current_item_content_index,
int(elapsed_ms),
)
await self._send_raw_message(converted)
else:
logger.debug(
"Didn't interrupt bc elapsed ms is < 0. "
f"Item id: {current_item_id}, "
f"elapsed ms: {elapsed_ms}, "
f"content index: {current_item_content_index}"
)
current_item_content_index = current_item_content_index or 0
if elapsed_ms > 0:
await self._emit_event(
RealtimeModelAudioInterruptedEvent(
item_id=current_item_id,
content_index=current_item_content_index,
)
)
converted = _ConversionHelper.convert_interrupt(
current_item_id,
current_item_content_index,
int(elapsed_ms),
)
await self._send_raw_message(converted)
else:
logger.debug(
"Didn't interrupt bc elapsed ms is < 0. "
f"Item id: {current_item_id}, "
f"elapsed ms: {elapsed_ms}, "
f"content index: {current_item_content_index}"
)

session = self._created_session
automatic_response_cancellation_enabled = (
Expand All @@ -434,12 +434,16 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
and session.audio.input.turn_detection is not None
and session.audio.input.turn_detection.interrupt_response is True
)
if not automatic_response_cancellation_enabled:
should_cancel_response = event.force_response_cancel or (
not automatic_response_cancellation_enabled
)
if should_cancel_response:
await self._cancel_response()

self._audio_state_tracker.on_interrupted()
if self._playback_tracker:
self._playback_tracker.on_interrupted()
if current_item_id is not None and elapsed_ms is not None:
self._audio_state_tracker.on_interrupted()
if self._playback_tracker:
self._playback_tracker.on_interrupted()

async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> None:
"""Send a session update to the model."""
Expand Down
2 changes: 1 addition & 1 deletion src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool:
)

# Interrupt the model
await self._model.send_event(RealtimeModelSendInterrupt())
await self._model.send_event(RealtimeModelSendInterrupt(force_response_cancel=True))

# Send guardrail triggered message
guardrail_names = [result.guardrail.get_name() for result in triggered_results]
Expand Down
54 changes: 54 additions & 0 deletions tests/realtime/test_openai_realtime.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import AsyncMock, Mock, patch

Expand Down Expand Up @@ -509,6 +510,59 @@ async def test_send_event_dispatch(self, model, monkeypatch):
# session update -> 1
assert send_raw.await_count == 8

@pytest.mark.asyncio
async def test_interrupt_force_cancel_overrides_auto_cancellation(self, model, monkeypatch):
"""Interrupt should send response.cancel even when auto cancel is enabled."""
model._audio_state_tracker.set_audio_format("pcm16")
model._audio_state_tracker.on_audio_delta("item_1", 0, b"\x00" * 4800)
model._ongoing_response = True
model._created_session = SimpleNamespace(
audio=SimpleNamespace(
input=SimpleNamespace(
turn_detection=SimpleNamespace(interrupt_response=True)
)
)
)

send_raw = AsyncMock()
emit_event = AsyncMock()
monkeypatch.setattr(model, "_send_raw_message", send_raw)
monkeypatch.setattr(model, "_emit_event", emit_event)

await model._send_interrupt(RealtimeModelSendInterrupt(force_response_cancel=True))

assert send_raw.await_count == 2
payload_types = {call.args[0].type for call in send_raw.call_args_list}
assert payload_types == {"conversation.item.truncate", "response.cancel"}
assert model._ongoing_response is False
assert model._audio_state_tracker.get_last_audio_item() is None

@pytest.mark.asyncio
async def test_interrupt_respects_auto_cancellation_when_not_forced(self, model, monkeypatch):
"""Interrupt should avoid sending response.cancel when relying on automatic cancellation."""
model._audio_state_tracker.set_audio_format("pcm16")
model._audio_state_tracker.on_audio_delta("item_1", 0, b"\x00" * 4800)
model._ongoing_response = True
model._created_session = SimpleNamespace(
audio=SimpleNamespace(
input=SimpleNamespace(
turn_detection=SimpleNamespace(interrupt_response=True)
)
)
)

send_raw = AsyncMock()
emit_event = AsyncMock()
monkeypatch.setattr(model, "_send_raw_message", send_raw)
monkeypatch.setattr(model, "_emit_event", emit_event)

await model._send_interrupt(RealtimeModelSendInterrupt())

assert send_raw.await_count == 1
assert send_raw.call_args_list[0].args[0].type == "conversation.item.truncate"
assert all(call.args[0].type != "response.cancel" for call in send_raw.call_args_list)
assert model._ongoing_response is True

def test_add_remove_listener_and_tools_conversion(self, model):
listener = AsyncMock()
model.add_listener(listener)
Expand Down