diff --git a/.changeset/smooth-bottles-confess.md b/.changeset/smooth-bottles-confess.md new file mode 100644 index 00000000..50e49948 --- /dev/null +++ b/.changeset/smooth-bottles-confess.md @@ -0,0 +1,5 @@ +--- +"@openai/agents-realtime": patch +--- + +fix: #613 Listen to peerConnection state in `OpenAIRealtimeWebRTC` to detect disconnects diff --git a/packages/agents-realtime/src/openaiRealtimeWebRtc.ts b/packages/agents-realtime/src/openaiRealtimeWebRtc.ts index b267e037..19d5279b 100644 --- a/packages/agents-realtime/src/openaiRealtimeWebRtc.ts +++ b/packages/agents-realtime/src/openaiRealtimeWebRtc.ts @@ -181,6 +181,23 @@ export class OpenAIRealtimeWebRTC const dataChannel = peerConnection.createDataChannel('oai-events'); let callId: string | undefined = undefined; + const attachConnectionStateHandler = ( + connection: RTCPeerConnection, + ) => { + connection.onconnectionstatechange = () => { + switch (connection.connectionState) { + case 'disconnected': + case 'failed': + case 'closed': + this.close(); + break; + // 'connected' state is handled by dataChannel.onopen. So we don't need to handle it here. + // 'new' and 'connecting' are intermediate states and do not require action here. + } + }; + }; + attachConnectionStateHandler(peerConnection); + this.#state = { status: 'connecting', peerConnection, @@ -249,8 +266,13 @@ export class OpenAIRealtimeWebRTC peerConnection.addTrack(stream.getAudioTracks()[0]); if (this.options.changePeerConnection) { + const originalPeerConnection = peerConnection; peerConnection = await this.options.changePeerConnection(peerConnection); + if (originalPeerConnection !== peerConnection) { + originalPeerConnection.onconnectionstatechange = null; + } + attachConnectionStateHandler(peerConnection); this.#state = { ...this.#state, peerConnection }; } @@ -332,6 +354,7 @@ export class OpenAIRealtimeWebRTC if (this.#state.peerConnection) { const peerConnection = this.#state.peerConnection; + peerConnection.onconnectionstatechange = null; peerConnection.getSenders().forEach((sender) => { sender.track?.stop(); }); diff --git a/packages/agents-realtime/test/openaiRealtimeWebRtc.test.ts b/packages/agents-realtime/test/openaiRealtimeWebRtc.test.ts index d53e857c..6db2da4a 100644 --- a/packages/agents-realtime/test/openaiRealtimeWebRtc.test.ts +++ b/packages/agents-realtime/test/openaiRealtimeWebRtc.test.ts @@ -16,22 +16,49 @@ let lastChannel: FakeRTCDataChannel | null = null; class FakeRTCPeerConnection { ontrack: ((ev: any) => void) | null = null; + onconnectionstatechange: (() => void) | null = null; + connectionState = 'new'; + createDataChannel(_name: string) { lastChannel = new FakeRTCDataChannel(); // simulate async open event - setTimeout(() => lastChannel?.dispatchEvent(new Event('open'))); + setTimeout(() => { + this._simulateStateChange('connected'); + lastChannel?.dispatchEvent(new Event('open')); + }, 0); return lastChannel as unknown as RTCDataChannel; } addTrack() {} async createOffer() { + this._simulateStateChange('connecting'); return { sdp: 'offer', type: 'offer' }; } async setLocalDescription(_desc: any) {} async setRemoteDescription(_desc: any) {} - close() {} + close() { + this._simulateStateChange('closed'); + } getSenders() { return [] as any; } + + _simulateStateChange( + state: + | 'new' + | 'connecting' + | 'connected' + | 'disconnected' + | 'failed' + | 'closed', + ) { + if (this.connectionState === state) return; + this.connectionState = state; + setTimeout(() => { + if (this.onconnectionstatechange) { + this.onconnectionstatechange(); + } + }, 0); + } } describe('OpenAIRealtimeWebRTC.interrupt', () => { @@ -219,6 +246,113 @@ describe('OpenAIRealtimeWebRTC.interrupt', () => { }); }); +describe('OpenAIRealtimeWebRTC.connectionState', () => { + const originals: Record = {}; + + beforeEach(() => { + originals.RTCPeerConnection = (global as any).RTCPeerConnection; + originals.navigator = (global as any).navigator; + originals.document = (global as any).document; + originals.fetch = (global as any).fetch; + + (global as any).RTCPeerConnection = FakeRTCPeerConnection as any; + Object.defineProperty(globalThis, 'navigator', { + value: { + mediaDevices: { + getUserMedia: async () => ({ + getAudioTracks: () => [{ enabled: true }], + }), + }, + }, + configurable: true, + writable: true, + }); + Object.defineProperty(globalThis, 'document', { + value: { createElement: () => ({ autoplay: true }) }, + configurable: true, + writable: true, + }); + Object.defineProperty(globalThis, 'fetch', { + value: async () => ({ + text: async () => 'answer', + headers: { + get: (headerKey: string) => { + if (headerKey === 'Location') { + return 'https://api.openai.com/v1/calls/rtc_u1_1234567890'; + } + return null; + }, + }, + }), + configurable: true, + writable: true, + }); + }); + + afterEach(() => { + (global as any).RTCPeerConnection = originals.RTCPeerConnection; + Object.defineProperty(globalThis, 'navigator', { + value: originals.navigator, + configurable: true, + writable: true, + }); + Object.defineProperty(globalThis, 'document', { + value: originals.document, + configurable: true, + writable: true, + }); + Object.defineProperty(globalThis, 'fetch', { + value: originals.fetch, + configurable: true, + writable: true, + }); + lastChannel = null; + }); + + it('fires connection_change and disconnects on peer connection failure', async () => { + const rtc = new OpenAIRealtimeWebRTC(); + const events: string[] = []; + rtc.on('connection_change', (status) => events.push(status)); + await rtc.connect({ apiKey: 'ek_test' }); + expect(rtc.status).toBe('connected'); + expect(events).toEqual(['connecting', 'connected']); + const pc = rtc.connectionState + .peerConnection as unknown as FakeRTCPeerConnection; + expect(pc).toBeInstanceOf(FakeRTCPeerConnection); + pc._simulateStateChange('failed'); + await new Promise((resolve) => setTimeout(resolve, 0)); + expect(rtc.status).toBe('disconnected'); + expect(events).toEqual(['connecting', 'connected', 'disconnected']); + }); + + it('migrates connection state handler when peer connection is replaced', async () => { + class CustomFakePeerConnection extends FakeRTCPeerConnection {} + const customPC = new CustomFakePeerConnection(); + + const rtc = new OpenAIRealtimeWebRTC({ + changePeerConnection: async () => customPC as any, + }); + + const closeSpy = vi.spyOn(rtc, 'close'); + const events: string[] = []; + rtc.on('connection_change', (status) => events.push(status)); + + await rtc.connect({ apiKey: 'ek_test' }); + + expect(rtc.status).toBe('connected'); + expect(rtc.connectionState.peerConnection).toBe(customPC as any); + expect(closeSpy).not.toHaveBeenCalled(); + expect(events).toEqual(['connecting', 'connected']); + + customPC._simulateStateChange('failed'); + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(closeSpy).toHaveBeenCalled(); + expect(rtc.status).toBe('disconnected'); + expect(events).toEqual(['connecting', 'connected', 'disconnected']); + }); +}); + describe('OpenAIRealtimeWebRTC.callId', () => { const originals: Record = {}; const callId = 'rtc_u1_1234567890'; @@ -288,6 +422,7 @@ describe('OpenAIRealtimeWebRTC.callId', () => { await rtc.connect({ apiKey: 'ek_test' }); expect(rtc.callId).toBe(callId); rtc.close(); + await new Promise((resolve) => setTimeout(resolve, 0)); expect(rtc.callId).toBeUndefined(); }); });