From 92a454c8b2bbdbd5f7f1cb52b24c24d98fcdd4fc Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Wed, 17 Sep 2025 23:47:25 -0700 Subject: [PATCH] feat(query-handler): minor adjustment for the API --- python/cocoindex/flow.py | 5 ++++- python/cocoindex/query_handler.py | 2 ++ src/lib_context.rs | 4 ++-- src/py/mod.rs | 4 ++-- src/service/flows.rs | 4 ++-- src/service/query_handler.rs | 10 +++++++--- 6 files changed, 19 insertions(+), 10 deletions(-) diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index aa56c10b..84d4f664 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -874,7 +874,10 @@ def add_query_handler( async def _handler(query: str) -> dict[str, Any]: handler_result = await async_handler(query) return { - "results": dump_engine_object(handler_result.results), + "results": [ + [(k, dump_engine_object(v)) for (k, v) in result.items()] + for result in handler_result.results + ], "query_info": dump_engine_object(handler_result.query_info), } diff --git a/python/cocoindex/query_handler.py b/python/cocoindex/query_handler.py index 35c92850..da566128 100644 --- a/python/cocoindex/query_handler.py +++ b/python/cocoindex/query_handler.py @@ -2,6 +2,7 @@ import numpy as np from numpy import typing as npt from typing import Generic, TypeVar +from .index import VectorSimilarityMetric @dataclasses.dataclass @@ -30,6 +31,7 @@ class QueryInfo: """ embedding: list[float] | npt.NDArray[np.float32] | None = None + similarity_metric: VectorSimilarityMetric | None = None R = TypeVar("R") diff --git a/src/lib_context.rs b/src/lib_context.rs index e0675ae5..ce542162 100644 --- a/src/lib_context.rs +++ b/src/lib_context.rs @@ -5,7 +5,7 @@ use crate::prelude::*; use crate::builder::AnalyzedFlow; use crate::execution::source_indexer::SourceIndexingContext; use crate::service::error::ApiError; -use crate::service::query_handler::{QueryHandler, QueryHandlerInfo}; +use crate::service::query_handler::{QueryHandler, QueryHandlerSpec}; use crate::settings; use crate::setup::ObjectSetupChange; use axum::http::StatusCode; @@ -99,7 +99,7 @@ impl FlowExecutionContext { } pub struct QueryHandlerContext { - pub info: Arc, + pub info: Arc, pub handler: Arc, } diff --git a/src/py/mod.rs b/src/py/mod.rs index 666ee0c8..85821bd2 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -9,7 +9,7 @@ use crate::lib_context::{ use crate::ops::py_factory::{PyExportTargetFactory, PyOpArgSchema}; use crate::ops::{interface::ExecutorFactory, py_factory::PyFunctionFactory, register_factory}; use crate::server::{self, ServerSettings}; -use crate::service::query_handler::QueryHandlerInfo; +use crate::service::query_handler::QueryHandlerSpec; use crate::settings::Settings; use crate::setup::{self}; use pyo3::IntoPyObjectExt; @@ -438,7 +438,7 @@ impl Flow { &self, name: String, handler: Py, - handler_info: Pythonized>, + handler_info: Pythonized>, ) -> PyResult<()> { struct PyQueryHandler { handler: Py, diff --git a/src/service/flows.rs b/src/service/flows.rs index d151a05a..436eadab 100644 --- a/src/service/flows.rs +++ b/src/service/flows.rs @@ -2,7 +2,7 @@ use crate::prelude::*; use crate::execution::{evaluator, indexing_status, memoization, row_indexer, stats}; use crate::lib_context::LibContext; -use crate::service::query_handler::{QueryHandlerInfo, QueryInput, QueryOutput}; +use crate::service::query_handler::{QueryHandlerSpec, QueryInput, QueryOutput}; use crate::{base::schema::FlowSchema, ops::interface::SourceExecutorReadOptions}; use axum::{ Json, @@ -31,7 +31,7 @@ pub async fn get_flow_schema( pub struct GetFlowResponseData { flow_spec: spec::FlowInstanceSpec, data_schema: FlowSchema, - query_handlers_spec: HashMap>, + query_handlers_spec: HashMap>, } #[derive(Serialize)] diff --git a/src/service/query_handler.rs b/src/service/query_handler.rs index 6b6d4650..e2781490 100644 --- a/src/service/query_handler.rs +++ b/src/service/query_handler.rs @@ -1,4 +1,7 @@ -use crate::prelude::*; +use crate::{ + base::spec::{FieldName, VectorSimilarityMetric}, + prelude::*, +}; #[derive(Serialize, Deserialize, Default)] pub struct QueryHandlerResultFields { @@ -7,7 +10,7 @@ pub struct QueryHandlerResultFields { } #[derive(Serialize, Deserialize, Default)] -pub struct QueryHandlerInfo { +pub struct QueryHandlerSpec { #[serde(default)] result_fields: QueryHandlerResultFields, } @@ -20,11 +23,12 @@ pub struct QueryInput { #[derive(Serialize, Deserialize, Default)] pub struct QueryInfo { pub embedding: Option, + pub similarity_metric: Option, } #[derive(Serialize, Deserialize)] pub struct QueryOutput { - pub results: Vec>, + pub results: Vec>, pub query_info: QueryInfo, }