Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/smooth-bottles-confess.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@openai/agents-realtime": patch
---

fix: #613 Listen to peerConnection state in `OpenAIRealtimeWebRTC` to detect disconnects
23 changes: 23 additions & 0 deletions packages/agents-realtime/src/openaiRealtimeWebRtc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,23 @@ export class OpenAIRealtimeWebRTC
const dataChannel = peerConnection.createDataChannel('oai-events');
let callId: string | undefined = undefined;

const attachConnectionStateHandler = (
connection: RTCPeerConnection,
) => {
connection.onconnectionstatechange = () => {
switch (connection.connectionState) {
case 'disconnected':
case 'failed':
case 'closed':
this.close();
break;
// 'connected' state is handled by dataChannel.onopen. So we don't need to handle it here.
// 'new' and 'connecting' are intermediate states and do not require action here.
}
};
};
attachConnectionStateHandler(peerConnection);

this.#state = {
status: 'connecting',
peerConnection,
Expand Down Expand Up @@ -249,8 +266,13 @@ export class OpenAIRealtimeWebRTC
peerConnection.addTrack(stream.getAudioTracks()[0]);

if (this.options.changePeerConnection) {
const originalPeerConnection = peerConnection;
peerConnection =
await this.options.changePeerConnection(peerConnection);
if (originalPeerConnection !== peerConnection) {
originalPeerConnection.onconnectionstatechange = null;
}
attachConnectionStateHandler(peerConnection);
this.#state = { ...this.#state, peerConnection };
}

Expand Down Expand Up @@ -332,6 +354,7 @@ export class OpenAIRealtimeWebRTC

if (this.#state.peerConnection) {
const peerConnection = this.#state.peerConnection;
peerConnection.onconnectionstatechange = null;
peerConnection.getSenders().forEach((sender) => {
sender.track?.stop();
});
Expand Down
139 changes: 137 additions & 2 deletions packages/agents-realtime/test/openaiRealtimeWebRtc.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,49 @@ let lastChannel: FakeRTCDataChannel | null = null;

class FakeRTCPeerConnection {
ontrack: ((ev: any) => void) | null = null;
onconnectionstatechange: (() => void) | null = null;
connectionState = 'new';

createDataChannel(_name: string) {
lastChannel = new FakeRTCDataChannel();
// simulate async open event
setTimeout(() => lastChannel?.dispatchEvent(new Event('open')));
setTimeout(() => {
this._simulateStateChange('connected');
lastChannel?.dispatchEvent(new Event('open'));
}, 0);
return lastChannel as unknown as RTCDataChannel;
}
addTrack() {}
async createOffer() {
this._simulateStateChange('connecting');
return { sdp: 'offer', type: 'offer' };
}
async setLocalDescription(_desc: any) {}
async setRemoteDescription(_desc: any) {}
close() {}
close() {
this._simulateStateChange('closed');
}
getSenders() {
return [] as any;
}

_simulateStateChange(
state:
| 'new'
| 'connecting'
| 'connected'
| 'disconnected'
| 'failed'
| 'closed',
) {
if (this.connectionState === state) return;
this.connectionState = state;
setTimeout(() => {
if (this.onconnectionstatechange) {
this.onconnectionstatechange();
}
}, 0);
}
}

describe('OpenAIRealtimeWebRTC.interrupt', () => {
Expand Down Expand Up @@ -219,6 +246,113 @@ describe('OpenAIRealtimeWebRTC.interrupt', () => {
});
});

describe('OpenAIRealtimeWebRTC.connectionState', () => {
const originals: Record<string, any> = {};

beforeEach(() => {
originals.RTCPeerConnection = (global as any).RTCPeerConnection;
originals.navigator = (global as any).navigator;
originals.document = (global as any).document;
originals.fetch = (global as any).fetch;

(global as any).RTCPeerConnection = FakeRTCPeerConnection as any;
Object.defineProperty(globalThis, 'navigator', {
value: {
mediaDevices: {
getUserMedia: async () => ({
getAudioTracks: () => [{ enabled: true }],
}),
},
},
configurable: true,
writable: true,
});
Object.defineProperty(globalThis, 'document', {
value: { createElement: () => ({ autoplay: true }) },
configurable: true,
writable: true,
});
Object.defineProperty(globalThis, 'fetch', {
value: async () => ({
text: async () => 'answer',
headers: {
get: (headerKey: string) => {
if (headerKey === 'Location') {
return 'https://api.openai.com/v1/calls/rtc_u1_1234567890';
}
return null;
},
},
}),
configurable: true,
writable: true,
});
});

afterEach(() => {
(global as any).RTCPeerConnection = originals.RTCPeerConnection;
Object.defineProperty(globalThis, 'navigator', {
value: originals.navigator,
configurable: true,
writable: true,
});
Object.defineProperty(globalThis, 'document', {
value: originals.document,
configurable: true,
writable: true,
});
Object.defineProperty(globalThis, 'fetch', {
value: originals.fetch,
configurable: true,
writable: true,
});
lastChannel = null;
});

it('fires connection_change and disconnects on peer connection failure', async () => {
const rtc = new OpenAIRealtimeWebRTC();
const events: string[] = [];
rtc.on('connection_change', (status) => events.push(status));
await rtc.connect({ apiKey: 'ek_test' });
expect(rtc.status).toBe('connected');
expect(events).toEqual(['connecting', 'connected']);
const pc = rtc.connectionState
.peerConnection as unknown as FakeRTCPeerConnection;
expect(pc).toBeInstanceOf(FakeRTCPeerConnection);
pc._simulateStateChange('failed');
await new Promise((resolve) => setTimeout(resolve, 0));
expect(rtc.status).toBe('disconnected');
expect(events).toEqual(['connecting', 'connected', 'disconnected']);
});

it('migrates connection state handler when peer connection is replaced', async () => {
class CustomFakePeerConnection extends FakeRTCPeerConnection {}
const customPC = new CustomFakePeerConnection();

const rtc = new OpenAIRealtimeWebRTC({
changePeerConnection: async () => customPC as any,
});

const closeSpy = vi.spyOn(rtc, 'close');
const events: string[] = [];
rtc.on('connection_change', (status) => events.push(status));

await rtc.connect({ apiKey: 'ek_test' });

expect(rtc.status).toBe('connected');
expect(rtc.connectionState.peerConnection).toBe(customPC as any);
expect(closeSpy).not.toHaveBeenCalled();
expect(events).toEqual(['connecting', 'connected']);

customPC._simulateStateChange('failed');
await new Promise((resolve) => setTimeout(resolve, 0));

expect(closeSpy).toHaveBeenCalled();
expect(rtc.status).toBe('disconnected');
expect(events).toEqual(['connecting', 'connected', 'disconnected']);
});
});

describe('OpenAIRealtimeWebRTC.callId', () => {
const originals: Record<string, any> = {};
const callId = 'rtc_u1_1234567890';
Expand Down Expand Up @@ -288,6 +422,7 @@ describe('OpenAIRealtimeWebRTC.callId', () => {
await rtc.connect({ apiKey: 'ek_test' });
expect(rtc.callId).toBe(callId);
rtc.close();
await new Promise((resolve) => setTimeout(resolve, 0));
expect(rtc.callId).toBeUndefined();
});
});