diff --git a/crates/sprout-agent/README.md b/crates/sprout-agent/README.md index e3b48302..5a6c2cf7 100644 --- a/crates/sprout-agent/README.md +++ b/crates/sprout-agent/README.md @@ -131,6 +131,7 @@ Everything is environment variables. No flags, no config files. (We are a subpro | `OPENAI_COMPAT_API_KEY` | — | Required when provider=openai. | | `OPENAI_COMPAT_MODEL` | — | Required when provider=openai. | | `OPENAI_COMPAT_BASE_URL` | `https://api.openai.com/v1` | Point at vLLM, llama.cpp, OpenRouter, Ollama, etc. | +| `OPENAI_COMPAT_API` | `auto` | `auto` \| `chat` \| `responses`. `auto` picks Responses for `*.openai.com`, Chat Completions everywhere else. | | `SPROUT_AGENT_SYSTEM_PROMPT` | built-in | Inline system prompt. | | `SPROUT_AGENT_SYSTEM_PROMPT_FILE` | — | File path. Mutually exclusive with the above. | | `SPROUT_AGENT_MAX_ROUNDS` | `0` | Tool-loop iteration cap. 0 = unlimited. | @@ -147,17 +148,19 @@ Everything is environment variables. No flags, no config files. (We are a subpro `sprout-agent` speaks two HTTP dialects. Pick with `SPROUT_AGENT_PROVIDER`. -| Provider | `SPROUT_AGENT_PROVIDER` | Endpoint | Tested with | +| Provider | `SPROUT_AGENT_PROVIDER` | Endpoint (auto) | Tested with | |---|---|---|---| | Anthropic | `anthropic` | `POST {base}/v1/messages` | claude-sonnet-4-5, claude-opus-4 | -| OpenAI | `openai` | `POST {base}/chat/completions` | gpt-5, gpt-4o | -| vLLM | `openai` | OpenAI-compatible endpoint | any tool-calling model | -| llama.cpp | `openai` | `--api-server` mode | any tool-calling GGUF | -| Ollama | `openai` | `http://localhost:11434/v1` | llama3.1, qwen2.5-coder | -| OpenRouter | `openai` | `https://openrouter.ai/api/v1` | anything they route | -| Block Gateway | `openai` | internal | gpt-5, claude | - -The "OpenAI" path is wire-compatible with the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). If a provider claims OpenAI compatibility and supports `tools` + `tool_choice: auto`, it works here. +| OpenAI | `openai` | `POST {base}/responses` | gpt-5, gpt-5-mini, o4-mini, gpt-4o | +| vLLM | `openai` | `POST {base}/chat/completions` | any tool-calling model | +| llama.cpp | `openai` | `POST {base}/chat/completions` | any tool-calling GGUF | +| Ollama | `openai` | `POST {base}/chat/completions` | llama3.1, qwen2.5-coder | +| OpenRouter | `openai` | `POST {base}/chat/completions` | anything they route | +| Block Gateway | `openai` | `POST {base}/chat/completions` | gpt-5, claude | + +`provider=openai` speaks two HTTP dialects: the [Responses API](https://platform.openai.com/docs/api-reference/responses) (`/v1/responses`, required for GPT-5 / o-series tool-calling on OpenAI's own service) and the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat) (`/chat/completions`, the broadly-supported OpenAI-compatible wire format). + +By default (`OPENAI_COMPAT_API=auto`) the agent picks **Responses** when `OPENAI_COMPAT_BASE_URL` points at an `*.openai.com` host and **Chat Completions** everywhere else. Pin the choice explicitly with `OPENAI_COMPAT_API=chat` or `OPENAI_COMPAT_API=responses` for providers that diverge from the default (e.g. a Responses-compatible self-hosted gateway). `Provider` is a Rust `enum` with one `match` in `Llm::complete`. There is no trait, no `Box`, no async-trait. Adding a third provider is a `match` arm and one `body`/`parse` pair in `llm.rs`. diff --git a/crates/sprout-agent/src/config.rs b/crates/sprout-agent/src/config.rs index 923b69e9..3c1c48f0 100644 --- a/crates/sprout-agent/src/config.rs +++ b/crates/sprout-agent/src/config.rs @@ -28,6 +28,18 @@ pub enum Provider { OpenAi, } +/// Which OpenAI-family HTTP API to call. Set via `OPENAI_COMPAT_API` +/// (`auto|chat|responses`); ignored when `provider = Anthropic`. `Auto` +/// picks Responses for `*.openai.com`, Chat Completions otherwise, and +/// permits a one-shot chat→responses upgrade on a "use /v1/responses" +/// provider error. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum OpenAiApi { + Chat, + Responses, + Auto, +} + #[derive(Debug, Clone)] pub struct Config { pub provider: Provider, @@ -57,6 +69,8 @@ pub struct Config { pub model: String, pub base_url: String, pub anthropic_api_version: String, + /// OpenAI endpoint selection. See [`OpenAiApi`]. + pub openai_api: OpenAiApi, } impl Config { @@ -66,16 +80,20 @@ impl Config { "openai" | "openai-compat" => Provider::OpenAi, o => return Err(format!("config: SPROUT_AGENT_PROVIDER={o} not supported")), }; - let (api_key, model, base_url) = match provider { + // OPENAI_COMPAT_API is only read when provider=openai, so a stray + // bad value can't break an Anthropic-only deployment. + let (api_key, model, base_url, openai_api) = match provider { Provider::Anthropic => ( req("ANTHROPIC_API_KEY")?, req("ANTHROPIC_MODEL")?, env_or("ANTHROPIC_BASE_URL", "https://api.anthropic.com"), + OpenAiApi::Auto, // unused for Anthropic ), Provider::OpenAi => ( req("OPENAI_COMPAT_API_KEY")?, req("OPENAI_COMPAT_MODEL")?, env_or("OPENAI_COMPAT_BASE_URL", "https://api.openai.com/v1"), + parse_openai_api(env("OPENAI_COMPAT_API").as_deref())?, ), }; let system_prompt = match (env("SPROUT_AGENT_SYSTEM_PROMPT"), env("SPROUT_AGENT_SYSTEM_PROMPT_FILE")) { @@ -92,6 +110,7 @@ impl Config { model, base_url, anthropic_api_version: env_or("ANTHROPIC_API_VERSION", "2023-06-01"), + openai_api, max_rounds: parse_env("SPROUT_AGENT_MAX_ROUNDS", 0)?, max_output_tokens: parse_env("SPROUT_AGENT_MAX_OUTPUT_TOKENS", 32_768)?, llm_timeout: Duration::from_secs(parse_env("SPROUT_AGENT_LLM_TIMEOUT_SECS", 120)?), @@ -182,6 +201,35 @@ fn req(k: &str) -> Result { env(k).ok_or_else(|| format!("config: {k} required")) } +/// Parse `OPENAI_COMPAT_API`. Pure (env-free) for testability; the +/// caller hands in the raw value. +fn parse_openai_api(raw: Option<&str>) -> Result { + match raw.unwrap_or("auto").trim().to_ascii_lowercase().as_str() { + "chat" | "chat-completions" | "chat_completions" => Ok(OpenAiApi::Chat), + "responses" => Ok(OpenAiApi::Responses), + "auto" | "" => Ok(OpenAiApi::Auto), + other => Err(format!( + "config: OPENAI_COMPAT_API={other} not supported (use auto|chat|responses)" + )), + } +} + +/// `true` when `base_url` is an official OpenAI host. Hosts on +/// `*.openai.com` get Responses under `Auto`; everything else (vLLM, +/// Ollama, OpenRouter, Block Gateway, …) gets Chat Completions. +/// Lookalike-safe: `api.openai.com.evil.example` returns `false`. +pub fn is_openai_host(base_url: &str) -> bool { + let rest = match base_url + .strip_prefix("https://") + .or_else(|| base_url.strip_prefix("http://")) + { + Some(r) => r, + None => return false, + }; + let host = &rest[..rest.find(['/', ':']).unwrap_or(rest.len())]; + host == "api.openai.com" || host.ends_with(".openai.com") +} + fn parse_env(key: &str, default: T) -> Result where T::Err: std::fmt::Display, @@ -337,4 +385,40 @@ mod tests { // expectation for callers. assert!(hs.allows("*")); } + + #[test] + fn parse_openai_api_values() { + use OpenAiApi::*; + for (raw, want) in [ + (None, Ok(Auto)), + (Some("auto"), Ok(Auto)), + (Some(" AUTO "), Ok(Auto)), + (Some(""), Ok(Auto)), + (Some("chat"), Ok(Chat)), + (Some("chat-completions"), Ok(Chat)), + (Some("Responses"), Ok(Responses)), + ] { + assert_eq!(parse_openai_api(raw), want, "raw={raw:?}"); + } + let err = parse_openai_api(Some("nope")).unwrap_err(); + assert!(err.contains("OPENAI_COMPAT_API=nope"), "{err}"); + } + + #[test] + fn is_openai_host_matrix() { + // Lookalike-safe: `api.openai.com.evil.example` and malformed URLs + // are treated as non-OpenAI (which falls back to Chat Completions). + for (url, want) in [ + ("https://api.openai.com/v1", true), + ("https://api.openai.com", true), + ("http://eu.api.openai.com/v1", true), + ("http://localhost:11434/v1", false), + ("https://openrouter.ai/api/v1", false), + ("https://gateway.block.example/v1", false), + ("https://api.openai.com.evil.example/v1", false), + ("not a url", false), + ] { + assert_eq!(is_openai_host(url), want, "url={url}"); + } + } } diff --git a/crates/sprout-agent/src/llm.rs b/crates/sprout-agent/src/llm.rs index ddf85f63..4cd3442c 100644 --- a/crates/sprout-agent/src/llm.rs +++ b/crates/sprout-agent/src/llm.rs @@ -1,7 +1,9 @@ +use std::sync::atomic::{AtomicBool, Ordering}; + use reqwest::Client; use serde_json::{json, Value}; -use crate::config::{Config, Provider}; +use crate::config::{is_openai_host, Config, OpenAiApi, Provider}; use crate::types::{ AgentError, HistoryItem, LlmResponse, ProviderStop, ToolCall, ToolDef, ToolResultContent, }; @@ -9,8 +11,17 @@ use crate::types::{ const MAX_LLM_RESPONSE_BYTES: usize = 16 * 1024 * 1024; const MAX_LLM_ERROR_BODY_BYTES: usize = 4 * 1024; +/// Parser for an OpenAI-family JSON response. Per-endpoint pair lives +/// alongside its `_body` serializer. +type OpenAiParse = fn(Value) -> Result; + pub struct Llm { http: Client, + /// One-shot sticky flag: set when a Chat Completions request comes + /// back with a "use /v1/responses" provider error while `cfg.openai_api + /// == Auto`. Subsequent OpenAI calls then go straight to Responses + /// for the lifetime of the process. + auto_upgraded: AtomicBool, } impl Llm { @@ -20,7 +31,10 @@ impl Llm { .timeout(cfg.llm_timeout) .build() .map_err(|e| AgentError::Llm(format!("http: {e}")))?; - Ok(Self { http }) + Ok(Self { + http, + auto_upgraded: AtomicBool::new(false), + }) } pub async fn complete( @@ -31,20 +45,26 @@ impl Llm { ) -> Result { match cfg.provider { Provider::Anthropic => { - let body = anthropic_body(cfg, history, tools); - let url = format!("{}/v1/messages", cfg.base_url.trim_end_matches('/')); - let v = post(&self.http, &url, &body, |r| { - r.header("x-api-key", &cfg.api_key) - .header("anthropic-version", &cfg.anthropic_api_version) - }) - .await?; + let v = self + .post_anthropic(cfg, &anthropic_body(cfg, history, tools)) + .await?; parse_anthropic(v) } Provider::OpenAi => { - let body = openai_body(cfg, history, tools); - let url = format!("{}/chat/completions", cfg.base_url.trim_end_matches('/')); - let v = post(&self.http, &url, &body, |r| r.bearer_auth(&cfg.api_key)).await?; - parse_openai(v) + self.openai_request(cfg, |use_responses| { + if use_responses { + ( + responses_body(cfg, history, tools), + parse_responses as OpenAiParse, + ) + } else { + ( + openai_body(cfg, history, tools), + parse_openai as OpenAiParse, + ) + } + }) + .await } } } @@ -67,29 +87,106 @@ impl Llm { "content": [{ "type": "text", "text": user_prompt }], }], }); - let url = format!("{}/v1/messages", cfg.base_url.trim_end_matches('/')); - let v = post(&self.http, &url, &body, |r| { - r.header("x-api-key", &cfg.api_key) - .header("anthropic-version", &cfg.anthropic_api_version) - }) - .await?; - Ok(parse_anthropic(v)?.text) + Ok(parse_anthropic(self.post_anthropic(cfg, &body).await?)?.text) } Provider::OpenAi => { - let body = json!({ - "model": cfg.model, - "stream": false, - "max_completion_tokens": max_output_tokens, - "messages": [ - { "role": "system", "content": system_prompt }, - { "role": "user", "content": user_prompt }, - ], - }); - let url = format!("{}/chat/completions", cfg.base_url.trim_end_matches('/')); - let v = post(&self.http, &url, &body, |r| r.bearer_auth(&cfg.api_key)).await?; - Ok(parse_openai(v)?.text) + let r = self + .openai_request(cfg, |use_responses| { + if use_responses { + ( + json!({ + "model": cfg.model, + "max_output_tokens": max_output_tokens, + "instructions": system_prompt, + "input": user_prompt, + }), + parse_responses as OpenAiParse, + ) + } else { + ( + json!({ + "model": cfg.model, + "stream": false, + "max_completion_tokens": max_output_tokens, + "messages": [ + { "role": "system", "content": system_prompt }, + { "role": "user", "content": user_prompt }, + ], + }), + parse_openai as OpenAiParse, + ) + } + }) + .await?; + Ok(r.text) + } + } + } + + async fn post_anthropic(&self, cfg: &Config, body: &Value) -> Result { + let url = format!("{}/v1/messages", cfg.base_url.trim_end_matches('/')); + post(&self.http, &url, body, |r| { + r.header("x-api-key", &cfg.api_key) + .header("anthropic-version", &cfg.anthropic_api_version) + }) + .await + } + + /// OpenAI dispatch: resolve endpoint (pinned > sticky-upgraded > auto by + /// host), POST, and on `auto` retry once on Responses if the provider + /// asks for it. `build` is called with `use_responses` so callers + /// only construct the body actually needed. + async fn openai_request(&self, cfg: &Config, mut build: F) -> Result + where + F: FnMut(bool) -> (Value, OpenAiParse) + Send, + { + let use_responses = self.auto_upgraded.load(Ordering::Relaxed) + || matches!(cfg.openai_api, OpenAiApi::Responses) + || matches!(cfg.openai_api, OpenAiApi::Auto) && is_openai_host(&cfg.base_url); + + if use_responses { + let (b, p) = build(true); + return p(self.post_openai(cfg, "/responses", &b).await?); + } + let (b, p) = build(false); + match self.post_openai(cfg, "/chat/completions", &b).await { + Ok(v) => p(v), + Err(e) if cfg.openai_api == OpenAiApi::Auto && self.try_upgrade(&e) => { + let (b, p) = build(true); + p(self.post_openai(cfg, "/responses", &b).await?) } + Err(e) => Err(e), + } + } + + async fn post_openai( + &self, + cfg: &Config, + path: &str, + body: &Value, + ) -> Result { + let url = format!("{}{}", cfg.base_url.trim_end_matches('/'), path); + post(&self.http, &url, body, |r| r.bearer_auth(&cfg.api_key)).await + } + + /// If `err` names `/v1/responses` / "use the Responses API", latch a + /// sticky upgrade so subsequent OpenAI calls hit Responses. Logged once. + fn try_upgrade(&self, err: &AgentError) -> bool { + let body = match err { + AgentError::Llm(s) => s.as_str(), + _ => return false, // auth/transport aren't "use the other endpoint" signals + }; + if !is_responses_required_error(body) { + return false; + } + if !self.auto_upgraded.swap(true, Ordering::Relaxed) { + tracing::warn!( + provider_message = body, + "openai: provider asked for the Responses API; \ + routing subsequent OpenAI calls to /v1/responses for this process" + ); } + true } } @@ -241,6 +338,174 @@ fn openai_image_user_content(content: &[ToolResultContent]) -> Vec { .collect() } +// ── OpenAI Responses API ─────────────────────────────────────────────────── +// Spec: https://platform.openai.com/docs/api-reference/responses +// +// Replay invariant: each assistant `function_call` input item **must** +// precede its matching `function_call_output`, or the API rejects with +// "No tool call found for call_id ...". `HistoryItem` ordering already +// guarantees this. + +fn responses_body(cfg: &Config, history: &[HistoryItem], tools: &[ToolDef]) -> Value { + let mut input: Vec = Vec::with_capacity(history.len()); + for item in history { + match item { + HistoryItem::User(text) => input.push(json!({ + "role": "user", + "content": [{ "type": "input_text", "text": text }], + })), + HistoryItem::Assistant { text, tool_calls } => { + if !text.is_empty() { + input.push(json!({ + "role": "assistant", + "content": [{ "type": "output_text", "text": text }], + })); + } + for c in tool_calls { + input.push(json!({ + "type": "function_call", + "call_id": c.provider_id, + "name": c.name, + "arguments": serde_json::to_string(&c.arguments) + .unwrap_or_else(|_| "{}".into()), + })); + } + } + HistoryItem::ToolResult(r) => { + input.push(json!({ + "type": "function_call_output", + "call_id": r.provider_id, + "output": openai_tool_text_content(&r.content), + })); + // Responses takes images as `input_image` parts on a user message. + let images: Vec = r + .content + .iter() + .filter_map(|c| match c { + ToolResultContent::Image { data, mime_type } => Some(json!({ + "type": "input_image", + "image_url": format!("data:{mime_type};base64,{data}"), + })), + ToolResultContent::Text(_) => None, + }) + .collect(); + if !images.is_empty() { + input.push(json!({ "role": "user", "content": images })); + } + } + } + } + + let tools_json: Vec = tools + .iter() + .map(|t| { + json!({ + "type": "function", + "name": t.name, + "description": t.description, + "parameters": t.input_schema, + }) + }) + .collect(); + + let mut body = json!({ + "model": cfg.model, + "instructions": cfg.system_prompt, + "max_output_tokens": cfg.max_output_tokens, + "input": input, + }); + if !tools_json.is_empty() { + body["tools"] = Value::Array(tools_json); + body["tool_choice"] = json!("auto"); + } + body +} + +/// Narrow matcher for "you should be on the Responses API" provider errors, +/// the signal we use to auto-upgrade. Triggers on the literal path +/// `/v1/responses` (Databricks GPT-5.5 phrasing) or the prose +/// "use the Responses API" / "Responses API instead". +fn is_responses_required_error(body: &str) -> bool { + let b = body.to_ascii_lowercase(); + b.contains("/v1/responses") + || b.contains("responses api instead") + || b.contains("use the responses api") +} + +fn parse_responses(v: Value) -> Result { + let mut text = String::new(); + let mut tool_calls = Vec::new(); + let mut saw_function_call = false; + + for item in v + .get("output") + .and_then(Value::as_array) + .into_iter() + .flatten() + { + match item.get("type").and_then(Value::as_str) { + Some("message") => { + for p in item + .get("content") + .and_then(Value::as_array) + .into_iter() + .flatten() + { + // Responses emits "output_text"; accept "text" forward-compat. + if matches!( + p.get("type").and_then(Value::as_str), + Some("output_text" | "text") + ) { + if let Some(t) = p.get("text").and_then(Value::as_str) { + text.push_str(t); + } + } + } + } + Some("function_call") => { + saw_function_call = true; + let raw = item + .get("arguments") + .and_then(Value::as_str) + .unwrap_or("{}"); + let args: Value = serde_json::from_str(raw).map_err(|e| { + AgentError::Llm(format!("function_call.arguments not valid JSON: {e}")) + })?; + tool_calls.push(make_tool_call( + str_field(item, "call_id"), + str_field(item, "name"), + args, + )?); + } + // Reasoning items are opaque/internal; we don't replay them. + // Unknown types ignored for forward-compat. + _ => {} + } + } + + let stop = match v.get("status").and_then(Value::as_str) { + Some("incomplete") => { + let reason = v + .get("incomplete_details") + .and_then(|d| d.get("reason")) + .and_then(Value::as_str); + if reason == Some("max_output_tokens") { + ProviderStop::MaxTokens + } else { + ProviderStop::Other + } + } + Some("completed") if saw_function_call => ProviderStop::ToolUse, + Some("completed") => ProviderStop::EndTurn, + _ => ProviderStop::Other, + }; + Ok(LlmResponse { + text, + tool_calls, + stop, + }) +} + fn map_stop(s: Option<&str>) -> ProviderStop { match s { Some("end_turn" | "stop") => ProviderStop::EndTurn, @@ -442,7 +707,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::config::{Config, HookServers, Provider}; + use crate::config::{Config, HookServers, OpenAiApi, Provider}; use crate::types::{HistoryItem, ToolCall, ToolResult, ToolResultContent}; use std::time::Duration; @@ -470,6 +735,7 @@ mod tests { model: "model".into(), base_url: "http://example.invalid".into(), anthropic_api_version: "2023-06-01".into(), + openai_api: OpenAiApi::Chat, } } @@ -509,6 +775,219 @@ mod tests { assert_eq!(content[1]["source"]["data"], "aW1n"); } + // ── Responses API unit tests ─────────────────────────────────────── + + fn cfg_responses() -> Config { + let mut c = cfg(Provider::OpenAi); + c.openai_api = OpenAiApi::Responses; + c + } + + fn tool_call_history() -> Vec { + vec![ + HistoryItem::User("call the tool".into()), + HistoryItem::Assistant { + text: "ok, calling".into(), + tool_calls: vec![ToolCall { + provider_id: "call_abc".into(), + name: "dev__shell".into(), + arguments: serde_json::json!({"command": "ls"}), + }], + }, + HistoryItem::ToolResult(ToolResult { + provider_id: "call_abc".into(), + content: vec![ToolResultContent::Text("file.txt".into())], + is_error: false, + }), + ] + } + + #[test] + fn responses_body_top_level_shape() { + let tools = vec![ToolDef { + name: "dev__shell".into(), + description: "run a shell command".into(), + input_schema: serde_json::json!({ + "type": "object", + "properties": {"command": {"type": "string"}}, + }), + }]; + let body = responses_body(&cfg_responses(), &[HistoryItem::User("hi".into())], &tools); + assert_eq!(body["model"], "model"); + assert_eq!(body["instructions"], "system"); + assert_eq!(body["max_output_tokens"], 1024); + assert!( + body.get("messages").is_none(), + "must use `input`, not `messages`" + ); + assert!(body.get("max_tokens").is_none()); + assert!(body.get("max_completion_tokens").is_none()); + + // Tools are flat — top-level type/name/description/parameters. + let tool = &body["tools"][0]; + assert_eq!(tool["type"], "function"); + assert_eq!(tool["name"], "dev__shell"); + assert!( + tool.get("function").is_none(), + "Responses tool schema is flat" + ); + assert_eq!(body["tool_choice"], "auto"); + } + + #[test] + fn responses_body_replay_emits_function_call_before_output() { + // Replay requirement from the live API: the assistant's prior + // function_call item *must* appear in `input[]` before its matching + // function_call_output, otherwise the API rejects with + // "No tool call found for call_id ...". + let body = responses_body(&cfg_responses(), &tool_call_history(), &[]); + let input = body["input"].as_array().unwrap(); + + // [0] user, [1] assistant text, [2] function_call, [3] function_call_output + assert_eq!(input[0]["role"], "user"); + assert_eq!(input[0]["content"][0]["type"], "input_text"); + assert_eq!(input[0]["content"][0]["text"], "call the tool"); + + assert_eq!(input[1]["role"], "assistant"); + assert_eq!(input[1]["content"][0]["type"], "output_text"); + assert_eq!(input[1]["content"][0]["text"], "ok, calling"); + + assert_eq!(input[2]["type"], "function_call"); + assert_eq!(input[2]["call_id"], "call_abc"); + assert_eq!(input[2]["name"], "dev__shell"); + // Arguments are a JSON-encoded string per spec. + assert_eq!(input[2]["arguments"], "{\"command\":\"ls\"}"); + + assert_eq!(input[3]["type"], "function_call_output"); + assert_eq!(input[3]["call_id"], "call_abc"); + assert_eq!(input[3]["output"], "file.txt"); + } + + #[test] + fn responses_body_skips_empty_assistant_text() { + // Mirrors the Chat Completions behavior (#559/#560): empty assistant + // turns are skipped so we don't emit an empty `output_text` block, + // but the tool_call(s) on that assistant turn still go through. + let history = vec![ + HistoryItem::User("u".into()), + HistoryItem::Assistant { + text: String::new(), + tool_calls: vec![ToolCall { + provider_id: "call_x".into(), + name: "t".into(), + arguments: serde_json::json!({}), + }], + }, + ]; + let body = responses_body(&cfg_responses(), &history, &[]); + let input = body["input"].as_array().unwrap(); + assert_eq!(input.len(), 2); + assert_eq!(input[0]["role"], "user"); + assert_eq!(input[1]["type"], "function_call"); + } + + #[test] + fn responses_body_image_tool_result_attaches_input_image() { + let body = responses_body(&cfg_responses(), &image_history(), &[]); + let input = body["input"].as_array().unwrap(); + // function_call_output carries the text part; image rides on a + // trailing user message as `input_image`. + let fco = input + .iter() + .find(|i| i["type"] == "function_call_output") + .unwrap(); + assert_eq!(fco["call_id"], "toolu_1"); + let img_msg = input.iter().rev().find(|i| i["role"] == "user").unwrap(); + assert_eq!(img_msg["content"][0]["type"], "input_image"); + assert_eq!( + img_msg["content"][0]["image_url"], + "data:image/png;base64,aW1n" + ); + } + + #[test] + fn parse_responses_completed_with_text_is_end_turn() { + let v = serde_json::json!({ + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "hello"}], + }], + }); + let r = parse_responses(v).unwrap(); + assert_eq!(r.text, "hello"); + assert!(r.tool_calls.is_empty()); + assert_eq!(r.stop, ProviderStop::EndTurn); + } + + #[test] + fn parse_responses_completed_with_function_call_is_tool_use() { + let v = serde_json::json!({ + "status": "completed", + "output": [ + {"type": "reasoning", "id": "rs_1", "summary": []}, + { + "type": "function_call", + "call_id": "call_z", + "name": "dev__shell", + "arguments": "{\"command\":\"ls\"}", + }, + ], + }); + let r = parse_responses(v).unwrap(); + assert_eq!(r.text, ""); + assert_eq!(r.tool_calls.len(), 1); + assert_eq!(r.tool_calls[0].provider_id, "call_z"); + assert_eq!(r.tool_calls[0].name, "dev__shell"); + assert_eq!( + r.tool_calls[0].arguments, + serde_json::json!({"command": "ls"}) + ); + assert_eq!(r.stop, ProviderStop::ToolUse); + } + + #[test] + fn parse_responses_incomplete_max_output_tokens() { + let v = serde_json::json!({ + "status": "incomplete", + "incomplete_details": {"reason": "max_output_tokens"}, + "output": [], + }); + let r = parse_responses(v).unwrap(); + assert_eq!(r.stop, ProviderStop::MaxTokens); + } + + #[test] + fn is_responses_required_error_matrix() { + for (body, want) in [ + // Databricks GPT-5.5 (the actual case we observed). + ("Function tools with reasoning_effort are not supported for gpt-5.5 in /v1/chat/completions. Please use /v1/responses instead.", true), + // Forward-compat: OpenAI saying the same thing in prose. + ("This model requires the Responses API. Please use the Responses API instead.", true), + // Negatives — must NOT trigger on unrelated 4xx. + ("{\"error\":\"invalid_api_key\"}", false), + ("max_tokens is not supported with this model", false), + ("", false), + ] { + assert_eq!(is_responses_required_error(body), want, "body={body:?}"); + } + } + + #[test] + fn parse_responses_rejects_malformed_function_arguments() { + let v = serde_json::json!({ + "status": "completed", + "output": [{ + "type": "function_call", + "call_id": "call_z", + "name": "t", + "arguments": "not json {", + }], + }); + assert!(matches!(parse_responses(v), Err(AgentError::Llm(_)))); + } + #[test] fn openai_tool_result_adds_followup_image_user_message() { let body = openai_body(&cfg(Provider::OpenAi), &image_history(), &[]); diff --git a/crates/sprout-agent/tests/openai_auto_upgrade.rs b/crates/sprout-agent/tests/openai_auto_upgrade.rs new file mode 100644 index 00000000..69f55917 --- /dev/null +++ b/crates/sprout-agent/tests/openai_auto_upgrade.rs @@ -0,0 +1,229 @@ +//! Integration test for OpenAI auto-upgrade chat→responses. +//! +//! Starts a tiny HTTP server that: +//! 1. accepts a POST to /chat/completions, replies 400 with a body that +//! mentions `/v1/responses` (mirrors the Databricks GPT-5.5 signal); +//! 2. accepts a POST to /responses, replies 200 with a Responses-shaped +//! JSON envelope. +//! +//! Spawns `sprout-agent` with `provider=openai` + `OPENAI_COMPAT_API=auto` +//! pointed at the fake server, drives one prompt through the ACP wire +//! protocol, and verifies the prompt completes with `stopReason=end_turn` +//! — which can only happen if the second (Responses) request succeeded. + +use std::io::{Read, Write}; +use std::net::TcpListener; +use std::process::Stdio; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use serde_json::json; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::Command; +use tokio::time::timeout; + +/// Spawns a single-shot fake provider. Returns the base URL (e.g. +/// `http://127.0.0.1:54321`). The server stays up for the lifetime of +/// the process — we don't need to clean it up explicitly. +fn spawn_fake_provider() -> (String, Arc, Arc) { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + listener.set_nonblocking(false).unwrap(); + let url = format!("http://{}", listener.local_addr().unwrap()); + let chat_hits = Arc::new(AtomicUsize::new(0)); + let responses_hits = Arc::new(AtomicUsize::new(0)); + let chat = chat_hits.clone(); + let resp = responses_hits.clone(); + + std::thread::spawn(move || { + loop { + let (mut sock, _) = match listener.accept() { + Ok(p) => p, + Err(_) => return, + }; + let chat = chat.clone(); + let resp = resp.clone(); + std::thread::spawn(move || { + sock.set_read_timeout(Some(Duration::from_secs(5))).ok(); + // Read request head + body. Naive: read until we have the + // request line + headers, then read Content-Length bytes. + let mut buf = Vec::with_capacity(4096); + let mut tmp = [0u8; 4096]; + loop { + if buf.windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + match sock.read(&mut tmp) { + Ok(0) | Err(_) => return, + Ok(n) => buf.extend_from_slice(&tmp[..n]), + } + if buf.len() > 256 * 1024 { + return; + } + } + let head_end = buf.windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4; + let head = String::from_utf8_lossy(&buf[..head_end]).to_string(); + // Drain the body to satisfy keep-alive; we don't actually + // need it. + let cl = head + .lines() + .find_map(|l| { + l.strip_prefix("content-length:") + .or_else(|| l.strip_prefix("Content-Length:")) + }) + .and_then(|s| s.trim().parse::().ok()) + .unwrap_or(0); + while buf.len() < head_end + cl { + match sock.read(&mut tmp) { + Ok(0) | Err(_) => break, + Ok(n) => buf.extend_from_slice(&tmp[..n]), + } + } + + let (status, body) = if head.contains("POST /chat/completions") { + chat.fetch_add(1, Ordering::SeqCst); + let body = json!({ + "error": { + "code": "BAD_REQUEST", + "message": "Function tools with reasoning_effort are not supported for gpt-5.5 in /v1/chat/completions. Please use /v1/responses instead." + } + }) + .to_string(); + (400u16, body) + } else if head.contains("POST /responses") { + resp.fetch_add(1, Ordering::SeqCst); + let body = json!({ + "status": "completed", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "ok from responses"}] + }] + }) + .to_string(); + (200u16, body) + } else { + (404u16, "{}".to_string()) + }; + let reason = match status { + 200 => "OK", + 400 => "Bad Request", + _ => "Not Found", + }; + let resp_text = format!( + "HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), body + ); + let _ = sock.write_all(resp_text.as_bytes()); + let _ = sock.shutdown(std::net::Shutdown::Write); + }); + } + }); + + (url, chat_hits, responses_hits) +} + +#[tokio::test] +async fn openai_auto_upgrades_chat_to_responses_on_databricks_signal() { + let (base_url, chat_hits, resp_hits) = spawn_fake_provider(); + + let bin = env!("CARGO_BIN_EXE_sprout-agent"); + let mut cmd = Command::new(bin); + cmd.env("SPROUT_AGENT_PROVIDER", "openai") + .env("OPENAI_COMPAT_API_KEY", "test") + .env("OPENAI_COMPAT_MODEL", "gpt-5.5") + .env("OPENAI_COMPAT_BASE_URL", &base_url) + // No OPENAI_COMPAT_API — must default to "auto" so the upgrade + // path is enabled. + .env_remove("OPENAI_COMPAT_API") + .env("SPROUT_AGENT_LLM_TIMEOUT_SECS", "5") + .env("SPROUT_AGENT_MAX_ROUNDS", "4") + .env("SPROUT_AGENT_MCP_INIT_TIMEOUT_SECS", "2") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) + .kill_on_drop(true); + + let mut child = cmd.spawn().expect("spawn sprout-agent"); + let mut stdin = child.stdin.take().unwrap(); + let mut stdout = BufReader::new(child.stdout.take().unwrap()); + + async fn send(stdin: &mut tokio::process::ChildStdin, v: serde_json::Value) { + let line = format!("{v}\n"); + stdin.write_all(line.as_bytes()).await.unwrap(); + stdin.flush().await.unwrap(); + } + async fn recv(stdout: &mut BufReader) -> serde_json::Value { + let mut line = String::new(); + timeout(Duration::from_secs(8), stdout.read_line(&mut line)) + .await + .expect("recv timed out") + .expect("recv io"); + serde_json::from_str(&line).expect("recv json") + } + + send( + &mut stdin, + json!({ + "jsonrpc": "2.0", "id": 1, "method": "initialize", + "params": {"protocolVersion": 1, "clientCapabilities": {}, + "clientInfo": {"name": "auto-upgrade-test"}} + }), + ) + .await; + let init = recv(&mut stdout).await; + assert!(init.get("result").is_some(), "initialize: {init}"); + + let cwd = std::env::current_dir().unwrap(); + send( + &mut stdin, + json!({ + "jsonrpc": "2.0", "id": 2, "method": "session/new", + "params": {"cwd": cwd.to_string_lossy(), "mcpServers": []} + }), + ) + .await; + let sess = recv(&mut stdout).await; + let sid = sess["result"]["sessionId"] + .as_str() + .unwrap_or_else(|| panic!("session/new failed: {sess}")) + .to_string(); + + send( + &mut stdin, + json!({ + "jsonrpc": "2.0", "id": 3, "method": "session/prompt", + "params": {"sessionId": sid, + "prompt": [{"type": "text", "text": "hi"}]} + }), + ) + .await; + + // Drain notifications until we see the response for id=3. + let mut stop_reason: Option = None; + for _ in 0..40 { + let msg = recv(&mut stdout).await; + if msg.get("id") == Some(&json!(3)) { + if let Some(r) = msg.get("result") { + stop_reason = r + .get("stopReason") + .and_then(|v| v.as_str()) + .map(String::from); + } + break; + } + } + assert_eq!(stop_reason.as_deref(), Some("end_turn")); + assert_eq!( + chat_hits.load(Ordering::SeqCst), + 1, + "must have tried chat first" + ); + assert!( + resp_hits.load(Ordering::SeqCst) >= 1, + "must have upgraded to responses" + ); + + drop(stdin); + let _ = child.wait().await; +}