Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ConversationItemDeletedEvent,
ConversationItemDeleteEvent,
ConversationItemInputAudioTranscriptionCompletedEvent,
ConversationItemInputAudioTranscriptionDeltaEvent,
ConversationItemInputAudioTranscriptionFailedEvent,
ConversationItemTruncateEvent,
InputAudioBufferAppendEvent,
Expand Down Expand Up @@ -807,6 +808,9 @@ def __init__(self, realtime_model: RealtimeModel) -> None:
self._item_delete_future: dict[str, asyncio.Future] = {}
self._item_create_future: dict[str, asyncio.Future] = {}

# accumulates partial input-audio transcripts per (item_id, content_index)
self._input_transcript_accumulators: dict[str, dict[int, str]] = {}

self._current_generation: _ResponseGeneration | None = None
self._remote_chat_ctx = llm.remote_chat_context.RemoteChatContext()

Expand Down Expand Up @@ -854,6 +858,7 @@ async def _reconnect() -> None:
)
old_chat_ctx = self._remote_chat_ctx
self._remote_chat_ctx = llm.remote_chat_context.RemoteChatContext()
self._input_transcript_accumulators.clear()
events.extend(self._create_update_chat_ctx_events(chat_ctx))

try:
Expand Down Expand Up @@ -1064,10 +1069,9 @@ async def _recv_task() -> None:
ConversationItemDeletedEvent.construct(**event)
)
elif event["type"] == "conversation.item.input_audio_transcription.delta":
# currently incoming transcripts are transcribed only after the user stops speaking
# it's not very useful to emit these as the transcribe process takes place within ~100ms
# when they handle streaming transcriptions, we'll handle it then.
pass
self._handle_conversion_item_input_audio_transcription_delta(
ConversationItemInputAudioTranscriptionDeltaEvent.construct(**event)
)
elif event["type"] == "conversation.item.input_audio_transcription.completed":
self._handle_conversion_item_input_audio_transcription_completed(
ConversationItemInputAudioTranscriptionCompletedEvent.construct(**event)
Expand Down Expand Up @@ -1746,6 +1750,8 @@ def _handle_conversion_item_added(self, event: ConversationItemAdded) -> None:
def _handle_conversion_item_deleted(self, event: ConversationItemDeletedEvent) -> None:
assert event.item_id is not None, "item_id is None"

self._input_transcript_accumulators.pop(event.item_id, None)

try:
self._remote_chat_ctx.delete(event.item_id)
except ValueError as e:
Expand All @@ -1759,9 +1765,38 @@ def _handle_conversion_item_deleted(self, event: ConversationItemDeletedEvent) -
else:
fut.set_result(None)

def _handle_conversion_item_input_audio_transcription_delta(
self, event: ConversationItemInputAudioTranscriptionDeltaEvent
) -> None:
if not event.delta:
return

content_index = event.content_index or 0
by_index = self._input_transcript_accumulators.setdefault(event.item_id, {})
accumulated = by_index.get(content_index, "") + event.delta
by_index[content_index] = accumulated

self.emit(
"input_audio_transcription_completed",
llm.InputTranscriptionCompleted(
item_id=event.item_id, transcript=accumulated, is_final=False
),
)

def _clear_transcript_accumulator(self, item_id: str, content_index: int) -> str | None:
by_index = self._input_transcript_accumulators.get(item_id)
if by_index is None:
return None
partial = by_index.pop(content_index, None)
if not by_index:
self._input_transcript_accumulators.pop(item_id, None)
return partial

def _handle_conversion_item_input_audio_transcription_completed(
self, event: ConversationItemInputAudioTranscriptionCompletedEvent
) -> None:
self._clear_transcript_accumulator(event.item_id, event.content_index or 0)

confidence = calculate_confidence_from_logprobs(event.logprobs)

if remote_item := self._remote_chat_ctx.get(event.item_id):
Expand All @@ -1787,6 +1822,17 @@ def _handle_conversion_item_input_audio_transcription_failed(
extra={"error": event.error},
)

# close any open partial stream so consumers waiting for is_final don't hang
partial = self._clear_transcript_accumulator(event.item_id, event.content_index or 0)
if partial is None:
return
self.emit(
"input_audio_transcription_completed",
llm.InputTranscriptionCompleted(
item_id=event.item_id, transcript=partial, is_final=True
),
)

def _handle_response_text_delta(self, event: ResponseTextDeltaEvent) -> None:
assert self._current_generation is not None, "current_generation is None"
item_generation = self._current_generation.messages[event.item_id]
Expand Down
10 changes: 6 additions & 4 deletions tests/test_realtime/test_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,18 +226,20 @@ async def test_vad_speech_events(rt_session: llm.RealtimeSession):
@pytest.mark.parametrize("rt_session", REALTIME_MODELS, indirect=True)
async def test_input_audio_transcription(rt_session: llm.RealtimeSession):
transcripts: list[str] = []
transcript_received = asyncio.Event()
final_received = asyncio.Event()

def on_transcript(ev: llm.InputTranscriptionCompleted):
transcripts.append(ev.transcript)
transcript_received.set()
if ev.is_final:
final_received.set()

rt_session.on("input_audio_transcription_completed", on_transcript)
await _push_speech(rt_session, "weather_question")
rt_session.commit_audio()

await asyncio.wait_for(transcript_received.wait(), timeout=15)
full = " ".join(transcripts).lower()
await asyncio.wait_for(final_received.wait(), timeout=15)
assert transcripts, "no transcript received"
full = transcripts[-1].lower()
assert "weather" in full or "paris" in full


Expand Down
Loading