From 78fb2fdb0ae9f849e0210437a4506a45ef353e5f Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Mon, 31 Mar 2025 00:17:58 +0000 Subject: [PATCH 1/3] refactor: expand type definitions for chat completion and tool calls Signed-off-by: Eden Reich --- README.md | 281 +++++---- openapi.yaml | 1307 ++++++++++++++++++++++++------------------ src/client.ts | 323 ++++++++--- src/types/index.ts | 148 ++++- tests/client.test.ts | 454 +++++++++------ 5 files changed, 1589 insertions(+), 924 deletions(-) diff --git a/README.md b/README.md index 8fc6426..1f63bab 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,18 @@ -# Inference Gateway Typescript SDK +# Inference Gateway TypeScript SDK -An SDK written in Typescript for the [Inference Gateway](https://github.com/edenreich/inference-gateway). +An SDK written in TypeScript for the [Inference Gateway](https://github.com/edenreich/inference-gateway). -- [Inference Gateway Typescript SDK](#inference-gateway-typescript-sdk) +- [Inference Gateway TypeScript SDK](#inference-gateway-typescript-sdk) - [Installation](#installation) - [Usage](#usage) - [Creating a Client](#creating-a-client) - - [Listing All Models](#listing-all-models) - - [List Models by Provider](#list-models-by-provider) - - [Generating Content](#generating-content) - - [Streaming Content](#streaming-content) + - [Listing Models](#listing-models) + - [Creating Chat Completions](#creating-chat-completions) + - [Streaming Chat Completions](#streaming-chat-completions) + - [Tool Calls](#tool-calls) + - [Proxying Requests](#proxying-requests) - [Health Check](#health-check) + - [Creating a Client with Custom Options](#creating-a-client-with-custom-options) - [Contributing](#contributing) - [License](#license) @@ -22,152 +24,215 @@ Run `npm i @inference-gateway/sdk`. ### Creating a Client +```typescript +import { InferenceGatewayClient } from '@inference-gateway/sdk'; + +// Create a client with default options +const client = new InferenceGatewayClient({ + baseURL: 'http://localhost:8080/v1', + apiKey: 'your-api-key', // Optional +}); +``` + +### Listing Models + +To list all available models: + +```typescript +import { InferenceGatewayClient, Provider } from '@inference-gateway/sdk'; + +const client = new InferenceGatewayClient({ + baseURL: 'http://localhost:8080/v1', +}); + +try { + // List all models + const models = await client.listModels(); + console.log('All models:', models); + + // List models from a specific provider + const openaiModels = await client.listModels(Provider.OpenAI); + console.log('OpenAI models:', openaiModels); +} catch (error) { + console.error('Error:', error); +} +``` + +### Creating Chat Completions + +To generate content using a model: + ```typescript import { InferenceGatewayClient, - Message, + MessageRole, Provider, } from '@inference-gateway/sdk'; -async function main() { - const client = new InferenceGatewayClient('http://localhost:8080'); - - try { - // List available models - const models = await client.listModels(); - models.forEach((providerModels) => { - console.log(`Provider: ${providerModels.provider}`); - providerModels.models.forEach((model) => { - console.log(`Model: ${model.name}`); - }); - }); - - // Generate content - const response = await client.generateContent({ - provider: Provider.Ollama, - model: 'llama2', +const client = new InferenceGatewayClient({ + baseURL: 'http://localhost:8080/v1', +}); + +try { + const response = await client.createChatCompletion( + { + model: 'gpt-4o', messages: [ { role: MessageRole.System, - content: 'You are a helpful llama', + content: 'You are a helpful assistant', }, { role: MessageRole.User, content: 'Tell me a joke', }, ], - }); - - console.log('Response:', response); - } catch (error) { - console.error('Error:', error); - } -} - -main(); -``` - -### Listing All Models - -To list all available models from all providers, use the `listModels` method: + }, + Provider.OpenAI + ); // Provider is optional -```typescript -try { - const models = await client.listModels(); - models.forEach((providerModels) => { - console.log(`Provider: ${providerModels.provider}`); - providerModels.models.forEach((model) => { - console.log(`Model: ${model.name}`); - }); - }); + console.log('Response:', response.choices[0].message.content); } catch (error) { console.error('Error:', error); } ``` -### List Models by Provider +### Streaming Chat Completions -To list all available models from a specific provider, use the `listModelsByProvider` method: +To stream content from a model: ```typescript +import { + InferenceGatewayClient, + MessageRole, + Provider, +} from '@inference-gateway/sdk'; + +const client = new InferenceGatewayClient({ + baseURL: 'http://localhost:8080/v1', +}); + try { - const providerModels = await client.listModelsByProvider(Provider.OpenAI); - console.log(`Provider: ${providerModels.provider}`); - providerModels.models.forEach((model) => { - console.log(`Model: ${model.name}`); - }); + await client.streamChatCompletion( + { + model: 'llama-3.3-70b-versatile', + messages: [ + { + role: MessageRole.User, + content: 'Tell me a story', + }, + ], + }, + { + onOpen: () => console.log('Stream opened'), + onContent: (content) => process.stdout.write(content), + onChunk: (chunk) => console.log('Received chunk:', chunk.id), + onFinish: () => console.log('\nStream completed'), + onError: (error) => console.error('Stream error:', error), + }, + Provider.Groq // Provider is optional + ); } catch (error) { console.error('Error:', error); } ``` -### Generating Content +### Tool Calls -To generate content using a model, use the `generateContent` method: +To use tool calls with models that support them: ```typescript import { InferenceGatewayClient, - Message, MessageRole, Provider, } from '@inference-gateway/sdk'; -const client = new InferenceGatewayClient('http://localhost:8080'); +const client = new InferenceGatewayClient({ + baseURL: 'http://localhost:8080/v1', +}); - const response = await client.generateContent({ - provider: Provider.Ollama, - model: 'llama2', - messages: [ - { - role: MessageRole.System, - content: 'You are a helpful llama', - }, - { - role: MessageRole.User, - content: 'Tell me a joke', +try { + await client.streamChatCompletion( + { + model: 'gpt-4o', + messages: [ + { + role: MessageRole.User, + content: 'What's the weather in San Francisco?', + }, + ], + tools: [ + { + type: 'function', + function: { + name: 'get_weather', + parameters: { + type: 'object', + properties: { + location: { + type: 'string', + description: 'The city and state, e.g. San Francisco, CA', + }, + }, + required: ['location'], + }, + }, + }, + ], + }, + { + onTool: (toolCall) => { + console.log('Tool call:', toolCall.function.name); + console.log('Arguments:', toolCall.function.arguments); }, - ], - }); - - console.log('Provider:', response.provider); - console.log('Response:', response.response); + onContent: (content) => process.stdout.write(content), + onFinish: () => console.log('\nStream completed'), + }, + Provider.OpenAI + ); } catch (error) { console.error('Error:', error); } ``` -### Streaming Content +### Proxying Requests -To stream content using a model, use the `streamContent` method: +To proxy requests directly to a provider: ```typescript -const client = new InferenceGatewayClient('http://localhost:8080'); - -await client.generateContentStream( - { - provider: Provider.Groq, - model: 'deepseek-r1-distill-llama-70b', - messages: [ - { - role: MessageRole.User, - content: 'Tell me a story', - }, - ], - }, - { - onMessageStart: (role) => console.log('Message started:', role), - onContentDelta: (content) => process.stdout.write(content), - onStreamEnd: () => console.log('\nStream completed'), - } -); +import { InferenceGatewayClient, Provider } from '@inference-gateway/sdk'; + +const client = new InferenceGatewayClient({ + baseURL: 'http://localhost:8080/v1', +}); + +try { + const response = await client.proxy(Provider.OpenAI, 'embeddings', { + method: 'POST', + body: JSON.stringify({ + model: 'text-embedding-ada-002', + input: 'Hello world', + }), + }); + + console.log('Embeddings:', response); +} catch (error) { + console.error('Error:', error); +} ``` ### Health Check -To check if the Inference Gateway is running, use the `healthCheck` method: +To check if the Inference Gateway is running: ```typescript +import { InferenceGatewayClient } from '@inference-gateway/sdk'; + +const client = new InferenceGatewayClient({ + baseURL: 'http://localhost:8080/v1', +}); + try { const isHealthy = await client.healthCheck(); console.log('API is healthy:', isHealthy); @@ -176,6 +241,26 @@ try { } ``` +### Creating a Client with Custom Options + +You can create a new client with custom options using the `withOptions` method: + +```typescript +import { InferenceGatewayClient } from '@inference-gateway/sdk'; + +const client = new InferenceGatewayClient({ + baseURL: 'http://localhost:8080/v1', +}); + +// Create a new client with custom headers +const clientWithHeaders = client.withOptions({ + defaultHeaders: { + 'X-Custom-Header': 'value', + }, + timeout: 60000, // 60 seconds +}); +``` + ## Contributing Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file for information about how to get involved. We welcome issues, questions, and pull requests. diff --git a/openapi.yaml b/openapi.yaml index b4abc85..f2177da 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -3,75 +3,141 @@ openapi: 3.1.0 info: title: Inference Gateway API description: | - API for interacting with various language models through the Inference Gateway. + The API for interacting with various language models and other AI services. + OpenAI, Groq, Ollama, and other providers are supported. + OpenAI compatible API for using with existing clients. + Unified API for all providers. + contact: + name: Inference Gateway + url: https://inference-gateway.github.io/docs/ version: 1.0.0 + license: + name: MIT + url: https://github.com/inference-gateway/inference-gateway/blob/main/LICENSE servers: - url: http://localhost:8080 + description: Default server without version prefix for healthcheck and proxy and points + x-server-tags: ["Health", "Proxy"] + - url: http://localhost:8080/v1 + description: Default server with version prefix for listing models and chat completions + x-server-tags: ["Models", "Completions"] + - url: https://api.inference-gateway.local/v1 + description: Local server with version prefix for listing models and chat completions + x-server-tags: ["Models", "Completions"] +tags: + - name: Models + description: List and describe the various models available in the API. + - name: Completions + description: Generate completions from the models. + - name: Proxy + description: Proxy requests to provider endpoints. + - name: Health + description: Health check paths: - /llms: + /models: get: - summary: List all language models operationId: listModels + tags: + - Models + description: | + Lists the currently available models, and provides basic information + about each one such as the owner and availability. + summary: + Lists the currently available models, and provides basic information + about each one such as the owner and availability. security: - bearerAuth: [] - responses: - "200": - description: A list of models by provider - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/ListModelsResponse" - "401": - $ref: "#/components/responses/Unauthorized" - /llms/{provider}: - get: - summary: List all models for a specific provider - operationId: listModelsByProvider parameters: - name: provider - in: path - required: true + in: query + required: false schema: - $ref: "#/components/schemas/Providers" - security: - - bearerAuth: [] + $ref: "#/components/schemas/Provider" + description: Specific provider to query (optional) responses: "200": - description: A list of models + description: List of available models content: application/json: schema: $ref: "#/components/schemas/ListModelsResponse" - "400": - $ref: "#/components/responses/BadRequest" + examples: + allProviders: + summary: Models from all providers + value: + object: "list" + data: + - id: "gpt-4o" + object: "model" + created: 1686935002 + owned_by: "openai" + - id: "llama-3.3-70b-versatile" + object: "model" + created: 1723651281 + owned_by: "groq" + - id: "claude-3-opus-20240229" + object: "model" + created: 1708905600 + owned_by: "anthropic" + - id: "command-r" + object: "model" + created: 1707868800 + owned_by: "cohere" + - id: "phi3:3.8b" + object: "model" + created: 1718441600 + owned_by: "ollama" + singleProvider: + summary: Models from a specific provider + value: + object: "list" + data: + - id: "gpt-4o" + object: "model" + created: 1686935002 + owned_by: "openai" + - id: "gpt-4-turbo" + object: "model" + created: 1687882410 + owned_by: "openai" + - id: "gpt-3.5-turbo" + object: "model" + created: 1677649963 + owned_by: "openai" "401": $ref: "#/components/responses/Unauthorized" - /llms/{provider}/generate: + "500": + $ref: "#/components/responses/InternalError" + /chat/completions: post: - summary: Generate content with a specific provider's LLM - operationId: generateContent + operationId: createChatCompletion + tags: + - Completions + description: | + Generates a chat completion based on the provided input. + The completion can be streamed to the client as it is generated. + summary: Create a chat completion + security: + - bearerAuth: [] parameters: - name: provider - in: path - required: true + in: query + required: false schema: - $ref: "#/components/schemas/Providers" - security: - - bearerAuth: [] + $ref: "#/components/schemas/Provider" + description: Specific provider to use (default determined by model) requestBody: - content: - application/json: - schema: - $ref: "#/components/schemas/GenerateRequest" + $ref: "#/components/requestBodies/CreateChatCompletionRequest" responses: "200": - description: Generated content + description: Successful response content: application/json: schema: - $ref: "#/components/schemas/GenerateResponse" + $ref: "#/components/schemas/CreateChatCompletionResponse" + text/event-stream: + schema: + $ref: "#/components/schemas/SSEvent" "400": $ref: "#/components/responses/BadRequest" "401": @@ -84,7 +150,7 @@ paths: in: path required: true schema: - $ref: "#/components/schemas/Providers" + $ref: "#/components/schemas/Provider" - name: path in: path required: true @@ -94,8 +160,14 @@ paths: type: string description: The remaining path to proxy to the provider get: - summary: Proxy GET request to provider operationId: proxyGet + tags: + - Proxy + description: | + Proxy GET request to provider + The request body depends on the specific provider and endpoint being called. + If you decide to use this approach, please follow the provider-specific documentations. + summary: Proxy GET request to provider responses: "200": $ref: "#/components/responses/ProviderResponse" @@ -108,8 +180,14 @@ paths: security: - bearerAuth: [] post: - summary: Proxy POST request to provider operationId: proxyPost + tags: + - Proxy + description: | + Proxy POST request to provider + The request body depends on the specific provider and endpoint being called. + If you decide to use this approach, please follow the provider-specific documentations. + summary: Proxy POST request to provider requestBody: $ref: "#/components/requestBodies/ProviderRequest" responses: @@ -124,8 +202,14 @@ paths: security: - bearerAuth: [] put: - summary: Proxy PUT request to provider operationId: proxyPut + tags: + - Proxy + description: | + Proxy PUT request to provider + The request body depends on the specific provider and endpoint being called. + If you decide to use this approach, please follow the provider-specific documentations. + summary: Proxy PUT request to provider requestBody: $ref: "#/components/requestBodies/ProviderRequest" responses: @@ -140,8 +224,14 @@ paths: security: - bearerAuth: [] delete: - summary: Proxy DELETE request to provider operationId: proxyDelete + tags: + - Proxy + description: | + Proxy DELETE request to provider + The request body depends on the specific provider and endpoint being called. + If you decide to use this approach, please follow the provider-specific documentations. + summary: Proxy DELETE request to provider responses: "200": $ref: "#/components/responses/ProviderResponse" @@ -154,8 +244,14 @@ paths: security: - bearerAuth: [] patch: - summary: Proxy PATCH request to provider operationId: proxyPatch + tags: + - Proxy + description: | + Proxy PATCH request to provider + The request body depends on the specific provider and endpoint being called. + If you decide to use this approach, please follow the provider-specific documentations. + summary: Proxy PATCH request to provider requestBody: $ref: "#/components/requestBodies/ProviderRequest" responses: @@ -171,6 +267,12 @@ paths: - bearerAuth: [] /health: get: + operationId: healthCheck + tags: + - Health + description: | + Health check endpoint + Returns a 200 status code if the service is healthy summary: Health check responses: "200": @@ -200,25 +302,34 @@ components: type: string temperature: type: number - format: float64 + format: float default: 0.7 - examples: - - openai: - summary: OpenAI chat completion request - value: - model: "gpt-3.5-turbo" - messages: - - role: "user" - content: "Hello! How can I assist you today?" - temperature: 0.7 - - anthropic: - summary: Anthropic Claude request - value: - model: "claude-3-opus-20240229" - messages: - - role: "user" - content: "Explain quantum computing" - temperature: 0.5 + examples: + openai: + summary: OpenAI chat completion request + value: + model: "gpt-3.5-turbo" + messages: + - role: "user" + content: "Hello! How can I assist you today?" + temperature: 0.7 + anthropic: + summary: Anthropic Claude request + value: + model: "claude-3-opus-20240229" + messages: + - role: "user" + content: "Explain quantum computing" + temperature: 0.5 + CreateChatCompletionRequest: + required: true + description: | + ProviderRequest depends on the specific provider and endpoint being called + If you decide to use this approach, please follow the provider-specific documentations. + content: + application/json: + schema: + $ref: "#/components/schemas/CreateChatCompletionRequest" responses: BadRequest: description: Bad request @@ -278,7 +389,7 @@ components: To enable authentication, set ENABLE_AUTH to true. When enabled, requests must include a valid JWT token in the Authorization header. schemas: - Providers: + Provider: type: string enum: - ollama @@ -287,36 +398,137 @@ components: - cloudflare - cohere - anthropic + - deepseek + x-provider-configs: + ollama: + id: "ollama" + url: "http://ollama:8080/v1" + auth_type: "none" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + anthropic: + id: "anthropic" + url: "https://api.anthropic.com/v1" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + cohere: + id: "cohere" + url: "https://api.cohere.ai" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/v1/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/compatibility/v1/chat/completions" + groq: + id: "groq" + url: "https://api.groq.com/openai/v1" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + openai: + id: "openai" + url: "https://api.openai.com/v1" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + cloudflare: + id: "cloudflare" + url: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/finetunes/public?limit=1000" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/v1/chat/completions" + deepseek: + id: "deepseek" + url: "https://api.deepseek.com" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" ProviderSpecificResponse: type: object description: | Provider-specific response format. Examples: - OpenAI GET /v1/models response: + OpenAI GET /v1/models?provider=openai response: ```json { + "provider": "openai", + "object": "list", "data": [ { "id": "gpt-4", "object": "model", - "created": 1687882410 + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai" } ] } ``` - Anthropic GET /v1/models response: + Anthropic GET /v1/models?provider=anthropic response: ```json { - "models": [ + "provider": "anthropic", + "object": "list", + "data": [ { - "name": "claude-3-opus-20240229", - "description": "Most capable model for highly complex tasks" + "id": "gpt-4", + "object": "model", + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai" } ] } ``` - additionalProperties: true ProviderAuthType: type: string description: Authentication type for providers @@ -325,6 +537,31 @@ components: - xheader - query - none + SSEvent: + type: object + properties: + event: + type: string + enum: + - message-start + - stream-start + - content-start + - content-delta + - content-end + - message-end + - stream-end + data: + type: string + format: byte + retry: + type: integer + Endpoints: + type: object + properties: + models: + type: string + chat: + type: string Error: type: object properties: @@ -337,6 +574,7 @@ components: - system - user - assistant + - tool Message: type: object description: Message structure for provider requests @@ -345,72 +583,474 @@ components: $ref: "#/components/schemas/MessageRole" content: type: string + tool_calls: + type: array + items: + $ref: "#/components/schemas/ChatCompletionMessageToolCall" + tool_call_id: + type: string + reasoning: + type: string + reasoning_content: + type: string + required: + - role + - content Model: type: object description: Common model information properties: - name: + id: + type: string + object: type: string + created: + type: integer + format: int64 + owned_by: + type: string + served_by: + $ref: "#/components/schemas/Provider" ListModelsResponse: type: object description: Response structure for listing models properties: provider: - $ref: "#/components/schemas/Providers" - models: + $ref: "#/components/schemas/Provider" + object: + type: string + data: type: array items: $ref: "#/components/schemas/Model" - GenerateRequest: + default: [] + FunctionObject: type: object - description: Request structure for token generation + properties: + description: + type: string + description: + A description of what the function does, used by the model to + choose when and how to call the function. + name: + type: string + description: + The name of the function to be called. Must be a-z, A-Z, 0-9, or + contain underscores and dashes, with a maximum length of 64. + parameters: + $ref: "#/components/schemas/FunctionParameters" + strict: + type: boolean + default: false + description: + Whether to enable strict schema adherence when generating the + function call. If set to true, the model will follow the exact + schema defined in the `parameters` field. Only a subset of JSON + Schema is supported when `strict` is `true`. Learn more about + Structured Outputs in the [function calling + guide](docs/guides/function-calling). required: - - model - - messages + - name + ChatCompletionTool: + type: object + properties: + type: + $ref: "#/components/schemas/ChatCompletionToolType" + function: + $ref: "#/components/schemas/FunctionObject" + required: + - type + - function + FunctionParameters: + type: object + description: >- + The parameters the functions accepts, described as a JSON Schema object. + See the [guide](/docs/guides/function-calling) for examples, and the + [JSON Schema + reference](https://json-schema.org/understanding-json-schema/) for + documentation about the format. + + Omitting `parameters` defines a function with an empty parameter list. + properties: + type: + type: string + description: The type of the parameters. Currently, only `object` is supported. + properties: + type: object + description: The properties of the parameters. + required: + type: array + items: + type: string + description: The required properties of the parameters. + ChatCompletionToolType: + type: string + description: The type of the tool. Currently, only `function` is supported. + enum: + - function + CompletionUsage: + type: object + description: Usage statistics for the completion request. + properties: + completion_tokens: + type: integer + default: 0 + format: int64 + description: Number of tokens in the generated completion. + prompt_tokens: + type: integer + default: 0 + format: int64 + description: Number of tokens in the prompt. + total_tokens: + type: integer + default: 0 + format: int64 + description: Total number of tokens used in the request (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens + ChatCompletionStreamOptions: + description: > + Options for streaming response. Only set this when you set `stream: + true`. + type: object + properties: + include_usage: + type: boolean + description: > + If set, an additional chunk will be streamed before the `data: + [DONE]` message. The `usage` field on this chunk shows the token + usage statistics for the entire request, and the `choices` field + will always be an empty array. All other chunks will also include a + `usage` field, but with a null value. + default: true + CreateChatCompletionRequest: + type: object properties: model: type: string + description: Model ID to use messages: + description: > + A list of messages comprising the conversation so far. type: array + minItems: 1 items: $ref: "#/components/schemas/Message" + max_tokens: + description: > + An upper bound for the number of tokens that can be generated + for a completion, including visible output tokens and reasoning tokens. + type: integer stream: + description: > + If set to true, the model response data will be streamed to the + client as it is generated using [server-sent + events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format). type: boolean default: false - description: Whether to stream tokens as they are generated in raw json - ssevents: - type: boolean - default: false - description: | - Whether to use Server-Sent Events for token generation. - When enabled, the response will be streamed as SSE with the following event types: - - message-start: Initial message event with assistant role - - stream-start: Stream initialization - - content-start: Content beginning - - content-delta: Content update with new tokens - - content-end: Content completion - - message-end: Message completion - - stream-end: Stream completion + stream_options: + $ref: "#/components/schemas/ChatCompletionStreamOptions" + tools: + type: array + description: > + A list of tools the model may call. Currently, only functions + are supported as a tool. Use this to provide a list of functions + the model may generate JSON inputs for. A max of 128 functions + are supported. + items: + $ref: "#/components/schemas/ChatCompletionTool" + required: + - model + - messages + ChatCompletionMessageToolCallFunction: + type: object + description: The function that the model called. + properties: + name: + type: string + description: The name of the function to call. + arguments: + type: string + description: + The arguments to call the function with, as generated by the model + in JSON format. Note that the model does not always generate + valid JSON, and may hallucinate parameters not defined by your + function schema. Validate the arguments in your code before + calling your function. + required: + - name + - arguments + ChatCompletionMessageToolCall: + type: object + properties: + id: + type: string + description: The ID of the tool call. + type: + $ref: "#/components/schemas/ChatCompletionToolType" + function: + $ref: "#/components/schemas/ChatCompletionMessageToolCallFunction" + required: + - id + - type + - function + ChatCompletionChoice: + type: object + properties: + finish_reason: + type: string + description: > + The reason the model stopped generating tokens. This will be + `stop` if the model hit a natural stop point or a provided + stop sequence, + + `length` if the maximum number of tokens specified in the + request was reached, - **Note:** Depending on the provider, some events may not be present. - ResponseTokens: + `content_filter` if content was omitted due to a flag from our + content filters, + + `tool_calls` if the model called a tool. + enum: + - stop + - length + - tool_calls + - content_filter + - function_call + index: + type: integer + description: The index of the choice in the list of choices. + message: + $ref: "#/components/schemas/Message" + required: + - finish_reason + - index + - message + - logprobs + ChatCompletionStreamChoice: type: object - description: Token response structure + required: + - delta + - finish_reason + - index properties: - role: + delta: + $ref: "#/components/schemas/ChatCompletionStreamResponseDelta" + logprobs: + description: Log probability information for the choice. + type: object + properties: + content: + description: A list of message content tokens with log probability information. + type: array + items: + $ref: "#/components/schemas/ChatCompletionTokenLogprob" + refusal: + description: A list of message refusal tokens with log probability information. + type: array + items: + $ref: "#/components/schemas/ChatCompletionTokenLogprob" + required: + - content + - refusal + finish_reason: + $ref: "#/components/schemas/FinishReason" + index: + type: integer + description: The index of the choice in the list of choices. + CreateChatCompletionResponse: + type: object + description: + Represents a chat completion response returned by model, based on + the provided input. + properties: + id: type: string + description: A unique identifier for the chat completion. + choices: + type: array + description: + A list of chat completion choices. Can be more than one if `n` is + greater than 1. + items: + $ref: "#/components/schemas/ChatCompletionChoice" + created: + type: integer + description: + The Unix timestamp (in seconds) of when the chat completion was + created. model: type: string + description: The model used for the chat completion. + object: + type: string + description: The object type, which is always `chat.completion`. + x-stainless-const: true + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - choices + - created + - id + - model + - object + ChatCompletionStreamResponseDelta: + type: object + description: A chat completion delta generated by streamed model responses. + properties: content: type: string - GenerateResponse: + description: The contents of the chunk message. + tool_calls: + type: array + items: + $ref: "#/components/schemas/ChatCompletionMessageToolCallChunk" + role: + $ref: "#/components/schemas/MessageRole" + refusal: + type: string + description: The refusal message generated by the model. + ChatCompletionMessageToolCallChunk: type: object - description: Response structure for token generation properties: - provider: + index: + type: integer + id: + type: string + description: The ID of the tool call. + type: + type: string + description: The type of the tool. Currently, only `function` is supported. + function: + type: object + properties: + name: + type: string + description: The name of the function to call. + arguments: + type: string + description: + The arguments to call the function with, as generated by the model + in JSON format. Note that the model does not always generate + valid JSON, and may hallucinate parameters not defined by your + function schema. Validate the arguments in your code before + calling your function. + required: + - index + ChatCompletionTokenLogprob: + type: object + properties: + token: &a1 + description: The token. + type: string + logprob: &a2 + description: + The log probability of this token, if it is within the top 20 most + likely tokens. Otherwise, the value `-9999.0` is used to signify + that the token is very unlikely. + type: number + bytes: &a3 + description: + A list of integers representing the UTF-8 bytes representation of + the token. Useful in instances where characters are represented by + multiple tokens and their byte representations must be combined to + generate the correct text representation. Can be `null` if there is + no bytes representation for the token. + type: array + items: + type: integer + top_logprobs: + description: + List of the most likely tokens and their log probability, at this + token position. In rare cases, there may be fewer than the number of + requested `top_logprobs` returned. + type: array + items: + type: object + properties: + token: *a1 + logprob: *a2 + bytes: *a3 + required: + - token + - logprob + - bytes + required: + - token + - logprob + - bytes + - top_logprobs + FinishReason: + type: string + description: > + The reason the model stopped generating tokens. This will be + `stop` if the model hit a natural stop point or a provided + stop sequence, + + `length` if the maximum number of tokens specified in the + request was reached, + + `content_filter` if content was omitted due to a flag from our + content filters, + + `tool_calls` if the model called a tool. + enum: + - stop + - length + - tool_calls + - content_filter + - function_call + CreateChatCompletionStreamResponse: + type: object + description: | + Represents a streamed chunk of a chat completion response returned + by the model, based on the provided input. + properties: + id: type: string - response: - $ref: "#/components/schemas/ResponseTokens" + description: + A unique identifier for the chat completion. Each chunk has the + same ID. + choices: + type: array + description: > + A list of chat completion choices. Can contain more than one + elements if `n` is greater than 1. Can also be empty for the + + last chunk if you set `stream_options: {"include_usage": true}`. + items: + $ref: "#/components/schemas/ChatCompletionStreamChoice" + created: + type: integer + description: + The Unix timestamp (in seconds) of when the chat completion was + created. Each chunk has the same timestamp. + model: + type: string + description: The model to generate the completion. + system_fingerprint: + type: string + description: > + This fingerprint represents the backend configuration that the model + runs with. + + Can be used in conjunction with the `seed` request parameter to + understand when backend changes have been made that might impact + determinism. + object: + type: string + description: The object type, which is always `chat.completion.chunk`. + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - choices + - created + - id + - model + - object Config: x-config: sections: @@ -526,7 +1166,7 @@ components: - name: anthropic_api_url env: "ANTHROPIC_API_URL" type: string - default: "https://api.anthropic.com" + default: "https://api.anthropic.com/v1" description: "Anthropic API URL" - name: anthropic_api_key env: "ANTHROPIC_API_KEY" @@ -536,7 +1176,7 @@ components: - name: cloudflare_api_url env: "CLOUDFLARE_API_URL" type: string - default: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}" + default: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai" description: "Cloudflare API URL" - name: cloudflare_api_key env: "CLOUDFLARE_API_KEY" @@ -546,7 +1186,7 @@ components: - name: cohere_api_url env: "COHERE_API_URL" type: string - default: "https://api.cohere.com" + default: "https://api.cohere.ai" description: "Cohere API URL" - name: cohere_api_key env: "COHERE_API_KEY" @@ -556,7 +1196,7 @@ components: - name: groq_api_url env: "GROQ_API_URL" type: string - default: "https://api.groq.com" + default: "https://api.groq.com/openai/v1" description: "Groq API URL" - name: groq_api_key env: "GROQ_API_KEY" @@ -566,7 +1206,7 @@ components: - name: ollama_api_url env: "OLLAMA_API_URL" type: string - default: "http://ollama:8080" + default: "http://ollama:8080/v1" description: "Ollama API URL" - name: ollama_api_key env: "OLLAMA_API_KEY" @@ -576,451 +1216,20 @@ components: - name: openai_api_url env: "OPENAI_API_URL" type: string - default: "https://api.openai.com" + default: "https://api.openai.com/v1" description: "OpenAI API URL" - name: openai_api_key env: "OPENAI_API_KEY" type: string description: "OpenAI API Key" secret: true - x-provider-configs: - ollama: - id: "ollama" - url: "http://ollama:8080" - auth_type: "none" - endpoints: - list: - endpoint: "/api/tags" - method: "GET" - schema: - response: - type: object - properties: - models: - type: array - items: - type: object - properties: - name: - type: string - modified_at: - type: string - size: - type: integer - digest: - type: string - details: - type: object - properties: - format: - type: string - family: - type: string - families: - type: array - items: - type: string - parameter_size: - type: string - generate: - endpoint: "/api/generate" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - prompt: - type: string - stream: - type: boolean - system: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - provider: - type: string - response: - type: object - properties: - role: - type: string - model: - type: string - content: - type: string - openai: - id: "openai" - url: "https://api.openai.com" - auth_type: "bearer" - endpoints: - list: - endpoint: "/v1/models" - method: "GET" - schema: - response: - type: object - properties: - object: - type: string - data: - type: array - items: - type: object - properties: - id: - type: string - object: - type: string - created: - type: integer - format: int64 - owned_by: - type: string - permission: - type: array - items: - type: object - properties: - id: - type: string - object: - type: string - created: - type: integer - format: int64 - allow_create_engine: - type: boolean - allow_sampling: - type: boolean - allow_logprobs: - type: boolean - allow_search_indices: - type: boolean - allow_view: - type: boolean - allow_fine_tuning: - type: boolean - root: - type: string - parent: - type: string - generate: - endpoint: "/v1/chat/completions" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - model: - type: string - choices: - type: array - items: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: string - groq: - id: "groq" - url: "https://api.groq.com" - auth_type: "bearer" - endpoints: - list: - endpoint: "/openai/v1/models" - method: "GET" - schema: - response: - type: object - properties: - object: - type: string - data: - type: array - items: - type: object - properties: - id: - type: string - object: - type: string - created: - type: integer - format: int64 - owned_by: - type: string - active: - type: boolean - context_window: - type: integer - public_apps: - type: object - generate: - endpoint: "/openai/v1/chat/completions" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - model: - type: string - choices: - type: array - items: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: string - cloudflare: - id: "cloudflare" - url: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}" - auth_type: "bearer" - endpoints: - list: - endpoint: "/ai/finetunes/public" - method: "GET" - schema: - response: - type: object - properties: - result: - type: array - items: - type: object - properties: - id: - type: string - name: - type: string - description: - type: string - created_at: - type: string - modified_at: - type: string - public: - type: integer - model: - type: string - generate: - endpoint: "/v1/chat/completions" - method: "POST" - schema: - request: - type: object - properties: - prompt: - type: string - model: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - result: - type: object - properties: - response: - type: string - cohere: - id: "cohere" - url: "https://api.cohere.com" - auth_type: "bearer" - endpoints: - list: - endpoint: "/v1/models" - method: "GET" - schema: - response: - type: object - properties: - models: - type: array - items: - type: object - properties: - name: - type: string - endpoints: - type: array - items: - type: string - finetuned: - type: boolean - context_length: - type: number - format: float64 - tokenizer_url: - type: string - default_endpoints: - type: array - items: - type: string - next_page_token: - type: string - generate: - endpoint: "/v2/chat" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: array - items: - type: object - properties: - type: - type: string - text: - type: string - anthropic: - id: "anthropic" - url: "https://api.anthropic.com" - auth_type: "xheader" - extra_headers: - anthropic-version: "2023-06-01" - endpoints: - list: - endpoint: "/v1/models" - method: "GET" - schema: - response: - type: object - properties: - models: - type: array - items: - type: object - properties: - type: - type: string - id: - type: string - display_name: - type: string - created_at: - type: string - has_more: - type: boolean - first_id: - type: string - last_id: - type: string - generate: - endpoint: "/v1/messages" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - model: - type: string - choices: - type: array - items: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: string + - name: deepseek_api_url + env: "DEEPSEEK_API_URL" + type: string + default: "https://api.deepseek.com" + description: "DeepSeek API URL" + - name: deepseek_api_key + env: "DEEPSEEK_API_KEY" + type: string + description: "DeepSeek API Key" + secret: true diff --git a/src/client.ts b/src/client.ts index 6ec974c..26bf294 100644 --- a/src/client.ts +++ b/src/client.ts @@ -1,140 +1,281 @@ import { - GenerateContentOptions, - GenerateContentRequest, - GenerateContentResponse, + Error as ApiError, + ChatCompletionMessageToolCall, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionStreamCallbacks, + ChatCompletionStreamResponse, + ListModelsResponse, Provider, - ProviderModels, } from './types'; +export interface ClientOptions { + baseURL?: string; + apiKey?: string; + defaultHeaders?: Record; + defaultQuery?: Record; + timeout?: number; + fetch?: typeof globalThis.fetch; +} + export class InferenceGatewayClient { - private baseUrl: string; - private authToken?: string; + private baseURL: string; + private apiKey?: string; + private defaultHeaders: Record; + private defaultQuery: Record; + private timeout: number; + private fetchFn: typeof globalThis.fetch; - constructor(baseUrl: string, authToken?: string) { - this.baseUrl = baseUrl.replace(/\/$/, ''); - this.authToken = authToken; + constructor(options: ClientOptions = {}) { + this.baseURL = options.baseURL || 'http://localhost:8080/v1'; + this.apiKey = options.apiKey; + this.defaultHeaders = options.defaultHeaders || {}; + this.defaultQuery = options.defaultQuery || {}; + this.timeout = options.timeout || 30000; + this.fetchFn = options.fetch || globalThis.fetch; } + /** + * Creates a new instance of the client with the given options merged with the existing options. + */ + withOptions(options: ClientOptions): InferenceGatewayClient { + return new InferenceGatewayClient({ + baseURL: options.baseURL || this.baseURL, + apiKey: options.apiKey || this.apiKey, + defaultHeaders: { ...this.defaultHeaders, ...options.defaultHeaders }, + defaultQuery: { ...this.defaultQuery, ...options.defaultQuery }, + timeout: options.timeout || this.timeout, + fetch: options.fetch || this.fetchFn, + }); + } + + /** + * Makes a request to the API. + */ private async request( path: string, - options: RequestInit = {} + options: RequestInit = {}, + query: Record = {} ): Promise { const headers = new Headers({ 'Content-Type': 'application/json', + ...this.defaultHeaders, ...(options.headers as Record), }); - if (this.authToken) { - headers.set('Authorization', `Bearer ${this.authToken}`); + if (this.apiKey) { + headers.set('Authorization', `Bearer ${this.apiKey}`); } - const response = await fetch(`${this.baseUrl}${path}`, { - ...options, - headers, + // Combine default query parameters with provided ones + const queryParams = new URLSearchParams({ + ...this.defaultQuery, + ...query, }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.error || `HTTP error! status: ${response.status}`); - } + const queryString = queryParams.toString(); + const url = `${this.baseURL}${path}${queryString ? `?${queryString}` : ''}`; - return response.json(); - } + const controller = new AbortController(); + const timeoutId = globalThis.setTimeout( + () => controller.abort(), + this.timeout + ); + + try { + const response = await this.fetchFn(url, { + ...options, + headers, + signal: controller.signal, + }); + + if (!response.ok) { + const error = (await response.json()) as ApiError; + throw new Error( + error.error || `HTTP error! status: ${response.status}` + ); + } - async listModels(): Promise { - return this.request('/llms'); + return response.json(); + } finally { + globalThis.clearTimeout(timeoutId); + } } - async listModelsByProvider(provider: Provider): Promise { - return this.request(`/llms/${provider}`); + /** + * Lists the currently available models. + */ + async listModels(provider?: Provider): Promise { + const query: Record = {}; + if (provider) { + query.provider = provider; + } + return this.request( + '/models', + { method: 'GET' }, + query + ); } - async generateContent( - params: GenerateContentRequest - ): Promise { - return this.request( - `/llms/${params.provider}/generate`, + /** + * Creates a chat completion. + */ + async createChatCompletion( + request: ChatCompletionRequest, + provider?: Provider + ): Promise { + const query: Record = {}; + if (provider) { + query.provider = provider; + } + return this.request( + '/chat/completions', { method: 'POST', - body: JSON.stringify({ - model: params.model, - messages: params.messages, - }), - } + body: JSON.stringify(request), + }, + query ); } - async generateContentStream( - params: GenerateContentRequest, - options?: GenerateContentOptions + /** + * Creates a streaming chat completion. + */ + async streamChatCompletion( + request: ChatCompletionRequest, + callbacks: ChatCompletionStreamCallbacks, + provider?: Provider ): Promise { - const response = await fetch( - `${this.baseUrl}/llms/${params.provider}/generate`, - { + const query: Record = {}; + if (provider) { + query.provider = provider; + } + + const queryParams = new URLSearchParams({ + ...this.defaultQuery, + ...query, + }); + + const queryString = queryParams.toString(); + const url = `${this.baseURL}/chat/completions${queryString ? `?${queryString}` : ''}`; + + const headers = new Headers({ + 'Content-Type': 'application/json', + ...this.defaultHeaders, + }); + + if (this.apiKey) { + headers.set('Authorization', `Bearer ${this.apiKey}`); + } + + const controller = new AbortController(); + const timeoutId = globalThis.setTimeout( + () => controller.abort(), + this.timeout + ); + + try { + const response = await this.fetchFn(url, { method: 'POST', - headers: { - 'Content-Type': 'application/json', - ...(this.authToken && { Authorization: `Bearer ${this.authToken}` }), - }, + headers, body: JSON.stringify({ - model: params.model, - messages: params.messages, + ...request, stream: true, - ssevents: true, }), + signal: controller.signal, + }); + + if (!response.ok) { + const error = (await response.json()) as ApiError; + throw new Error( + error.error || `HTTP error! status: ${response.status}` + ); } - ); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.error || `HTTP error! status: ${response.status}`); - } + if (!response.body) { + throw new Error('Response body is not readable'); + } - const reader = response.body?.getReader(); - if (!reader) throw new Error('Response body is not readable'); - - const decoder = new TextDecoder(); - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - const events = decoder.decode(value).split('\n\n'); - for (const event of events) { - if (!event.trim()) continue; - - const [eventType, ...data] = event.split('\n'); - const eventData = JSON.parse(data.join('\n').replace('data: ', '')); - - switch (eventType.replace('event: ', '')) { - case 'message-start': - options?.onMessageStart?.(eventData.role); - break; - case 'stream-start': - options?.onStreamStart?.(); - break; - case 'content-start': - options?.onContentStart?.(); - break; - case 'content-delta': - options?.onContentDelta?.(eventData.content); - break; - case 'content-end': - options?.onContentEnd?.(); - break; - case 'message-end': - options?.onMessageEnd?.(); - break; - case 'stream-end': - options?.onStreamEnd?.(); - break; + callbacks.onOpen?.(); + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split('\n'); + buffer = lines.pop() || ''; + + for (const line of lines) { + if (line.startsWith('data: ')) { + const data = line.slice(5).trim(); + + if (data === '[DONE]') { + callbacks.onFinish?.( + null as unknown as ChatCompletionStreamResponse + ); + return; + } + + try { + const chunk = JSON.parse(data) as ChatCompletionStreamResponse; + callbacks.onChunk?.(chunk); + + const content = chunk.choices[0]?.delta?.content; + if (content) { + callbacks.onContent?.(content); + } + + const toolCalls = chunk.choices[0]?.delta?.tool_calls; + if (toolCalls && toolCalls.length > 0) { + const toolCall: ChatCompletionMessageToolCall = { + id: toolCalls[0].id || '', + type: 'function', + function: { + name: toolCalls[0].function?.name || '', + arguments: toolCalls[0].function?.arguments || '', + }, + }; + callbacks.onTool?.(toolCall); + } + } catch (e) { + globalThis.console.error('Error parsing SSE data:', e); + } + } } } + } catch (error) { + const apiError: ApiError = { + error: (error as Error).message || 'Unknown error', + }; + callbacks.onError?.(apiError); + throw error; + } finally { + globalThis.clearTimeout(timeoutId); } } + /** + * Proxy a request to a specific provider. + */ + async proxy( + provider: Provider, + path: string, + options: RequestInit = {} + ): Promise { + return this.request(`/proxy/${provider}/${path}`, options); + } + + /** + * Health check endpoint. + */ async healthCheck(): Promise { try { - await this.request('/health'); + await this.fetchFn(`${this.baseURL.replace('/v1', '')}/health`); return true; } catch { return false; diff --git a/src/types/index.ts b/src/types/index.ts index 4222ed3..bbac205 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -2,52 +2,158 @@ export enum Provider { Ollama = 'ollama', Groq = 'groq', OpenAI = 'openai', - Google = 'google', Cloudflare = 'cloudflare', Cohere = 'cohere', Anthropic = 'anthropic', + DeepSeek = 'deepseek', } export enum MessageRole { System = 'system', User = 'user', Assistant = 'assistant', + Tool = 'tool', } + export interface Message { role: MessageRole; content: string; + tool_calls?: ChatCompletionMessageToolCall[]; + tool_call_id?: string; } export interface Model { + id: string; + object: string; + created: number; + owned_by: string; +} + +export interface ListModelsResponse { + object: string; + data: Model[]; +} + +export interface ChatCompletionMessageToolCallFunction { + name: string; + arguments: string; +} + +export interface ChatCompletionMessageToolCall { + id: string; + type: 'function'; + function: ChatCompletionMessageToolCallFunction; +} + +export interface ChatCompletionMessageToolCallChunk { + index: number; + id?: string; + type?: string; + function?: { + name?: string; + arguments?: string; + }; +} + +export interface FunctionParameters { + type: string; + properties?: Record; + required?: string[]; +} + +export interface FunctionObject { + description?: string; name: string; + parameters: FunctionParameters; + strict?: boolean; } -export interface ProviderModels { - provider: Provider; - models: Model[]; +export interface ChatCompletionTool { + type: 'function'; + function: FunctionObject; } -export interface GenerateContentRequest { - provider: Provider; +export interface ChatCompletionRequest { model: string; messages: Message[]; + max_tokens?: number; + stream?: boolean; + stream_options?: ChatCompletionStreamOptions; + tools?: ChatCompletionTool[]; + temperature?: number; + top_p?: number; + top_k?: number; } -export interface GenerateContentResponse { - provider: string; - response: { - role: 'assistant'; - model: string; - content: string; - }; +export interface ChatCompletionStreamOptions { + include_usage?: boolean; +} + +export interface ChatCompletionChoice { + finish_reason: + | 'stop' + | 'length' + | 'tool_calls' + | 'content_filter' + | 'function_call'; + index: number; + message: Message; + logprobs?: Record; +} + +export interface CompletionUsage { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; +} + +export interface ChatCompletionResponse { + id: string; + choices: ChatCompletionChoice[]; + created: number; + model: string; + object: string; + usage?: CompletionUsage; +} + +export interface ChatCompletionStreamChoice { + delta: ChatCompletionStreamResponseDelta; + finish_reason: + | 'stop' + | 'length' + | 'tool_calls' + | 'content_filter' + | 'function_call' + | null; + index: number; + logprobs?: Record; +} + +export interface ChatCompletionStreamResponseDelta { + content?: string; + tool_calls?: ChatCompletionMessageToolCallChunk[]; + role?: MessageRole; + refusal?: string; +} + +export interface ChatCompletionStreamResponse { + id: string; + choices: ChatCompletionStreamChoice[]; + created: number; + model: string; + object: string; + usage?: CompletionUsage; +} + +export interface ChatCompletionStreamCallbacks { + onOpen?: () => void; + onChunk?: (chunk: ChatCompletionStreamResponse) => void; + onContent?: (content: string) => void; + onTool?: (toolCall: ChatCompletionMessageToolCall) => void; + onFinish?: (response: ChatCompletionStreamResponse) => void; + onError?: (error: Error) => void; } -export interface GenerateContentOptions { - onMessageStart?: (role: string) => void; - onStreamStart?: () => void; - onContentStart?: () => void; - onContentDelta?: (content: string) => void; - onContentEnd?: () => void; - onMessageEnd?: () => void; - onStreamEnd?: () => void; +export interface Error { + error: string; } diff --git a/tests/client.test.ts b/tests/client.test.ts index ca8369d..987e165 100644 --- a/tests/client.test.ts +++ b/tests/client.test.ts @@ -1,308 +1,432 @@ import { InferenceGatewayClient } from '@/client'; import { - GenerateContentResponse, + ChatCompletionResponse, + ListModelsResponse, MessageRole, Provider, - ProviderModels, } from '@/types'; +import { TransformStream } from 'node:stream/web'; +import { TextEncoder } from 'node:util'; describe('InferenceGatewayClient', () => { let client: InferenceGatewayClient; - const mockBaseUrl = 'http://localhost:8080'; + const mockFetch = jest.fn(); beforeEach(() => { - client = new InferenceGatewayClient(mockBaseUrl); - global.fetch = jest.fn(); + client = new InferenceGatewayClient({ + baseURL: 'http://localhost:8080/v1', + fetch: mockFetch, + }); + }); + + afterEach(() => { + jest.clearAllMocks(); }); describe('listModels', () => { it('should fetch available models', async () => { - const mockResponse: ProviderModels[] = [ - { - provider: Provider.Ollama, - models: [ - { - name: 'llama2', - }, - ], - }, - ]; + const mockResponse: ListModelsResponse = { + object: 'list', + data: [ + { + id: 'gpt-4o', + object: 'model', + created: 1686935002, + owned_by: 'openai', + }, + { + id: 'llama-3.3-70b-versatile', + object: 'model', + created: 1723651281, + owned_by: 'groq', + }, + ], + }; - (global.fetch as jest.Mock).mockResolvedValueOnce({ + mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve(mockResponse), }); const result = await client.listModels(); expect(result).toEqual(mockResponse); - expect(global.fetch).toHaveBeenCalledWith( - `${mockBaseUrl}/llms`, + expect(mockFetch).toHaveBeenCalledWith( + 'http://localhost:8080/v1/models', expect.objectContaining({ + method: 'GET', headers: expect.any(Headers), }) ); }); - }); - describe('listModelsByProvider', () => { it('should fetch models for a specific provider', async () => { - const mockResponse: ProviderModels = { - provider: Provider.OpenAI, - models: [ + const mockResponse: ListModelsResponse = { + object: 'list', + data: [ { - name: 'gpt-4', + id: 'gpt-4o', + object: 'model', + created: 1686935002, + owned_by: 'openai', }, ], }; - (global.fetch as jest.Mock).mockResolvedValueOnce({ + mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve(mockResponse), }); - const result = await client.listModelsByProvider(Provider.OpenAI); + const result = await client.listModels(Provider.OpenAI); expect(result).toEqual(mockResponse); - expect(global.fetch).toHaveBeenCalledWith( - `${mockBaseUrl}/llms/${Provider.OpenAI}`, + expect(mockFetch).toHaveBeenCalledWith( + 'http://localhost:8080/v1/models?provider=openai', expect.objectContaining({ + method: 'GET', headers: expect.any(Headers), }) ); }); - it('should throw error when provider request fails', async () => { + it('should throw error when request fails', async () => { const errorMessage = 'Provider not found'; - (global.fetch as jest.Mock).mockResolvedValueOnce({ + mockFetch.mockResolvedValueOnce({ ok: false, status: 404, json: () => Promise.resolve({ error: errorMessage }), }); - await expect( - client.listModelsByProvider(Provider.OpenAI) - ).rejects.toThrow(errorMessage); + await expect(client.listModels(Provider.OpenAI)).rejects.toThrow( + errorMessage + ); }); }); - describe('generateContent', () => { - it('should generate content with the specified provider', async () => { + describe('createChatCompletion', () => { + it('should create a chat completion', async () => { const mockRequest = { - provider: Provider.Ollama, - model: 'llama2', + model: 'gpt-4o', messages: [ { role: MessageRole.System, content: 'You are a helpful assistant' }, { role: MessageRole.User, content: 'Hello' }, ], }; - const mockResponse: GenerateContentResponse = { - provider: Provider.Ollama, - response: { - role: MessageRole.Assistant, - model: 'llama2', - content: 'Hi there!', + const mockResponse: ChatCompletionResponse = { + id: 'chatcmpl-123', + object: 'chat.completion', + created: 1677652288, + model: 'gpt-4o', + choices: [ + { + index: 0, + message: { + role: MessageRole.Assistant, + content: 'Hello! How can I help you today?', + }, + finish_reason: 'stop', + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 8, + total_tokens: 18, }, }; - (global.fetch as jest.Mock).mockResolvedValueOnce({ + mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve(mockResponse), }); - const result = await client.generateContent(mockRequest); + const result = await client.createChatCompletion(mockRequest); expect(result).toEqual(mockResponse); - expect(global.fetch).toHaveBeenCalledWith( - `${mockBaseUrl}/llms/${mockRequest.provider}/generate`, + expect(mockFetch).toHaveBeenCalledWith( + 'http://localhost:8080/v1/chat/completions', expect.objectContaining({ method: 'POST', - body: JSON.stringify({ - model: mockRequest.model, - messages: mockRequest.messages, - }), + body: JSON.stringify(mockRequest), }) ); }); - }); - - describe('healthCheck', () => { - it('should return true when API is healthy', async () => { - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - json: () => Promise.resolve({}), - }); - const result = await client.healthCheck(); - expect(result).toBe(true); - expect(global.fetch).toHaveBeenCalledWith( - `${mockBaseUrl}/health`, - expect.any(Object) - ); - }); - - it('should return false when API is unhealthy', async () => { - (global.fetch as jest.Mock).mockRejectedValueOnce(new Error('API error')); + it('should create a chat completion with a specific provider', async () => { + const mockRequest = { + model: 'claude-3-opus-20240229', + messages: [{ role: MessageRole.User, content: 'Hello' }], + }; - const result = await client.healthCheck(); - expect(result).toBe(false); - }); - }); + const mockResponse: ChatCompletionResponse = { + id: 'chatcmpl-456', + object: 'chat.completion', + created: 1677652288, + model: 'claude-3-opus-20240229', + choices: [ + { + index: 0, + message: { + role: MessageRole.Assistant, + content: 'Hello! How can I assist you today?', + }, + finish_reason: 'stop', + }, + ], + usage: { + prompt_tokens: 5, + completion_tokens: 8, + total_tokens: 13, + }, + }; - describe('error handling', () => { - it('should throw error when API request fails', async () => { - const errorMessage = 'Bad Request'; - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: false, - status: 400, - json: () => Promise.resolve({ error: errorMessage }), + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockResponse), }); - await expect(client.listModels()).rejects.toThrow(errorMessage); + const result = await client.createChatCompletion( + mockRequest, + Provider.Anthropic + ); + expect(result).toEqual(mockResponse); + expect(mockFetch).toHaveBeenCalledWith( + 'http://localhost:8080/v1/chat/completions?provider=anthropic', + expect.objectContaining({ + method: 'POST', + body: JSON.stringify(mockRequest), + }) + ); }); }); - describe('generateContentStream', () => { - it('should handle SSE events correctly', async () => { + describe('streamChatCompletion', () => { + it('should handle streaming chat completions', async () => { const mockRequest = { - provider: Provider.Ollama, - model: 'llama2', - messages: [ - { role: MessageRole.System, content: 'You are a helpful assistant' }, - { role: MessageRole.User, content: 'Hello' }, - ], + model: 'gpt-4o', + messages: [{ role: MessageRole.User, content: 'Hello' }], }; const mockStream = new TransformStream(); const writer = mockStream.writable.getWriter(); const encoder = new TextEncoder(); - (global.fetch as jest.Mock).mockResolvedValueOnce({ + mockFetch.mockResolvedValueOnce({ ok: true, body: mockStream.readable, }); const callbacks = { - onMessageStart: jest.fn(), - onStreamStart: jest.fn(), - onContentStart: jest.fn(), - onContentDelta: jest.fn(), - onContentEnd: jest.fn(), - onMessageEnd: jest.fn(), - onStreamEnd: jest.fn(), + onOpen: jest.fn(), + onChunk: jest.fn(), + onContent: jest.fn(), + onFinish: jest.fn(), + onError: jest.fn(), }; - const streamPromise = client.generateContentStream( - mockRequest, - callbacks - ); + const streamPromise = client.streamChatCompletion(mockRequest, callbacks); + // Simulate SSE events await writer.write( encoder.encode( - 'event: message-start\ndata: {"role": "assistant"}\n\n' + - 'event: stream-start\ndata: {}\n\n' + - 'event: content-start\ndata: {}\n\n' + - 'event: content-delta\ndata: {"content": "Hello"}\n\n' + - 'event: content-delta\ndata: {"content": " there!"}\n\n' + - 'event: content-end\ndata: {}\n\n' + - 'event: message-end\ndata: {}\n\n' + - 'event: stream-end\ndata: {}\n\n' + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}\n\n' + + 'data: [DONE]\n\n' ) ); await writer.close(); await streamPromise; - expect(callbacks.onMessageStart).toHaveBeenCalledWith('assistant'); - expect(callbacks.onStreamStart).toHaveBeenCalledTimes(1); - expect(callbacks.onContentStart).toHaveBeenCalledTimes(1); - expect(callbacks.onContentDelta).toHaveBeenCalledWith('Hello'); - expect(callbacks.onContentDelta).toHaveBeenCalledWith(' there!'); - expect(callbacks.onContentEnd).toHaveBeenCalledTimes(1); - expect(callbacks.onMessageEnd).toHaveBeenCalledTimes(1); - expect(callbacks.onStreamEnd).toHaveBeenCalledTimes(1); - expect(global.fetch).toHaveBeenCalledWith( - `${mockBaseUrl}/llms/${mockRequest.provider}/generate`, + expect(callbacks.onOpen).toHaveBeenCalledTimes(1); + expect(callbacks.onChunk).toHaveBeenCalledTimes(4); + expect(callbacks.onContent).toHaveBeenCalledWith('Hello'); + expect(callbacks.onContent).toHaveBeenCalledWith('!'); + expect(callbacks.onFinish).toHaveBeenCalledTimes(1); + expect(mockFetch).toHaveBeenCalledWith( + 'http://localhost:8080/v1/chat/completions', expect.objectContaining({ method: 'POST', body: JSON.stringify({ - model: mockRequest.model, - messages: mockRequest.messages, + ...mockRequest, stream: true, - ssevents: true, }), }) ); }); - it('should handle errors in the stream response', async () => { + it('should handle tool calls in streaming chat completions', async () => { + const mockRequest = { + model: 'gpt-4o', + messages: [ + { + role: MessageRole.User, + content: 'What is the weather in San Francisco?', + }, + ], + tools: [ + { + type: 'function' as const, + function: { + name: 'get_weather', + parameters: { + type: 'object', + properties: { + location: { + type: 'string', + description: 'The city and state, e.g. San Francisco, CA', + }, + }, + required: ['location'], + }, + }, + }, + ], + }; + + const mockStream = new TransformStream(); + const writer = mockStream.writable.getWriter(); + const encoder = new TextEncoder(); + + mockFetch.mockResolvedValueOnce({ + ok: true, + body: mockStream.readable, + }); + + const callbacks = { + onOpen: jest.fn(), + onChunk: jest.fn(), + onTool: jest.fn(), + onFinish: jest.fn(), + }; + + const streamPromise = client.streamChatCompletion(mockRequest, callbacks); + + // Simulate SSE events with tool calls + await writer.write( + encoder.encode( + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_123","type":"function","function":{"name":"get_weather"}}]},"finish_reason":null}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"location\\""}}]},"finish_reason":null}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":":\\"San Francisco, CA\\""}}]},"finish_reason":null}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"finish_reason":null}]}\n\n' + + 'data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1677652288,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}\n\n' + + 'data: [DONE]\n\n' + ) + ); + + await writer.close(); + await streamPromise; + + expect(callbacks.onOpen).toHaveBeenCalledTimes(1); + expect(callbacks.onChunk).toHaveBeenCalledTimes(6); + expect(callbacks.onTool).toHaveBeenCalledTimes(4); // Called for each chunk with tool_calls + expect(callbacks.onFinish).toHaveBeenCalledTimes(1); + }); + + it('should handle errors in streaming chat completions', async () => { const mockRequest = { - provider: Provider.Ollama, - model: 'llama2', + model: 'gpt-4o', messages: [{ role: MessageRole.User, content: 'Hello' }], }; - (global.fetch as jest.Mock).mockResolvedValueOnce({ + mockFetch.mockResolvedValueOnce({ ok: false, status: 400, json: () => Promise.resolve({ error: 'Bad Request' }), }); + const callbacks = { + onError: jest.fn(), + }; + await expect( - client.generateContentStream(mockRequest, {}) + client.streamChatCompletion(mockRequest, callbacks) ).rejects.toThrow('Bad Request'); + + expect(callbacks.onError).toHaveBeenCalledTimes(1); }); + }); - it('should handle non-readable response body', async () => { - const mockRequest = { - provider: Provider.Ollama, - model: 'llama2', - messages: [{ role: MessageRole.User, content: 'Hello' }], - }; + describe('proxy', () => { + it('should proxy requests to a specific provider', async () => { + const mockResponse = { result: 'success' }; - (global.fetch as jest.Mock).mockResolvedValueOnce({ + mockFetch.mockResolvedValueOnce({ ok: true, - body: null, + json: () => Promise.resolve(mockResponse), }); - await expect( - client.generateContentStream(mockRequest, {}) - ).rejects.toThrow('Response body is not readable'); + const result = await client.proxy(Provider.OpenAI, 'embeddings', { + method: 'POST', + body: JSON.stringify({ + model: 'text-embedding-ada-002', + input: 'Hello world', + }), + }); + + expect(result).toEqual(mockResponse); + expect(mockFetch).toHaveBeenCalledWith( + 'http://localhost:8080/v1/proxy/openai/embeddings', + expect.objectContaining({ + method: 'POST', + body: JSON.stringify({ + model: 'text-embedding-ada-002', + input: 'Hello world', + }), + }) + ); }); + }); - it('should handle empty events in the stream', async () => { - const mockRequest = { - provider: Provider.Ollama, - model: 'llama2', - messages: [{ role: MessageRole.User, content: 'Hello' }], - }; + describe('healthCheck', () => { + it('should return true when API is healthy', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + }); - const mockStream = new TransformStream(); - const writer = mockStream.writable.getWriter(); - const encoder = new TextEncoder(); + const result = await client.healthCheck(); + expect(result).toBe(true); + expect(mockFetch).toHaveBeenCalledWith('http://localhost:8080/health'); + }); - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - body: mockStream.readable, + it('should return false when API is unhealthy', async () => { + mockFetch.mockRejectedValueOnce(new Error('API error')); + + const result = await client.healthCheck(); + expect(result).toBe(false); + }); + }); + + describe('withOptions', () => { + it('should create a new client with merged options', () => { + const originalClient = new InferenceGatewayClient({ + baseURL: 'http://localhost:8080/v1', + apiKey: 'test-key', + fetch: mockFetch, }); - const callbacks = { - onContentDelta: jest.fn(), - }; + const newClient = originalClient.withOptions({ + defaultHeaders: { 'X-Custom-Header': 'value' }, + }); - const streamPromise = client.generateContentStream( - mockRequest, - callbacks - ); + expect(newClient).toBeInstanceOf(InferenceGatewayClient); + expect(newClient).not.toBe(originalClient); - await writer.write(encoder.encode('\n\n')); - await writer.write( - encoder.encode('event: content-delta\ndata: {"content": "Hello"}\n\n') - ); + // We can't directly test private properties, but we can test behavior + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({}), + }); - await writer.close(); - await streamPromise; + newClient.listModels(); - expect(callbacks.onContentDelta).toHaveBeenCalledTimes(1); - expect(callbacks.onContentDelta).toHaveBeenCalledWith('Hello'); + expect(mockFetch).toHaveBeenCalledWith( + 'http://localhost:8080/v1/models', + expect.objectContaining({ + headers: expect.any(Headers), + }) + ); }); }); }); From 650fb915740f827b62d435b090dd7cf7e589908c Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Mon, 31 Mar 2025 00:26:58 +0000 Subject: [PATCH 2/3] build: add GitHub CLI installation to Dockerfile Signed-off-by: Eden Reich --- .devcontainer/Dockerfile | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index da0a9f2..58ebe82 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -6,6 +6,13 @@ ENV ZSH_CUSTOM=/home/node/.oh-my-zsh/custom \ RUN apt-get update && \ # Install Task curl -s https://taskfile.dev/install.sh | sh -s -- -b /usr/local/bin ${TASK_VERSION} && \ + # Install GitHub CLI + curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg | dd of=/usr/share/keyrings/githubcli-archive-keyring.gpg && \ + chmod go+r /usr/share/keyrings/githubcli-archive-keyring.gpg && \ + echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null && \ + apt-get update && \ + apt-get install -y gh && \ + # Cleanup apt-get clean && \ rm -rf /var/lib/apt/lists/* From b31058bd65afa4b6a20cfd41a2889830e4a9fc27 Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Mon, 31 Mar 2025 00:31:29 +0000 Subject: [PATCH 3/3] chore: add deepseek keyword to package.json Signed-off-by: Eden Reich --- package.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/package.json b/package.json index 4e3dddf..334745b 100644 --- a/package.json +++ b/package.json @@ -18,7 +18,8 @@ "ollama", "cloudflare", "cohere", - "typescript" + "typescript", + "deepseek" ], "author": "Eden Reich ", "license": "MIT",