diff --git a/src/client.ts b/src/client.ts index f588e55..0396a83 100644 --- a/src/client.ts +++ b/src/client.ts @@ -213,6 +213,18 @@ export class InferenceGatewayClient { const decoder = new TextDecoder(); let buffer = ''; + const incompleteToolCalls = new Map< + number, + { + id: string; + type: ChatCompletionToolType; + function: { + name: string; + arguments: string; + }; + } + >(); + while (true) { const { done, value } = await reader.read(); if (done) break; @@ -226,6 +238,16 @@ export class InferenceGatewayClient { const data = line.slice(5).trim(); if (data === '[DONE]') { + for (const [, toolCall] of incompleteToolCalls.entries()) { + callbacks.onTool?.({ + id: toolCall.id, + type: toolCall.type, + function: { + name: toolCall.function.name, + arguments: toolCall.function.arguments, + }, + }); + } callbacks.onFinish?.(null); return; } @@ -242,15 +264,54 @@ export class InferenceGatewayClient { const toolCalls = chunk.choices[0]?.delta?.tool_calls; if (toolCalls && toolCalls.length > 0) { - const toolCall: SchemaChatCompletionMessageToolCall = { - id: toolCalls[0].id || '', - type: ChatCompletionToolType.function, - function: { - name: toolCalls[0].function?.name || '', - arguments: toolCalls[0].function?.arguments || '', - }, - }; - callbacks.onTool?.(toolCall); + for (const toolCallChunk of toolCalls) { + const index = toolCallChunk.index; + + if (!incompleteToolCalls.has(index)) { + incompleteToolCalls.set(index, { + id: toolCallChunk.id || '', + type: ChatCompletionToolType.function, + function: { + name: toolCallChunk.function?.name || '', + arguments: toolCallChunk.function?.arguments || '', + }, + }); + } else { + const existingToolCall = incompleteToolCalls.get(index)!; + + if (toolCallChunk.id) { + existingToolCall.id = toolCallChunk.id; + } + + if (toolCallChunk.function?.name) { + existingToolCall.function.name = + toolCallChunk.function.name; + } + + if (toolCallChunk.function?.arguments) { + existingToolCall.function.arguments += + toolCallChunk.function.arguments; + } + } + } + } + + const finishReason = chunk.choices[0]?.finish_reason; + if ( + finishReason === 'tool_calls' && + incompleteToolCalls.size > 0 + ) { + for (const [, toolCall] of incompleteToolCalls.entries()) { + callbacks.onTool?.({ + id: toolCall.id, + type: toolCall.type, + function: { + name: toolCall.function.name, + arguments: toolCall.function.arguments, + }, + }); + } + incompleteToolCalls.clear(); } } catch (e) { globalThis.console.error('Error parsing SSE data:', e); diff --git a/tests/client.test.ts b/tests/client.test.ts index 471b090..c0b35e4 100644 --- a/tests/client.test.ts +++ b/tests/client.test.ts @@ -320,7 +320,15 @@ describe('InferenceGatewayClient', () => { expect(callbacks.onOpen).toHaveBeenCalledTimes(1); expect(callbacks.onChunk).toHaveBeenCalledTimes(6); - expect(callbacks.onTool).toHaveBeenCalledTimes(4); + expect(callbacks.onTool).toHaveBeenCalledTimes(1); + expect(callbacks.onTool).toHaveBeenCalledWith({ + id: 'call_123', + type: 'function', + function: { + name: 'get_weather', + arguments: '{"location":"San Francisco, CA"}' + } + }); expect(callbacks.onFinish).toHaveBeenCalledTimes(1); });