diff --git a/Taskfile.yml b/Taskfile.yml index 25f5b4a..ed323e6 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -4,7 +4,7 @@ tasks: oas-download: desc: Download OpenAPI specification cmds: - - curl -o openapi.yaml https://raw.githubusercontent.com/inference-gateway/inference-gateway/refs/heads/main/openapi.yaml + - curl -o openapi.yaml https://raw.githubusercontent.com/inference-gateway/schemas/refs/heads/main/openapi.yaml lint: desc: Run linter diff --git a/openapi.yaml b/openapi.yaml index 317a74c..0efb48b 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -31,8 +31,6 @@ tags: description: Generate completions from the models. - name: MCP description: List and manage MCP tools. - - name: A2A - description: List and manage A2A agents. - name: Proxy description: Proxy requests to provider endpoints. - name: Health @@ -96,6 +94,16 @@ paths: created: 1718441600 owned_by: 'ollama' served_by: 'ollama' + - id: 'ollama_cloud/gpt-oss:20b' + object: 'model' + created: 1730419200 + owned_by: 'ollama_cloud' + served_by: 'ollama_cloud' + - id: 'mistral/mistral-large-latest' + object: 'model' + created: 1698019200 + owned_by: 'mistral' + served_by: 'mistral' singleProvider: summary: Models from a specific provider value: @@ -179,65 +187,6 @@ paths: $ref: '#/components/responses/MCPNotExposed' '500': $ref: '#/components/responses/InternalError' - /a2a/agents: - get: - operationId: listAgents - tags: - - A2A - description: | - Lists the currently available A2A agents. Only accessible when EXPOSE_A2A is enabled. - summary: Lists the currently available A2A agents - security: - - bearerAuth: [] - responses: - '200': - description: Successful response - content: - application/json: - schema: - $ref: '#/components/schemas/ListAgentsResponse' - '401': - $ref: '#/components/responses/Unauthorized' - '403': - $ref: '#/components/responses/A2ANotExposed' - '500': - $ref: '#/components/responses/InternalError' - /a2a/agents/{id}: - get: - operationId: getAgent - tags: - - A2A - description: | - Gets a specific A2A agent by its unique identifier. Only accessible when EXPOSE_A2A is enabled. - summary: Gets a specific A2A agent by ID - security: - - bearerAuth: [] - parameters: - - name: id - in: path - required: true - schema: - type: string - description: The unique identifier of the agent - responses: - '200': - description: Successful response - content: - application/json: - schema: - $ref: '#/components/schemas/A2AAgentCard' - '401': - $ref: '#/components/responses/Unauthorized' - '403': - $ref: '#/components/responses/A2ANotExposed' - '404': - description: Agent not found - content: - application/json: - schema: - $ref: '#/components/schemas/Error' - '500': - $ref: '#/components/responses/InternalError' /proxy/{provider}/{path}: parameters: - name: provider @@ -415,6 +364,14 @@ components: - role: 'user' content: 'Explain quantum computing' temperature: 0.5 + mistral: + summary: Mistral AI request + value: + model: 'mistral-large-latest' + messages: + - role: 'user' + content: 'Write a Python function to calculate fibonacci numbers' + temperature: 0.3 CreateChatCompletionRequest: required: true description: | @@ -451,14 +408,6 @@ components: $ref: '#/components/schemas/Error' example: error: 'MCP tools endpoint is not exposed. Set EXPOSE_MCP=true to enable.' - A2ANotExposed: - description: A2A agents endpoint is not exposed - content: - application/json: - schema: - $ref: '#/components/schemas/Error' - example: - error: 'A2A agents endpoint is not exposed. Set EXPOSE_A2A=true to enable.' ProviderResponse: description: | ProviderResponse depends on the specific provider and endpoint being called @@ -489,6 +438,27 @@ components: }, ], } + mistral: + summary: Mistral AI response + value: + { + 'id': 'cmpl-123', + 'object': 'chat.completion', + 'created': 1677652288, + 'model': 'mistral-large-latest', + 'choices': + [ + { + 'index': 0, + 'message': + { + 'role': 'assistant', + 'content': 'def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)', + }, + 'finish_reason': 'stop', + }, + ], + } securitySchemes: bearerAuth: type: http @@ -503,6 +473,7 @@ components: type: string enum: - ollama + - ollama_cloud - groq - openai - cloudflare @@ -510,11 +481,27 @@ components: - anthropic - deepseek - google + - mistral x-provider-configs: ollama: id: 'ollama' url: 'http://ollama:8080/v1' auth_type: 'none' + supports_vision: true + endpoints: + models: + name: 'list_models' + method: 'GET' + endpoint: '/models' + chat: + name: 'chat_completions' + method: 'POST' + endpoint: '/chat/completions' + ollama_cloud: + id: 'ollama_cloud' + url: 'https://ollama.com/v1' + auth_type: 'bearer' + supports_vision: true endpoints: models: name: 'list_models' @@ -528,6 +515,7 @@ components: id: 'anthropic' url: 'https://api.anthropic.com/v1' auth_type: 'xheader' + supports_vision: true endpoints: models: name: 'list_models' @@ -541,6 +529,7 @@ components: id: 'cohere' url: 'https://api.cohere.ai' auth_type: 'bearer' + supports_vision: true endpoints: models: name: 'list_models' @@ -554,6 +543,7 @@ components: id: 'groq' url: 'https://api.groq.com/openai/v1' auth_type: 'bearer' + supports_vision: true endpoints: models: name: 'list_models' @@ -567,6 +557,7 @@ components: id: 'openai' url: 'https://api.openai.com/v1' auth_type: 'bearer' + supports_vision: true endpoints: models: name: 'list_models' @@ -580,6 +571,7 @@ components: id: 'cloudflare' url: 'https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai' auth_type: 'bearer' + supports_vision: false endpoints: models: name: 'list_models' @@ -593,6 +585,7 @@ components: id: 'deepseek' url: 'https://api.deepseek.com' auth_type: 'bearer' + supports_vision: false endpoints: models: name: 'list_models' @@ -606,6 +599,21 @@ components: id: 'google' url: 'https://generativelanguage.googleapis.com/v1beta/openai' auth_type: 'bearer' + supports_vision: true + endpoints: + models: + name: 'list_models' + method: 'GET' + endpoint: '/models' + chat: + name: 'chat_completions' + method: 'POST' + endpoint: '/chat/completions' + mistral: + id: 'mistral' + url: 'https://api.mistral.ai/v1' + auth_type: 'bearer' + supports_vision: true endpoints: models: name: 'list_models' @@ -709,7 +717,13 @@ components: role: $ref: '#/components/schemas/MessageRole' content: - type: string + oneOf: + - type: string + description: Text content (backward compatibility) + - type: array + items: + $ref: '#/components/schemas/ContentPart' + description: Array of content parts for multimodal messages tool_calls: type: array items: @@ -725,6 +739,53 @@ components: required: - role - content + ContentPart: + type: object + description: A content part within a multimodal message + oneOf: + - $ref: '#/components/schemas/TextContentPart' + - $ref: '#/components/schemas/ImageContentPart' + TextContentPart: + type: object + description: Text content part + properties: + type: + type: string + enum: [text] + description: Content type identifier + text: + type: string + description: The text content + required: + - type + - text + ImageContentPart: + type: object + description: Image content part + properties: + type: + type: string + enum: [image_url] + description: Content type identifier + image_url: + $ref: '#/components/schemas/ImageURL' + required: + - type + - image_url + ImageURL: + type: object + description: Image URL configuration + properties: + url: + type: string + description: URL of the image (data URLs supported) + detail: + type: string + enum: [auto, low, high] + default: auto + description: Image detail level for vision processing + required: + - url Model: type: object description: Common model information @@ -779,103 +840,6 @@ components: required: - object - data - ListAgentsResponse: - type: object - description: Response structure for listing A2A agents - properties: - object: - type: string - description: Always "list" - example: 'list' - data: - type: array - items: - $ref: '#/components/schemas/A2AAgentCard' - default: [] - description: Array of available A2A agents - required: - - object - - data - A2AAgentCard: - description: |- - An AgentCard conveys key information: - - Overall details (version, name, description, uses) - - Skills: A set of capabilities the agent can perform - - Default modalities/content types supported by the agent. - - Authentication requirements - properties: - capabilities: - additionalProperties: true - description: Optional capabilities supported by the agent. - defaultInputModes: - description: |- - The set of interaction modes that the agent supports across all skills. This can be overridden per-skill. - Supported media types for input. - items: - type: string - type: array - defaultOutputModes: - description: Supported media types for output. - items: - type: string - type: array - description: - description: |- - A human-readable description of the agent. Used to assist users and - other agents in understanding what the agent can do. - type: string - documentationUrl: - description: A URL to documentation for the agent. - type: string - iconUrl: - description: A URL to an icon for the agent. - type: string - id: - description: Unique identifier for the agent (base64-encoded SHA256 hash of the agent URL). - type: string - name: - description: Human readable name of the agent. - type: string - provider: - additionalProperties: true - description: The service provider of the agent - security: - description: Security requirements for contacting the agent. - items: - additionalProperties: true - type: object - type: array - securitySchemes: - additionalProperties: true - description: Security scheme details used for authenticating with this agent. - type: object - skills: - description: Skills are a unit of capability that an agent can perform. - items: - additionalProperties: true - type: array - supportsAuthenticatedExtendedCard: - description: |- - true if the agent supports providing an extended agent card when the user is authenticated. - Defaults to false if not specified. - type: boolean - url: - description: A URL to the address the agent is hosted at. - type: string - version: - description: The version of the agent - format is up to the provider. - type: string - required: - - capabilities - - defaultInputModes - - defaultOutputModes - - description - - id - - name - - skills - - url - - version - type: object MCPTool: type: object description: An MCP tool definition @@ -1352,6 +1316,21 @@ components: type: string default: '' description: 'Comma-separated list of models to allow. If empty, all models will be available' + - name: enable_vision + env: 'ENABLE_VISION' + type: bool + default: 'false' + description: 'Enable vision/multimodal support for all providers. When disabled, image inputs will be rejected even if the provider and model support vision' + - name: debug_content_truncate_words + env: 'DEBUG_CONTENT_TRUNCATE_WORDS' + type: int + default: '10' + description: 'Number of words to truncate per content section in debug logs (development mode only)' + - name: debug_max_messages + env: 'DEBUG_MAX_MESSAGES' + type: int + default: '100' + description: 'Maximum number of messages to show in debug logs (development mode only)' - telemetry: title: 'Telemetry' settings: @@ -1457,78 +1436,6 @@ components: type: bool default: 'true' description: 'Disable health check log messages to reduce noise' - - a2a: - title: 'Agent-to-Agent (A2A) Protocol' - settings: - - name: a2a_enable - env: 'A2A_ENABLE' - type: bool - default: 'false' - description: 'Enable A2A protocol support' - - name: a2a_expose - env: 'A2A_EXPOSE' - type: bool - default: 'false' - description: 'Expose A2A agents list cards endpoint' - - name: a2a_agents - env: 'A2A_AGENTS' - type: string - description: 'Comma-separated list of A2A agent URLs' - - name: a2a_client_timeout - env: 'A2A_CLIENT_TIMEOUT' - type: time.Duration - default: '30s' - description: 'A2A client timeout' - - name: a2a_polling_enable - env: 'A2A_POLLING_ENABLE' - type: bool - default: 'true' - description: 'Enable task status polling' - - name: a2a_polling_interval - env: 'A2A_POLLING_INTERVAL' - type: time.Duration - default: '1s' - description: 'Interval between polling requests' - - name: a2a_polling_timeout - env: 'A2A_POLLING_TIMEOUT' - type: time.Duration - default: '30s' - description: 'Maximum time to wait for task completion' - - name: a2a_max_poll_attempts - env: 'A2A_MAX_POLL_ATTEMPTS' - type: int - default: '30' - description: 'Maximum number of polling attempts' - - name: a2a_max_retries - env: 'A2A_MAX_RETRIES' - type: int - default: '3' - description: 'Maximum number of connection retry attempts' - - name: a2a_retry_interval - env: 'A2A_RETRY_INTERVAL' - type: time.Duration - default: '5s' - description: 'Interval between connection retry attempts' - - name: a2a_initial_backoff - env: 'A2A_INITIAL_BACKOFF' - type: time.Duration - default: '1s' - description: 'Initial backoff duration for exponential backoff retry' - - name: a2a_enable_reconnect - env: 'A2A_ENABLE_RECONNECT' - type: bool - default: 'true' - description: 'Enable automatic reconnection for failed agents' - - name: a2a_reconnect_interval - env: 'A2A_RECONNECT_INTERVAL' - type: time.Duration - default: '30s' - description: 'Interval between reconnection attempts' - - name: a2a_disable_healthcheck_logs - env: 'A2A_DISABLE_HEALTHCHECK_LOGS' - type: bool - default: 'true' - description: 'Disable health check log messages to reduce noise' - auth: title: 'Authentication' settings: @@ -1685,6 +1592,16 @@ components: type: string description: 'Ollama API Key' secret: true + - name: ollama_cloud_api_url + env: 'OLLAMA_CLOUD_API_URL' + type: string + default: 'https://ollama.com/v1' + description: 'Ollama Cloud API URL' + - name: ollama_cloud_api_key + env: 'OLLAMA_CLOUD_API_KEY' + type: string + description: 'Ollama Cloud API Key' + secret: true - name: openai_api_url env: 'OPENAI_API_URL' type: string @@ -1715,3 +1632,13 @@ components: type: string description: 'Google API Key' secret: true + - name: mistral_api_url + env: 'MISTRAL_API_URL' + type: string + default: 'https://api.mistral.ai/v1' + description: 'Mistral API URL' + - name: mistral_api_key + env: 'MISTRAL_API_KEY' + type: string + description: 'Mistral API Key' + secret: true diff --git a/src/lib.rs b/src/lib.rs index 492ab7a..9f0dde3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,9 @@ pub enum GatewayError { #[error("Forbidden: {0}")] Forbidden(String), + #[error("Not found: {0}")] + NotFound(String), + #[error("Bad request: {0}")] BadRequest(String), @@ -66,13 +69,13 @@ pub struct Model { /// The model identifier pub id: String, /// The object type, usually "model" - pub object: Option, + pub object: String, /// The Unix timestamp (in seconds) of when the model was created - pub created: Option, + pub created: i64, /// The organization that owns the model - pub owned_by: Option, + pub owned_by: String, /// The provider that serves the model - pub served_by: Option, + pub served_by: Provider, } /// Response structure for listing models @@ -116,6 +119,8 @@ pub struct ListToolsResponse { pub enum Provider { #[serde(alias = "Ollama", alias = "OLLAMA")] Ollama, + #[serde(alias = "OllamaCloud", alias = "OLLAMA_CLOUD", rename = "ollama_cloud")] + OllamaCloud, #[serde(alias = "Groq", alias = "GROQ")] Groq, #[serde(alias = "OpenAI", alias = "OPENAI")] @@ -130,12 +135,15 @@ pub enum Provider { Deepseek, #[serde(alias = "Google", alias = "GOOGLE")] Google, + #[serde(alias = "Mistral", alias = "MISTRAL")] + Mistral, } impl fmt::Display for Provider { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Provider::Ollama => write!(f, "ollama"), + Provider::OllamaCloud => write!(f, "ollama_cloud"), Provider::Groq => write!(f, "groq"), Provider::OpenAI => write!(f, "openai"), Provider::Cloudflare => write!(f, "cloudflare"), @@ -143,6 +151,7 @@ impl fmt::Display for Provider { Provider::Anthropic => write!(f, "anthropic"), Provider::Deepseek => write!(f, "deepseek"), Provider::Google => write!(f, "google"), + Provider::Mistral => write!(f, "mistral"), } } } @@ -153,6 +162,7 @@ impl TryFrom<&str> for Provider { fn try_from(s: &str) -> Result { match s.to_lowercase().as_str() { "ollama" => Ok(Self::Ollama), + "ollama_cloud" => Ok(Self::OllamaCloud), "groq" => Ok(Self::Groq), "openai" => Ok(Self::OpenAI), "cloudflare" => Ok(Self::Cloudflare), @@ -160,6 +170,7 @@ impl TryFrom<&str> for Provider { "anthropic" => Ok(Self::Anthropic), "deepseek" => Ok(Self::Deepseek), "google" => Ok(Self::Google), + "mistral" => Ok(Self::Mistral), _ => Err(GatewayError::BadRequest(format!("Unknown provider: {s}"))), } } @@ -199,7 +210,10 @@ pub struct Message { /// Unique identifier of the tool call #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, - /// Reasoning behind the message + /// The reasoning content of the message + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, + /// The reasoning of the message (same as reasoning_content) #[serde(skip_serializing_if = "Option::is_none")] pub reasoning: Option, } @@ -277,25 +291,50 @@ struct CreateChatCompletionRequest { /// Maximum number of tokens to generate #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, + /// The format of the reasoning content. Can be `raw` or `parsed`. + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_format: Option, } -/// A tool call in the response -#[derive(Debug, Deserialize, Clone)] -pub struct ToolCallResponse { +/// A tool call chunk in streaming responses +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ChatCompletionMessageToolCallChunk { + /// Index of the tool call in the array + pub index: i32, /// Unique identifier of the tool call - pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, /// Type of tool that was called - #[serde(rename = "type")] - pub r#type: ToolType, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub r#type: Option, /// Function that the LLM wants to call - pub function: ChatCompletionMessageToolCallFunction, + #[serde(skip_serializing_if = "Option::is_none")] + pub function: Option, +} + +/// The reason the model stopped generating tokens +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum FinishReason { + /// Model hit a natural stop point or a provided stop sequence + Stop, + /// Maximum number of tokens specified in the request was reached + Length, + /// Model called a tool + ToolCalls, + /// Content was omitted due to a flag from content filters + ContentFilter, + /// Function call (deprecated, use tool_calls) + FunctionCall, } #[derive(Debug, Deserialize, Clone)] pub struct ChatCompletionChoice { - pub finish_reason: String, + pub finish_reason: FinishReason, pub message: Message, pub index: i32, + /// Log probability information for the choice + pub logprobs: Option, } /// The response from generating content @@ -327,6 +366,42 @@ pub struct CreateChatCompletionStreamResponse { /// Usage statistics for the completion request. #[serde(skip_serializing_if = "Option::is_none")] pub usage: Option, + /// The format of the reasoning content. Can be `raw` or `parsed`. + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_format: Option, +} + +/// Token log probability information +#[derive(Debug, Deserialize, Clone)] +pub struct ChatCompletionTokenLogprob { + /// The token + pub token: String, + /// The log probability of this token + pub logprob: f64, + /// UTF-8 bytes representation of the token + pub bytes: Option>, + /// List of the most likely tokens and their log probability + pub top_logprobs: Vec, +} + +/// Top log probability entry +#[derive(Debug, Deserialize, Clone)] +pub struct TopLogprob { + /// The token + pub token: String, + /// The log probability of this token + pub logprob: f64, + /// UTF-8 bytes representation of the token + pub bytes: Option>, +} + +/// Log probability information for a choice +#[derive(Debug, Deserialize, Clone)] +pub struct ChoiceLogprobs { + /// A list of message content tokens with log probability information + pub content: Option>, + /// A list of message refusal tokens with log probability information + pub refusal: Option>, } /// Choice in a streaming completion response @@ -338,7 +413,10 @@ pub struct ChatCompletionStreamChoice { pub index: i32, /// The reason the model stopped generating tokens #[serde(skip_serializing_if = "Option::is_none")] - pub finish_reason: Option, + pub finish_reason: Option, + /// Log probability information for the choice + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, } /// Delta content for streaming responses @@ -350,9 +428,18 @@ pub struct ChatCompletionStreamDelta { /// Content of the message delta #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, + /// The reasoning content of the chunk message + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, + /// The reasoning of the chunk message (same as reasoning_content) + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, /// Tool calls for this delta #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, + pub tool_calls: Option>, + /// The refusal message generated by the model + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, } /// Usage statistics for the completion request @@ -628,6 +715,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient { stream: false, tools: self.tools.clone(), max_tokens: self.max_tokens, + reasoning_format: None, }; let response = request.json(&request_payload).send().await?; @@ -673,6 +761,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient { stream: true, tools: None, max_tokens: None, + reasoning_format: None, }; async_stream::try_stream! { @@ -756,8 +845,9 @@ impl InferenceGatewayAPI for InferenceGatewayClient { mod tests { use crate::{ CreateChatCompletionRequest, CreateChatCompletionResponse, - CreateChatCompletionStreamResponse, FunctionObject, GatewayError, InferenceGatewayAPI, - InferenceGatewayClient, Message, MessageRole, Provider, Tool, ToolType, + CreateChatCompletionStreamResponse, FinishReason, FunctionObject, GatewayError, + InferenceGatewayAPI, InferenceGatewayClient, Message, MessageRole, Provider, Tool, + ToolType, }; use futures_util::{pin_mut, StreamExt}; use mockito::{Matcher, Server}; @@ -767,6 +857,7 @@ mod tests { fn test_provider_serialization() { let providers = vec![ (Provider::Ollama, "ollama"), + (Provider::OllamaCloud, "ollama_cloud"), (Provider::Groq, "groq"), (Provider::OpenAI, "openai"), (Provider::Cloudflare, "cloudflare"), @@ -774,6 +865,7 @@ mod tests { (Provider::Anthropic, "anthropic"), (Provider::Deepseek, "deepseek"), (Provider::Google, "google"), + (Provider::Mistral, "mistral"), ]; for (provider, expected) in providers { @@ -786,6 +878,7 @@ mod tests { fn test_provider_deserialization() { let test_cases = vec![ ("\"ollama\"", Provider::Ollama), + ("\"ollama_cloud\"", Provider::OllamaCloud), ("\"groq\"", Provider::Groq), ("\"openai\"", Provider::OpenAI), ("\"cloudflare\"", Provider::Cloudflare), @@ -793,6 +886,7 @@ mod tests { ("\"anthropic\"", Provider::Anthropic), ("\"deepseek\"", Provider::Deepseek), ("\"google\"", Provider::Google), + ("\"mistral\"", Provider::Mistral), ]; for (json, expected) in test_cases { @@ -840,6 +934,7 @@ mod tests { fn test_provider_display() { let providers = vec![ (Provider::Ollama, "ollama"), + (Provider::OllamaCloud, "ollama_cloud"), (Provider::Groq, "groq"), (Provider::OpenAI, "openai"), (Provider::Cloudflare, "cloudflare"), @@ -847,6 +942,7 @@ mod tests { (Provider::Anthropic, "anthropic"), (Provider::Deepseek, "deepseek"), (Provider::Google, "google"), + (Provider::Mistral, "mistral"), ]; for (provider, expected) in providers { @@ -909,6 +1005,7 @@ mod tests { }, }]), max_tokens: None, + reasoning_format: None, }; let serialized = serde_json::to_string_pretty(&request_payload).unwrap(); @@ -1107,6 +1204,7 @@ mod tests { { "index": 0, "finish_reason": "stop", + "logprobs": null, "message": { "role": "assistant", "content": "Hellloooo" @@ -1154,6 +1252,7 @@ mod tests { { "index": 0, "finish_reason": "stop", + "logprobs": null, "message": { "role": "assistant", "content": "Hello" @@ -1295,6 +1394,7 @@ mod tests { { "index": 0, "finish_reason": "stop", + "logprobs": null, "message": { "role": "assistant", "content": "Hello" @@ -1374,7 +1474,7 @@ mod tests { if generate_response.choices[0].finish_reason.is_some() { assert_eq!( generate_response.choices[0].finish_reason.as_ref().unwrap(), - "stop" + &FinishReason::Stop ); break; } @@ -1449,6 +1549,7 @@ mod tests { { "index": 0, "finish_reason": "tool_calls", + "logprobs": null, "message": { "role": "assistant", "content": "Let me check the weather for you.", @@ -1538,6 +1639,7 @@ mod tests { { "index": 0, "finish_reason": "stop", + "logprobs": null, "message": { "role": "assistant", "content": "Hello!" @@ -1622,6 +1724,7 @@ mod tests { { "index": 0, "finish_reason": "stop", + "logprobs": null, "message": { "role": "assistant", "content": "Let me check the weather for you", @@ -1719,6 +1822,7 @@ mod tests { { "index": 0, "finish_reason": "stop", + "logprobs": null, "message": { "role": "assistant", "content": "Here's a poem with 100 tokens..."