Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,24 +158,28 @@ pub enum Commands {
/// Search query text — required for both --type bm25 and --type vector
query: String,

/// Search type — required (no default; choose deliberately)
/// Search type (`bm25` or `vector`). Inferred automatically when the table has exactly
/// one search index — required only when multiple indexes exist.
///
/// `vector` runs server-side `vector_distance(col, 'text')` — the server resolves the
/// embedding column, model, and metric from the index metadata.
///
/// `bm25` runs server-side `bm25_search(table, col, 'text')` and requires a BM25 index
/// on the column.
#[arg(long, value_parser = ["vector", "bm25"])]
r#type: String,
r#type: Option<String>,

/// Table to search (connection.schema.table)
/// Table to search (`connection.table` or `connection.schema.table`).
/// Schema defaults to `public` when omitted.
#[arg(long)]
table: String,

/// Column to search. For `--type vector`, name the source text column — the server
/// resolves the embedding column from the index metadata.
/// Column to search. Inferred automatically when the table has exactly one search index
/// of the resolved type — required only when multiple indexed columns exist.
/// For `--type vector`, name the source text column — the server resolves the embedding
/// column from the index metadata.
#[arg(long)]
column: String,
column: Option<String>,

/// Columns to display (comma-separated, defaults to all)
#[arg(long)]
Expand Down
189 changes: 189 additions & 0 deletions src/indexes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,104 @@ fn list_one_table_scan(
}
}

/// Pure matching logic for search inference — extracted for testability.
///
/// Filters `indexes` to searchable types (`bm25`, `vector`), narrows by `hint_type` /
/// `hint_column` when provided, and returns `Ok((index_type, column))` on an unambiguous
/// match. Returns `Err(message)` on no match, multiple matches, or an index with no columns.
/// `location` is used only in error messages (e.g. `"mydb.public.listings"`).
fn resolve_search_params(
indexes: &[Index],
hint_type: Option<&str>,
hint_column: Option<&str>,
location: &str,
) -> Result<(String, String), String> {
let matches: Vec<&Index> = indexes
.iter()
.filter(|i| {
let t = i.index_type.as_str();
(t == "bm25" || t == "vector")
&& hint_type.map_or(true, |ht| ht == t)
&& hint_column.map_or(true, |hc| i.columns.iter().any(|c| c == hc))
})
.collect();

match matches.as_slice() {
[] => {
let what = match hint_type {
Some(t) => format!("{} index", t),
None => "BM25 or vector index".to_string(),
};
Err(format!(
"No {} found on {} — run 'hotdata indexes create' first.",
what, location
))
}
[one] => {
let index_type = one.index_type.clone();
let column = one.columns.first().cloned().ok_or_else(|| {
format!("Index '{}' has no columns.", one.index_name)
})?;
Ok((index_type, column))
}
_ => {
let types: Vec<&str> = matches.iter().map(|i| i.index_type.as_str()).collect();
let cols: Vec<String> = matches
.iter()
.flat_map(|i| i.columns.iter().cloned())
.collect();
Err(format!(
"Multiple search indexes found (types: {}, columns: {}) — specify --type and --column.",
types.join(", "),
cols.join(", ")
))
}
}
}

/// Infers `(index_type, column)` for `hotdata search` when `--type` or `--column` are omitted.
///
/// Fetches the indexes on `connection_name.schema.table`, filters to searchable types
/// (`bm25`, `vector`), and narrows further by `hint_type` / `hint_column` when provided.
/// Exits with an error when the result is ambiguous (multiple matches) or no index exists.
pub fn infer_for_search(
workspace_id: &str,
connection_name: &str,
schema: &str,
table: &str,
hint_type: Option<&str>,
hint_column: Option<&str>,
) -> (String, String) {
use crossterm::style::Stylize;

let api = ApiClient::new(Some(workspace_id));

// Resolve connection name → ID
let conn_map = connection_lookup(&api);
let connection_id = match conn_map.get(connection_name) {
Some(id) => id.clone(),
None => {
eprintln!(
"{}",
format!("Connection '{}' not found.", connection_name).red()
);
std::process::exit(1);
}
};

// Fetch indexes for this table
let indexes = list_one_table(&api, &connection_id, schema, table);

let location = format!("{}.{}.{}", connection_name, schema, table);
match resolve_search_params(&indexes, hint_type, hint_column, &location) {
Ok(result) => result,
Err(msg) => {
eprintln!("{}", msg.red());
std::process::exit(1);
}
}
}

pub fn list(
workspace_id: &str,
connection_id: Option<&str>,
Expand Down Expand Up @@ -574,4 +672,95 @@ mod tests {
mock.assert();
assert!(rows.is_empty());
}

fn make_index(name: &str, index_type: &str, columns: &[&str]) -> Index {
Index {
index_name: name.into(),
index_type: index_type.into(),
columns: columns.iter().map(|c| c.to_string()).collect(),
metric: None,
status: "ready".into(),
created_at: "2020-01-01T00:00:00Z".into(),
updated_at: "2020-01-01T00:00:00Z".into(),
}
}

#[test]
fn resolve_search_params_single_bm25_returns_type_and_column() {
let indexes = vec![make_index("fts", "bm25", &["description"])];
let result = resolve_search_params(&indexes, None, None, "db.public.t");
assert_eq!(result, Ok(("bm25".into(), "description".into())));
}

#[test]
fn resolve_search_params_single_vector_returns_type_and_column() {
let indexes = vec![make_index("vec", "vector", &["embedding"])];
let result = resolve_search_params(&indexes, None, None, "db.public.t");
assert_eq!(result, Ok(("vector".into(), "embedding".into())));
}

#[test]
fn resolve_search_params_non_search_indexes_ignored() {
let indexes = vec![
make_index("sorted_idx", "sorted", &["created_at"]),
make_index("fts", "bm25", &["body"]),
];
let result = resolve_search_params(&indexes, None, None, "db.public.t");
assert_eq!(result, Ok(("bm25".into(), "body".into())));
}

#[test]
fn resolve_search_params_hint_type_narrows_to_single() {
let indexes = vec![
make_index("fts", "bm25", &["description"]),
make_index("vec", "vector", &["embedding"]),
];
let result = resolve_search_params(&indexes, Some("bm25"), None, "db.public.t");
assert_eq!(result, Ok(("bm25".into(), "description".into())));
}

#[test]
fn resolve_search_params_hint_column_narrows_to_single() {
let indexes = vec![
make_index("fts_desc", "bm25", &["description"]),
make_index("fts_name", "bm25", &["name"]),
];
let result = resolve_search_params(&indexes, None, Some("name"), "db.public.t");
assert_eq!(result, Ok(("bm25".into(), "name".into())));
}

#[test]
fn resolve_search_params_no_search_indexes_returns_error() {
let indexes = vec![make_index("sorted_idx", "sorted", &["id"])];
let result = resolve_search_params(&indexes, None, None, "db.public.t");
assert!(result.is_err());
assert!(result.unwrap_err().contains("No BM25 or vector index found"));
}

#[test]
fn resolve_search_params_no_index_error_mentions_hint_type() {
let indexes = vec![make_index("fts", "bm25", &["description"])];
let result = resolve_search_params(&indexes, Some("vector"), None, "db.public.t");
assert!(result.is_err());
assert!(result.unwrap_err().contains("vector index"));
}

#[test]
fn resolve_search_params_multiple_matches_returns_error() {
let indexes = vec![
make_index("fts_desc", "bm25", &["description"]),
make_index("fts_name", "bm25", &["name"]),
];
let result = resolve_search_params(&indexes, None, None, "db.public.t");
assert!(result.is_err());
assert!(result.unwrap_err().contains("Multiple search indexes found"));
}

#[test]
fn resolve_search_params_index_with_no_columns_returns_error() {
let indexes = vec![make_index("fts", "bm25", &[])];
let result = resolve_search_params(&indexes, None, None, "db.public.t");
assert!(result.is_err());
assert!(result.unwrap_err().contains("has no columns"));
}
}
47 changes: 42 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,9 +706,46 @@ fn main() {
output,
} => {
let workspace_id = resolve_workspace(workspace_id);

// Parse `connection.table` or `connection.schema.table`.
// Schema defaults to `public` when omitted.
let parts: Vec<&str> = table.splitn(4, '.').collect();
let (conn_name, schema, table_name) = match parts.as_slice() {
[conn, schema, tbl] => {
(conn.to_string(), schema.to_string(), tbl.to_string())
}
[conn, tbl] => (conn.to_string(), "public".to_string(), tbl.to_string()),
_ => {
eprintln!(
"error: --table must be 'connection.table' or 'connection.schema.table'"
);
std::process::exit(1);
}
};
let normalized_table = format!("{}.{}.{}", conn_name, schema, table_name);

// Infer --type and --column from the table's indexes when either is omitted.
let (resolved_type, resolved_column) =
if r#type.is_some() && column.is_some() {
(r#type.unwrap(), column.unwrap())
} else {
let (inferred_type, inferred_column) = indexes::infer_for_search(
&workspace_id,
&conn_name,
&schema,
&table_name,
r#type.as_deref(),
column.as_deref(),
);
(
r#type.unwrap_or(inferred_type),
column.unwrap_or(inferred_column),
)
};

let select_cols = select.as_deref().unwrap_or("*");

let sql = match r#type.as_str() {
let sql = match resolved_type.as_str() {
"bm25" => {
let bm25_columns = match select.as_deref() {
Some(cols) => format!("{}, score", cols),
Expand All @@ -717,8 +754,8 @@ fn main() {
format!(
"SELECT {} FROM bm25_search('{}', '{}', '{}') ORDER BY score DESC LIMIT {}",
bm25_columns,
table.replace('\'', "''"),
column.replace('\'', "''"),
normalized_table.replace('\'', "''"),
resolved_column.replace('\'', "''"),
query.replace('\'', "''"),
limit,
)
Expand All @@ -728,9 +765,9 @@ fn main() {
"vector" => format!(
"SELECT {}, vector_distance({}, '{}') AS dist FROM {} ORDER BY dist LIMIT {}",
select_cols,
column,
resolved_column,
query.replace('\'', "''"),
table,
normalized_table,
limit,
),
_ => unreachable!(),
Expand Down
Loading