Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/cocoindex/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,16 @@ def internal_handler(self) -> _engine.SimpleSemanticsQueryHandler:
return self._lazy_query_handler()

def search(self, query: str, limit: int, vector_field_name: str | None = None,
similarity_metric: index.VectorSimilarityMetric | None = None) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
similarity_metric: index.VectorSimilarityMetric | None = None
) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
"""
Search the index with the given query, limit, vector field name, and similarity metric.
"""
internal_results, internal_info = self.internal_handler().search(
query, limit, vector_field_name,
similarity_metric.value if similarity_metric is not None else None)
fields = [field['name'] for field in internal_results['fields']]
results = [QueryResult(data=dict(zip(fields, result['data'])), score=result['score']) for result in internal_results['results']]
results = [QueryResult(data=result['data'], score=result['score'])
for result in internal_results]
info = SimpleSemanticsQueryInfo(
similarity_metric=index.VectorSimilarityMetric(internal_info['similarity_metric']),
query_vector=internal_info['query_vector'],
Expand Down
34 changes: 29 additions & 5 deletions src/ops/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,44 @@ pub struct VectorMatchQuery {
}

#[derive(Debug, Clone, Serialize)]
pub struct QueryResult {
pub data: Vec<Value>,
pub struct QueryResult<Row = Vec<Value>> {
pub data: Row,
pub score: f64,
}

#[derive(Debug, Clone, Serialize)]
pub struct QueryResults {
pub struct QueryResults<Row = Vec<Value>> {
pub fields: Vec<FieldSchema>,
pub results: Vec<QueryResult>,
pub results: Vec<QueryResult<Row>>,
}

impl TryFrom<QueryResults<Vec<Value>>> for QueryResults<serde_json::Value> {
type Error = anyhow::Error;

fn try_from(values: QueryResults<Vec<Value>>) -> Result<Self, Self::Error> {
let results = values
.results
.into_iter()
.map(|r| {
let data = serde_json::to_value(TypedFieldsValue {
schema: &values.fields,
values_iter: r.data.iter(),
})?;
Ok(QueryResult {
data,
score: r.score,
})
})
.collect::<Result<Vec<_>>>()?;
Ok(QueryResults {
fields: values.fields,
results,
})
}
}
#[derive(Debug, Clone, Serialize)]
pub struct QueryResponse {
pub results: QueryResults,
pub results: QueryResults<serde_json::Value>,
pub info: serde_json::Value,
}

Expand Down
30 changes: 15 additions & 15 deletions src/py/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::prelude::*;
use crate::base::spec::VectorSimilarityMetric;
use crate::execution::query;
use crate::lib_context::{clear_lib_context, get_auth_registry, init_lib_context};
use crate::ops::interface::QueryResults;
use crate::ops::interface::{QueryResult, QueryResults};
use crate::ops::py_factory::PyOpArgSchema;
use crate::ops::{interface::ExecutorFactory, py_factory::PyFunctionFactory, register_factory};
use crate::server::{self, ServerSettings};
Expand Down Expand Up @@ -264,24 +264,24 @@ impl SimpleSemanticsQueryHandler {
vector_field_name: Option<String>,
similarity_metric: Option<Pythonized<VectorSimilarityMetric>>,
) -> PyResult<(
Pythonized<QueryResults>,
Pythonized<Vec<QueryResult<serde_json::Value>>>,
Pythonized<query::SimpleSemanticsQueryInfo>,
)> {
py.allow_threads(|| {
let (results, info) = get_runtime()
.block_on(async move {
self.0
.search(
query,
limit,
vector_field_name,
similarity_metric.map(|m| m.0),
)
.await
})
.into_py_result()?;
Ok((Pythonized(results), Pythonized(info)))
let (results, info) = get_runtime().block_on(async move {
self.0
.search(
query,
limit,
vector_field_name,
similarity_metric.map(|m| m.0),
)
.await
})?;
let results = QueryResults::<serde_json::Value>::try_from(results)?;
anyhow::Ok((Pythonized(results.results), Pythonized(info)))
})
.into_py_result()
}
}

Expand Down
13 changes: 4 additions & 9 deletions src/service/search.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
use std::sync::Arc;
use crate::prelude::*;

use axum::extract::Path;
use axum::http::StatusCode;
use serde::Deserialize;

use axum::{extract::State, Json};
use axum_extra::extract::Query;

use crate::base::spec;
use crate::lib_context::LibContext;
use crate::ops::interface::QueryResponse;

use super::error::ApiError;
use axum::{extract::State, Json};
use axum_extra::extract::Query;

#[derive(Debug, Deserialize)]
pub struct SearchParams {
Expand Down Expand Up @@ -51,7 +46,7 @@ pub async fn search(
.search(query.query, query.limit, query.field, query.metric)
.await?;
let response = QueryResponse {
results,
results: results.try_into()?,
info: serde_json::to_value(info).map_err(|e| {
ApiError::new(
&format!("Failed to serialize query info: {e}"),
Expand Down