diff --git a/src/command.rs b/src/command.rs index 74b9914..341cb4a 100644 --- a/src/command.rs +++ b/src/command.rs @@ -158,7 +158,8 @@ pub enum Commands { /// Search query text — required for both --type bm25 and --type vector query: String, - /// Search type — required (no default; choose deliberately) + /// Search type (`bm25` or `vector`). Inferred automatically when the table has exactly + /// one search index — required only when multiple indexes exist. /// /// `vector` runs server-side `vector_distance(col, 'text')` — the server resolves the /// embedding column, model, and metric from the index metadata. @@ -166,16 +167,19 @@ pub enum Commands { /// `bm25` runs server-side `bm25_search(table, col, 'text')` and requires a BM25 index /// on the column. #[arg(long, value_parser = ["vector", "bm25"])] - r#type: String, + r#type: Option, - /// Table to search (connection.schema.table) + /// Table to search (`connection.table` or `connection.schema.table`). + /// Schema defaults to `public` when omitted. #[arg(long)] table: String, - /// Column to search. For `--type vector`, name the source text column — the server - /// resolves the embedding column from the index metadata. + /// Column to search. Inferred automatically when the table has exactly one search index + /// of the resolved type — required only when multiple indexed columns exist. + /// For `--type vector`, name the source text column — the server resolves the embedding + /// column from the index metadata. #[arg(long)] - column: String, + column: Option, /// Columns to display (comma-separated, defaults to all) #[arg(long)] diff --git a/src/indexes.rs b/src/indexes.rs index b9473fa..2465b2b 100644 --- a/src/indexes.rs +++ b/src/indexes.rs @@ -147,6 +147,104 @@ fn list_one_table_scan( } } +/// Pure matching logic for search inference — extracted for testability. +/// +/// Filters `indexes` to searchable types (`bm25`, `vector`), narrows by `hint_type` / +/// `hint_column` when provided, and returns `Ok((index_type, column))` on an unambiguous +/// match. Returns `Err(message)` on no match, multiple matches, or an index with no columns. +/// `location` is used only in error messages (e.g. `"mydb.public.listings"`). +fn resolve_search_params( + indexes: &[Index], + hint_type: Option<&str>, + hint_column: Option<&str>, + location: &str, +) -> Result<(String, String), String> { + let matches: Vec<&Index> = indexes + .iter() + .filter(|i| { + let t = i.index_type.as_str(); + (t == "bm25" || t == "vector") + && hint_type.map_or(true, |ht| ht == t) + && hint_column.map_or(true, |hc| i.columns.iter().any(|c| c == hc)) + }) + .collect(); + + match matches.as_slice() { + [] => { + let what = match hint_type { + Some(t) => format!("{} index", t), + None => "BM25 or vector index".to_string(), + }; + Err(format!( + "No {} found on {} — run 'hotdata indexes create' first.", + what, location + )) + } + [one] => { + let index_type = one.index_type.clone(); + let column = one.columns.first().cloned().ok_or_else(|| { + format!("Index '{}' has no columns.", one.index_name) + })?; + Ok((index_type, column)) + } + _ => { + let types: Vec<&str> = matches.iter().map(|i| i.index_type.as_str()).collect(); + let cols: Vec = matches + .iter() + .flat_map(|i| i.columns.iter().cloned()) + .collect(); + Err(format!( + "Multiple search indexes found (types: {}, columns: {}) — specify --type and --column.", + types.join(", "), + cols.join(", ") + )) + } + } +} + +/// Infers `(index_type, column)` for `hotdata search` when `--type` or `--column` are omitted. +/// +/// Fetches the indexes on `connection_name.schema.table`, filters to searchable types +/// (`bm25`, `vector`), and narrows further by `hint_type` / `hint_column` when provided. +/// Exits with an error when the result is ambiguous (multiple matches) or no index exists. +pub fn infer_for_search( + workspace_id: &str, + connection_name: &str, + schema: &str, + table: &str, + hint_type: Option<&str>, + hint_column: Option<&str>, +) -> (String, String) { + use crossterm::style::Stylize; + + let api = ApiClient::new(Some(workspace_id)); + + // Resolve connection name → ID + let conn_map = connection_lookup(&api); + let connection_id = match conn_map.get(connection_name) { + Some(id) => id.clone(), + None => { + eprintln!( + "{}", + format!("Connection '{}' not found.", connection_name).red() + ); + std::process::exit(1); + } + }; + + // Fetch indexes for this table + let indexes = list_one_table(&api, &connection_id, schema, table); + + let location = format!("{}.{}.{}", connection_name, schema, table); + match resolve_search_params(&indexes, hint_type, hint_column, &location) { + Ok(result) => result, + Err(msg) => { + eprintln!("{}", msg.red()); + std::process::exit(1); + } + } +} + pub fn list( workspace_id: &str, connection_id: Option<&str>, @@ -574,4 +672,95 @@ mod tests { mock.assert(); assert!(rows.is_empty()); } + + fn make_index(name: &str, index_type: &str, columns: &[&str]) -> Index { + Index { + index_name: name.into(), + index_type: index_type.into(), + columns: columns.iter().map(|c| c.to_string()).collect(), + metric: None, + status: "ready".into(), + created_at: "2020-01-01T00:00:00Z".into(), + updated_at: "2020-01-01T00:00:00Z".into(), + } + } + + #[test] + fn resolve_search_params_single_bm25_returns_type_and_column() { + let indexes = vec![make_index("fts", "bm25", &["description"])]; + let result = resolve_search_params(&indexes, None, None, "db.public.t"); + assert_eq!(result, Ok(("bm25".into(), "description".into()))); + } + + #[test] + fn resolve_search_params_single_vector_returns_type_and_column() { + let indexes = vec![make_index("vec", "vector", &["embedding"])]; + let result = resolve_search_params(&indexes, None, None, "db.public.t"); + assert_eq!(result, Ok(("vector".into(), "embedding".into()))); + } + + #[test] + fn resolve_search_params_non_search_indexes_ignored() { + let indexes = vec![ + make_index("sorted_idx", "sorted", &["created_at"]), + make_index("fts", "bm25", &["body"]), + ]; + let result = resolve_search_params(&indexes, None, None, "db.public.t"); + assert_eq!(result, Ok(("bm25".into(), "body".into()))); + } + + #[test] + fn resolve_search_params_hint_type_narrows_to_single() { + let indexes = vec![ + make_index("fts", "bm25", &["description"]), + make_index("vec", "vector", &["embedding"]), + ]; + let result = resolve_search_params(&indexes, Some("bm25"), None, "db.public.t"); + assert_eq!(result, Ok(("bm25".into(), "description".into()))); + } + + #[test] + fn resolve_search_params_hint_column_narrows_to_single() { + let indexes = vec![ + make_index("fts_desc", "bm25", &["description"]), + make_index("fts_name", "bm25", &["name"]), + ]; + let result = resolve_search_params(&indexes, None, Some("name"), "db.public.t"); + assert_eq!(result, Ok(("bm25".into(), "name".into()))); + } + + #[test] + fn resolve_search_params_no_search_indexes_returns_error() { + let indexes = vec![make_index("sorted_idx", "sorted", &["id"])]; + let result = resolve_search_params(&indexes, None, None, "db.public.t"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("No BM25 or vector index found")); + } + + #[test] + fn resolve_search_params_no_index_error_mentions_hint_type() { + let indexes = vec![make_index("fts", "bm25", &["description"])]; + let result = resolve_search_params(&indexes, Some("vector"), None, "db.public.t"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("vector index")); + } + + #[test] + fn resolve_search_params_multiple_matches_returns_error() { + let indexes = vec![ + make_index("fts_desc", "bm25", &["description"]), + make_index("fts_name", "bm25", &["name"]), + ]; + let result = resolve_search_params(&indexes, None, None, "db.public.t"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Multiple search indexes found")); + } + + #[test] + fn resolve_search_params_index_with_no_columns_returns_error() { + let indexes = vec![make_index("fts", "bm25", &[])]; + let result = resolve_search_params(&indexes, None, None, "db.public.t"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("has no columns")); + } } diff --git a/src/main.rs b/src/main.rs index a26e7d9..aa07b0a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -706,9 +706,46 @@ fn main() { output, } => { let workspace_id = resolve_workspace(workspace_id); + + // Parse `connection.table` or `connection.schema.table`. + // Schema defaults to `public` when omitted. + let parts: Vec<&str> = table.splitn(4, '.').collect(); + let (conn_name, schema, table_name) = match parts.as_slice() { + [conn, schema, tbl] => { + (conn.to_string(), schema.to_string(), tbl.to_string()) + } + [conn, tbl] => (conn.to_string(), "public".to_string(), tbl.to_string()), + _ => { + eprintln!( + "error: --table must be 'connection.table' or 'connection.schema.table'" + ); + std::process::exit(1); + } + }; + let normalized_table = format!("{}.{}.{}", conn_name, schema, table_name); + + // Infer --type and --column from the table's indexes when either is omitted. + let (resolved_type, resolved_column) = + if r#type.is_some() && column.is_some() { + (r#type.unwrap(), column.unwrap()) + } else { + let (inferred_type, inferred_column) = indexes::infer_for_search( + &workspace_id, + &conn_name, + &schema, + &table_name, + r#type.as_deref(), + column.as_deref(), + ); + ( + r#type.unwrap_or(inferred_type), + column.unwrap_or(inferred_column), + ) + }; + let select_cols = select.as_deref().unwrap_or("*"); - let sql = match r#type.as_str() { + let sql = match resolved_type.as_str() { "bm25" => { let bm25_columns = match select.as_deref() { Some(cols) => format!("{}, score", cols), @@ -717,8 +754,8 @@ fn main() { format!( "SELECT {} FROM bm25_search('{}', '{}', '{}') ORDER BY score DESC LIMIT {}", bm25_columns, - table.replace('\'', "''"), - column.replace('\'', "''"), + normalized_table.replace('\'', "''"), + resolved_column.replace('\'', "''"), query.replace('\'', "''"), limit, ) @@ -728,9 +765,9 @@ fn main() { "vector" => format!( "SELECT {}, vector_distance({}, '{}') AS dist FROM {} ORDER BY dist LIMIT {}", select_cols, - column, + resolved_column, query.replace('\'', "''"), - table, + normalized_table, limit, ), _ => unreachable!(),