diff --git a/python/cocoindex/query.py b/python/cocoindex/query.py index befd643e..9b5f1056 100644 --- a/python/cocoindex/query.py +++ b/python/cocoindex/query.py @@ -66,15 +66,16 @@ def internal_handler(self) -> _engine.SimpleSemanticsQueryHandler: return self._lazy_query_handler() def search(self, query: str, limit: int, vector_field_name: str | None = None, - similarity_metric: index.VectorSimilarityMetric | None = None) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]: + similarity_metric: index.VectorSimilarityMetric | None = None + ) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]: """ Search the index with the given query, limit, vector field name, and similarity metric. """ internal_results, internal_info = self.internal_handler().search( query, limit, vector_field_name, similarity_metric.value if similarity_metric is not None else None) - fields = [field['name'] for field in internal_results['fields']] - results = [QueryResult(data=dict(zip(fields, result['data'])), score=result['score']) for result in internal_results['results']] + results = [QueryResult(data=result['data'], score=result['score']) + for result in internal_results] info = SimpleSemanticsQueryInfo( similarity_metric=index.VectorSimilarityMetric(internal_info['similarity_metric']), query_vector=internal_info['query_vector'], diff --git a/src/ops/interface.rs b/src/ops/interface.rs index 0f12da50..357592b2 100644 --- a/src/ops/interface.rs +++ b/src/ops/interface.rs @@ -228,20 +228,44 @@ pub struct VectorMatchQuery { } #[derive(Debug, Clone, Serialize)] -pub struct QueryResult { - pub data: Vec, +pub struct QueryResult> { + pub data: Row, pub score: f64, } #[derive(Debug, Clone, Serialize)] -pub struct QueryResults { +pub struct QueryResults> { pub fields: Vec, - pub results: Vec, + pub results: Vec>, } +impl TryFrom>> for QueryResults { + type Error = anyhow::Error; + + fn try_from(values: QueryResults>) -> Result { + let results = values + .results + .into_iter() + .map(|r| { + let data = serde_json::to_value(TypedFieldsValue { + schema: &values.fields, + values_iter: r.data.iter(), + })?; + Ok(QueryResult { + data, + score: r.score, + }) + }) + .collect::>>()?; + Ok(QueryResults { + fields: values.fields, + results, + }) + } +} #[derive(Debug, Clone, Serialize)] pub struct QueryResponse { - pub results: QueryResults, + pub results: QueryResults, pub info: serde_json::Value, } diff --git a/src/py/mod.rs b/src/py/mod.rs index 21400d4d..081569fa 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -3,7 +3,7 @@ use crate::prelude::*; use crate::base::spec::VectorSimilarityMetric; use crate::execution::query; use crate::lib_context::{clear_lib_context, get_auth_registry, init_lib_context}; -use crate::ops::interface::QueryResults; +use crate::ops::interface::{QueryResult, QueryResults}; use crate::ops::py_factory::PyOpArgSchema; use crate::ops::{interface::ExecutorFactory, py_factory::PyFunctionFactory, register_factory}; use crate::server::{self, ServerSettings}; @@ -264,24 +264,24 @@ impl SimpleSemanticsQueryHandler { vector_field_name: Option, similarity_metric: Option>, ) -> PyResult<( - Pythonized, + Pythonized>>, Pythonized, )> { py.allow_threads(|| { - let (results, info) = get_runtime() - .block_on(async move { - self.0 - .search( - query, - limit, - vector_field_name, - similarity_metric.map(|m| m.0), - ) - .await - }) - .into_py_result()?; - Ok((Pythonized(results), Pythonized(info))) + let (results, info) = get_runtime().block_on(async move { + self.0 + .search( + query, + limit, + vector_field_name, + similarity_metric.map(|m| m.0), + ) + .await + })?; + let results = QueryResults::::try_from(results)?; + anyhow::Ok((Pythonized(results.results), Pythonized(info))) }) + .into_py_result() } } diff --git a/src/service/search.rs b/src/service/search.rs index 449b9947..d42d1f2d 100644 --- a/src/service/search.rs +++ b/src/service/search.rs @@ -1,17 +1,12 @@ -use std::sync::Arc; +use crate::prelude::*; use axum::extract::Path; use axum::http::StatusCode; -use serde::Deserialize; -use axum::{extract::State, Json}; -use axum_extra::extract::Query; - -use crate::base::spec; use crate::lib_context::LibContext; use crate::ops::interface::QueryResponse; - -use super::error::ApiError; +use axum::{extract::State, Json}; +use axum_extra::extract::Query; #[derive(Debug, Deserialize)] pub struct SearchParams { @@ -51,7 +46,7 @@ pub async fn search( .search(query.query, query.limit, query.field, query.metric) .await?; let response = QueryResponse { - results, + results: results.try_into()?, info: serde_json::to_value(info).map_err(|e| { ApiError::new( &format!("Failed to serialize query info: {e}"),