Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 59 additions & 160 deletions codex-rs/core/src/tools/handlers/tool_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ use codex_tools::ToolName;
use codex_tools::ToolSearchSourceInfo;
use codex_tools::ToolSpec;
use codex_tools::coalesce_loadable_tool_specs;
use std::collections::HashMap;

const COMPUTER_USE_MCP_SERVER_NAME: &str = "computer-use";
const COMPUTER_USE_TOOL_SEARCH_LIMIT: usize = 20;

pub struct ToolSearchHandler {
entries: Vec<ToolSearchEntry>,
Expand Down Expand Up @@ -88,8 +84,7 @@ impl ToolHandler for ToolSearchHandler {
"query must not be empty".to_string(),
));
}
let requested_limit = args.limit;
let limit = requested_limit.unwrap_or(TOOL_SEARCH_DEFAULT_LIMIT);
let limit = args.limit.unwrap_or(TOOL_SEARCH_DEFAULT_LIMIT);

if limit == 0 {
return Err(FunctionCallError::RespondToModel(
Expand All @@ -101,7 +96,7 @@ impl ToolHandler for ToolSearchHandler {
return Ok(ToolSearchOutput { tools: Vec::new() });
}

let tools = self.search(query, limit, requested_limit.is_none())?;
let tools = self.search(query, limit)?;

Ok(ToolSearchOutput { tools })
}
Expand All @@ -112,44 +107,14 @@ impl ToolSearchHandler {
&self,
query: &str,
limit: usize,
use_default_limit: bool,
) -> Result<Vec<LoadableToolSpec>, FunctionCallError> {
let results = self.search_result_entries(query, limit, use_default_limit);
self.search_output_tools(results)
}

fn search_result_entries(
&self,
query: &str,
limit: usize,
use_default_limit: bool,
) -> Vec<&ToolSearchEntry> {
let mut results = self
let results = self
.search_engine
.search(query, limit)
.into_iter()
.map(|result| result.document.id)
.filter_map(|id| self.entries.get(id))
.collect::<Vec<_>>();
if !use_default_limit {
return results;
}

if results.iter().any(|entry| {
entry
.limit_bucket
.as_deref()
.is_some_and(|bucket| bucket == COMPUTER_USE_MCP_SERVER_NAME)
}) {
results = self
.search_engine
.search(query, COMPUTER_USE_TOOL_SEARCH_LIMIT)
.into_iter()
.map(|result| result.document.id)
.filter_map(|id| self.entries.get(id))
.collect();
}
limit_results_by_bucket(results)
.filter_map(|id| self.entries.get(id));
self.search_output_tools(results)
}

fn search_output_tools<'a>(
Expand All @@ -162,45 +127,23 @@ impl ToolSearchHandler {
}
}

fn limit_results_by_bucket(results: Vec<&ToolSearchEntry>) -> Vec<&ToolSearchEntry> {
results
.into_iter()
.scan(HashMap::<&str, usize>::new(), |counts, result| {
let Some(bucket) = result.limit_bucket.as_deref() else {
return Some(Some(result));
};
let count = counts.entry(bucket).or_default();
if *count >= default_limit_for_bucket(bucket) {
Some(None)
} else {
*count += 1;
Some(Some(result))
}
})
.flatten()
.collect()
}

fn default_limit_for_bucket(bucket: &str) -> usize {
if bucket == COMPUTER_USE_MCP_SERVER_NAME {
COMPUTER_USE_TOOL_SEARCH_LIMIT
} else {
TOOL_SEARCH_DEFAULT_LIMIT
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::session::tests::make_session_and_context;
use crate::tools::context::ToolCallSource;
use crate::tools::tool_search_entry::build_tool_search_entries;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_mcp::ToolInfo;
use codex_protocol::dynamic_tools::DynamicToolSpec;
use codex_protocol::models::SearchToolCallParams;
use codex_tools::ResponsesApiNamespace;
use codex_tools::ResponsesApiNamespaceTool;
use codex_tools::ResponsesApiTool;
use pretty_assertions::assert_eq;
use rmcp::model::Tool;
use std::sync::Arc;
use tokio::sync::Mutex;

#[test]
fn mixed_search_results_coalesce_mcp_namespaces() {
Expand Down Expand Up @@ -290,103 +233,66 @@ mod tests {
);
}

#[test]
fn computer_use_tool_search_uses_larger_limit() {
let tools = numbered_tools(
COMPUTER_USE_MCP_SERVER_NAME,
"computer use",
/*count*/ 100,
);
let handler = handler_from_tools(Some(&tools), &[]);

let results = handler.search_result_entries(
"computer use",
TOOL_SEARCH_DEFAULT_LIMIT,
/*use_default_limit*/ true,
);

assert_eq!(results.len(), COMPUTER_USE_TOOL_SEARCH_LIMIT);
assert!(
results
.iter()
.all(|entry| entry.limit_bucket.as_deref() == Some(COMPUTER_USE_MCP_SERVER_NAME))
);
#[tokio::test]
async fn omitted_limit_uses_default_tool_search_result_limit() {
let tool_count = TOOL_SEARCH_DEFAULT_LIMIT + 5;
let dynamic_tools = numbered_dynamic_tools(tool_count);
let handler = handler_from_tools(/*mcp_tools*/ None, &dynamic_tools);

let explicit_results = handler.search_result_entries(
"computer use",
/*limit*/ 100,
/*use_default_limit*/ false,
);
let output = tool_search_output(&handler, /*limit*/ None).await;

assert_eq!(explicit_results.len(), 100);
assert_eq!(output.tools.len(), TOOL_SEARCH_DEFAULT_LIMIT);
}

#[test]
fn non_computer_use_query_keeps_default_limit_with_computer_use_tools_installed() {
let mut tools = numbered_tools(
COMPUTER_USE_MCP_SERVER_NAME,
"computer use",
/*count*/ 100,
);
tools.extend(numbered_tools(
"other-server",
"calendar",
/*count*/ 100,
));
let handler = handler_from_tools(Some(&tools), &[]);

let results = handler.search_result_entries(
"calendar",
TOOL_SEARCH_DEFAULT_LIMIT,
/*use_default_limit*/ true,
);

assert_eq!(results.len(), TOOL_SEARCH_DEFAULT_LIMIT);
assert!(
results
.iter()
.all(|entry| entry.limit_bucket.as_deref() == Some("other-server"))
);
#[tokio::test]
async fn explicit_limit_controls_tool_search_result_count() {
let explicit_limit = 3;
let tool_count = TOOL_SEARCH_DEFAULT_LIMIT + explicit_limit;
let dynamic_tools = numbered_dynamic_tools(tool_count);
let handler = handler_from_tools(/*mcp_tools*/ None, &dynamic_tools);

let explicit_results = handler.search_result_entries(
"calendar", /*limit*/ 100, /*use_default_limit*/ false,
);
let output = tool_search_output(&handler, Some(explicit_limit)).await;

assert_eq!(explicit_results.len(), 100);
assert_eq!(output.tools.len(), explicit_limit);
}

#[test]
fn expanded_search_keeps_non_computer_use_servers_at_default_limit() {
let mut tools = numbered_tools(
COMPUTER_USE_MCP_SERVER_NAME,
"computer use",
/*count*/ 100,
);
tools.extend(numbered_tools(
"other-server",
"computer use",
/*count*/ 100,
));
let handler = handler_from_tools(Some(&tools), &[]);

let results = handler.search_result_entries(
"computer use",
TOOL_SEARCH_DEFAULT_LIMIT,
/*use_default_limit*/ true,
);

assert!(
count_results_for_server(&results, COMPUTER_USE_MCP_SERVER_NAME)
<= COMPUTER_USE_TOOL_SEARCH_LIMIT
);
assert!(count_results_for_server(&results, "other-server") <= TOOL_SEARCH_DEFAULT_LIMIT);
async fn tool_search_output(
handler: &ToolSearchHandler,
limit: Option<usize>,
) -> ToolSearchOutput {
let (session, turn) = make_session_and_context().await;
handler
.handle(ToolInvocation {
session: Arc::new(session),
turn: Arc::new(turn),
cancellation_token: tokio_util::sync::CancellationToken::new(),
tracker: Arc::new(Mutex::new(TurnDiffTracker::new())),
call_id: "call-tool-search".to_string(),
tool_name: ToolName::plain(TOOL_SEARCH_TOOL_NAME),
source: ToolCallSource::Direct,
payload: ToolPayload::ToolSearch {
arguments: SearchToolCallParams {
query: "calendar".to_string(),
limit,
},
},
})
.await
.expect("tool_search should succeed")
}

fn numbered_tools(server_name: &str, description_prefix: &str, count: usize) -> Vec<ToolInfo> {
fn numbered_dynamic_tools(count: usize) -> Vec<DynamicToolSpec> {
(0..count)
.map(|index| {
let tool_name = format!("tool_{index:03}");
tool_info(server_name, &tool_name, description_prefix)
.map(|index| DynamicToolSpec {
namespace: None,
name: format!("calendar_tool_{index:03}"),
description: "Calendar search helper.".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {},
"additionalProperties": false,
}),
defer_loading: true,
})
.collect()
}
Expand Down Expand Up @@ -420,13 +326,6 @@ mod tests {
}
}

fn count_results_for_server(results: &[&ToolSearchEntry], server_name: &str) -> usize {
results
.iter()
.filter(|entry| entry.limit_bucket.as_deref() == Some(server_name))
.count()
}

fn handler_from_tools(
mcp_tools: Option<&[ToolInfo]>,
dynamic_tools: &[DynamicToolSpec],
Expand Down
3 changes: 0 additions & 3 deletions codex-rs/core/src/tools/tool_search_entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use codex_tools::tool_search_result_source_to_loadable_tool_spec;
pub(crate) struct ToolSearchEntry {
pub(crate) search_text: String,
pub(crate) output: LoadableToolSpec,
pub(crate) limit_bucket: Option<String>,
}

pub(crate) fn build_tool_search_entries(
Expand Down Expand Up @@ -82,15 +81,13 @@ fn mcp_tool_search_entry(info: &ToolInfo) -> Result<ToolSearchEntry, serde_json:
connector_name: info.connector_name.as_deref(),
description: info.namespace_description.as_deref(),
})?,
limit_bucket: Some(info.server_name.clone()),
})
}

fn dynamic_tool_search_entry(tool: &DynamicToolSpec) -> Result<ToolSearchEntry, serde_json::Error> {
Ok(ToolSearchEntry {
search_text: build_dynamic_search_text(tool),
output: dynamic_tool_to_loadable_tool_spec(tool)?,
limit_bucket: None,
})
}

Expand Down
Loading