Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/wise-results-mate.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@openai/agents-realtime': patch
---

fix: avoid realtime guardrail race condition and detect ongoing response
18 changes: 18 additions & 0 deletions examples/realtime-next/src/app/websocket/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
TransportEvent,
RealtimeItem,
OutputGuardrailTripwireTriggered,
RealtimeOutputGuardrail,
} from '@openai/agents/realtime';
import { useEffect, useRef, useState } from 'react';
import { z } from 'zod';
Expand All @@ -26,6 +27,21 @@ const refundBackchannel = tool({
},
});

const guardrails: RealtimeOutputGuardrail[] = [
{
name: 'No mention of Dom',
execute: async ({ agentOutput }) => {
const domInOutput = agentOutput.includes('Dom');
return {
tripwireTriggered: domInOutput,
outputInfo: {
domInOutput,
},
};
},
},
];

const agent = new RealtimeAgent({
name: 'Greeter',
instructions:
Expand All @@ -48,6 +64,7 @@ export default function Home() {
useEffect(() => {
session.current = new RealtimeSession(agent, {
transport: 'websocket',
outputGuardrails: guardrails,
});
recorder.current = new WavRecorder({ sampleRate: 24000 });
player.current = new WavStreamPlayer({ sampleRate: 24000 });
Expand Down Expand Up @@ -87,6 +104,7 @@ export default function Home() {
async function connect() {
if (isConnected) {
await session.current?.close();
await player.current?.interrupt();
await recorder.current?.end();
setIsConnected(false);
} else {
Expand Down
1 change: 1 addition & 0 deletions packages/agents-realtime/src/openaiRealtimeBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ export abstract class OpenAIRealtimeBase
type: 'transcript_delta',
delta: parsed.delta,
itemId: parsed.item_id,
responseId: parsed.response_id,
});
}
// no support for partial transcripts yet.
Expand Down
25 changes: 14 additions & 11 deletions packages/agents-realtime/src/openaiRealtimeEvents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ import type { MessageEvent as WebSocketMessageEvent } from 'ws';
// provide better runtime validation when parsing events from the server.

export const realtimeResponse = z.object({
id: z.string().optional(),
conversation_id: z.string().optional(),
max_output_tokens: z.number().or(z.literal('inf')).optional(),
id: z.string().optional().nullable(),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had to mark all of these as nullable since optional() in Zod now explicitly requires nullable()

conversation_id: z.string().optional().nullable(),
max_output_tokens: z.number().or(z.literal('inf')).optional().nullable(),
metadata: z.record(z.string(), z.any()).optional().nullable(),
modalities: z.array(z.string()).optional(),
object: z.literal('realtime.response').optional(),
output: z.array(z.any()).optional(),
output_audio_format: z.string().optional(),
status: z.enum(['completed', 'incomplete', 'failed', 'cancelled']).optional(),
modalities: z.array(z.string()).optional().nullable(),
object: z.literal('realtime.response').optional().nullable(),
output: z.array(z.any()).optional().nullable(),
output_audio_format: z.string().optional().nullable(),
status: z
.enum(['completed', 'incomplete', 'failed', 'cancelled', 'in_progress'])
.optional()
.nullable(),
status_details: z.record(z.string(), z.any()).optional().nullable(),
usage: z
.object({
Expand All @@ -26,8 +29,9 @@ export const realtimeResponse = z.object({
.optional()
.nullable(),
})
.optional(),
voice: z.string().optional(),
.optional()
.nullable(),
voice: z.string().optional().nullable(),
});

// Basic content schema used by ConversationItem.
Expand Down Expand Up @@ -315,7 +319,6 @@ export const responseDoneEventSchema = z.object({
type: z.literal('response.done'),
event_id: z.string(),
response: realtimeResponse,
test: z.boolean(),
});

export const responseFunctionCallArgumentsDeltaEventSchema = z.object({
Expand Down
2 changes: 2 additions & 0 deletions packages/agents-realtime/src/openaiRealtimeWebRtc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ export class OpenAIRealtimeWebRTC
if (!parsed || isGeneric) {
return;
}

if (parsed.type === 'response.created') {
this.#ongoingResponse = true;
} else if (parsed.type === 'response.done') {
Expand Down Expand Up @@ -334,6 +335,7 @@ export class OpenAIRealtimeWebRTC
this.sendEvent({
type: 'response.cancel',
});
this.#ongoingResponse = false;
}

this.sendEvent({
Expand Down
3 changes: 1 addition & 2 deletions packages/agents-realtime/src/openaiRealtimeWebsocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ export class OpenAIRealtimeWebSocket
this.sendEvent({
type: 'response.cancel',
});
this.#ongoingResponse = false;
}
}

Expand Down Expand Up @@ -367,8 +368,6 @@ export class OpenAIRealtimeWebSocket
this._cancelResponse();

const elapsedTime = Date.now() - this._firstAudioTimestamp;
console.log(`Interrupting response after ${elapsedTime}ms`);
console.log(`Audio length: ${this._audioLengthMs}ms`);
if (elapsedTime >= 0 && elapsedTime < this._audioLengthMs) {
this._interrupt(elapsedTime);
}
Expand Down
47 changes: 27 additions & 20 deletions packages/agents-realtime/src/realtimeSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ export class RealtimeSession<
#transcribedTextDeltas: Record<string, string> = {};
#history: RealtimeItem[] = [];
#shouldIncludeAudioData: boolean;
#interruptedByGuardrail: Record<string, boolean> = {};

constructor(
public readonly initialAgent:
Expand Down Expand Up @@ -446,7 +447,7 @@ export class RealtimeSession<
}
}

async #runOutputGuardrails(output: string) {
async #runOutputGuardrails(output: string, responseId: string) {
if (this.#outputGuardrails.length === 0) {
return;
}
Expand All @@ -460,24 +461,28 @@ export class RealtimeSession<
this.#outputGuardrails.map((guardrail) => guardrail.run(guardrailArgs)),
);

for (const result of results) {
if (result.output.tripwireTriggered) {
const error = new OutputGuardrailTripwireTriggered(
`Output guardrail triggered: ${JSON.stringify(result.output.outputInfo)}`,
result,
);
this.emit(
'guardrail_tripped',
this.#context,
this.#currentAgent,
error,
);
this.interrupt();

const feedbackText = getRealtimeGuardrailFeedbackMessage(result);
this.sendMessage(feedbackText);
break;
const firstTripwireTriggered = results.find(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the same behavior as before. Just makes it more explicit that we trip on the first guardrail

(result) => result.output.tripwireTriggered,
);
if (firstTripwireTriggered) {
// this ensures that if one guardrail already trips and we are in the middle of another
// guardrail run, we don't trip again
if (this.#interruptedByGuardrail[responseId]) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this there might be jankyness since other guardrail checks might have been kicked off already. So this makes sure that we are only tripping once per response

return;
}
this.#interruptedByGuardrail[responseId] = true;
const error = new OutputGuardrailTripwireTriggered(
`Output guardrail triggered: ${JSON.stringify(firstTripwireTriggered.output.outputInfo)}`,
firstTripwireTriggered,
);
this.emit('guardrail_tripped', this.#context, this.#currentAgent, error);
this.interrupt();

const feedbackText = getRealtimeGuardrailFeedbackMessage(
firstTripwireTriggered,
);
this.sendMessage(feedbackText);
return;
}
}

Expand All @@ -498,7 +503,7 @@ export class RealtimeSession<
this.emit('agent_end', this.#context, this.#currentAgent, textOutput);
this.#currentAgent.emit('agent_end', this.#context, textOutput);

this.#runOutputGuardrails(textOutput);
this.#runOutputGuardrails(textOutput, event.response.id);
});

this.#transport.on('audio_done', () => {
Expand All @@ -511,6 +516,7 @@ export class RealtimeSession<
try {
const delta = event.delta;
const itemId = event.itemId;
const responseId = event.responseId;
if (lastItemId !== itemId) {
lastItemId = itemId;
lastRunIndex = 0;
Expand All @@ -531,7 +537,7 @@ export class RealtimeSession<
// We don't cancel existing runs because we want the first one to fail to fail
// The transport layer should upon failure handle the interruption and stop the model
// from generating further
this.#runOutputGuardrails(newText);
this.#runOutputGuardrails(newText, responseId);
}
} catch (err) {
this.emit('error', {
Expand Down Expand Up @@ -672,6 +678,7 @@ export class RealtimeSession<
* Disconnect from the session.
*/
close() {
this.#interruptedByGuardrail = {};
this.#transport.close();
}

Expand Down
1 change: 1 addition & 0 deletions packages/agents-realtime/src/transportLayerEvents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export type TransportLayerTranscriptDelta = {
type: 'transcript_delta';
itemId: string;
delta: string;
responseId: string;
};

export type TransportLayerResponseCompleted =
Expand Down
12 changes: 10 additions & 2 deletions packages/agents-realtime/test/realtimeSession.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,16 @@ describe('RealtimeSession', () => {
outputGuardrailSettings: { debounceTextLength: 1 },
});
await s.connect({ apiKey: 'test' });
t.emit('audio_transcript_delta', { delta: 'a', itemId: '1' } as any);
t.emit('audio_transcript_delta', { delta: 'a', itemId: '2' } as any);
t.emit('audio_transcript_delta', {
delta: 'a',
itemId: '1',
responseId: 'z',
} as any);
t.emit('audio_transcript_delta', {
delta: 'a',
itemId: '2',
responseId: 'z',
} as any);
await vi.waitFor(() => expect(runMock).toHaveBeenCalledTimes(2));
vi.restoreAllMocks();
});
Expand Down