From 7286cd92678c6cead0ccf41bda4960dd290b1b08 Mon Sep 17 00:00:00 2001 From: LJ Date: Sun, 30 Mar 2025 18:13:58 -0700 Subject: [PATCH] Add `SourceIndexingState` into `FlowContext`. --- src/builder/flow_builder.rs | 23 +++++------ src/execution/source_indexer.rs | 29 ++++++-------- src/lib_context.rs | 71 ++++++++++++++++++++++++--------- src/prelude.rs | 5 ++- src/py/mod.rs | 31 +++++++------- src/service/flows.rs | 29 +++++++------- src/service/search.rs | 47 +++++++++------------- src/setup/driver.rs | 12 +++--- 8 files changed, 136 insertions(+), 111 deletions(-) diff --git a/src/builder/flow_builder.rs b/src/builder/flow_builder.rs index 6b451874..6c49ff1b 100644 --- a/src/builder/flow_builder.rs +++ b/src/builder/flow_builder.rs @@ -1,9 +1,9 @@ -use anyhow::{anyhow, bail, Result}; +use crate::prelude::*; + use pyo3::{exceptions::PyException, prelude::*}; use std::{ collections::{btree_map, hash_map::Entry, HashMap}, ops::Deref, - sync::{Arc, Mutex, Weak}, }; use super::analyzer::{ @@ -11,10 +11,9 @@ use super::analyzer::{ ExecutionScope, ValueTypeBuilder, }; use crate::{ - api_bail, base::{ - schema::{self, CollectorSchema, FieldSchema}, - spec::{self, FieldName, NamedSpec}, + schema::{CollectorSchema, FieldSchema}, + spec::{FieldName, NamedSpec}, }, get_lib_context, lib_context::LibContext, @@ -649,10 +648,8 @@ impl FlowBuilder { )) }) .into_py_result()?; - let analyzed_flow = Arc::new(analyzed_flow); - - let mut analyzed_flows = self.lib_context.flows.write().unwrap(); - match analyzed_flows.entry(self.flow_instance_name.clone()) { + let mut flow_ctxs = self.lib_context.flows.lock().unwrap(); + let flow_ctx = match flow_ctxs.entry(self.flow_instance_name.clone()) { btree_map::Entry::Occupied(_) => { return Err(PyException::new_err(format!( "flow instance name already exists: {}", @@ -660,10 +657,12 @@ impl FlowBuilder { ))); } btree_map::Entry::Vacant(entry) => { - entry.insert(FlowContext::new(analyzed_flow.clone())); + let flow_ctx = Arc::new(FlowContext::new(Arc::new(analyzed_flow))); + entry.insert(flow_ctx.clone()); + flow_ctx } - } - Ok(py::Flow(analyzed_flow)) + }; + Ok(py::Flow(flow_ctx)) } pub fn build_transient_flow(&self, py: Python<'_>) -> PyResult { diff --git a/src/execution/source_indexer.rs b/src/execution/source_indexer.rs index 40bb82d1..530fb23a 100644 --- a/src/execution/source_indexer.rs +++ b/src/execution/source_indexer.rs @@ -10,13 +10,13 @@ use super::{ use futures::future::try_join_all; use sqlx::PgPool; use tokio::{sync::Semaphore, task::JoinSet}; -struct SourceRowState { +struct SourceRowIndexingState { source_version: SourceVersion, processing_sem: Arc, touched_generation: usize, } -impl Default for SourceRowState { +impl Default for SourceRowIndexingState { fn default() -> Self { Self { source_version: SourceVersion::default(), @@ -26,17 +26,17 @@ impl Default for SourceRowState { } } -struct SourceState { - rows: HashMap, +struct SourceIndexingState { + rows: HashMap, scan_generation: usize, } -pub struct SourceContext { +pub struct SourceIndexingContext { flow: Arc, source_idx: usize, - state: Mutex, + state: Mutex, } -impl SourceContext { +impl SourceIndexingContext { pub async fn load( flow: Arc, source_idx: usize, @@ -58,7 +58,7 @@ impl SourceContext { .into_key()?; rows.insert( source_key, - SourceRowState { + SourceRowIndexingState { source_version: SourceVersion::from_stored( key_metadata.processed_source_ordinal, &key_metadata.process_logic_fingerprint, @@ -72,7 +72,7 @@ impl SourceContext { Ok(Self { flow, source_idx, - state: Mutex::new(SourceState { + state: Mutex::new(SourceIndexingState { rows, scan_generation, }), @@ -144,7 +144,7 @@ impl SourceContext { } } hash_map::Entry::Vacant(entry) => { - entry.insert(SourceRowState { + entry.insert(SourceRowIndexingState { source_version: target_source_version, touched_generation: scan_generation, ..Default::default() @@ -259,15 +259,12 @@ impl SourceContext { } } -pub async fn update( - flow: &Arc, - pool: &PgPool, -) -> Result { - let plan = flow.get_execution_plan().await?; +pub async fn update(flow_context: &FlowContext, pool: &PgPool) -> Result { + let plan = flow_context.flow.get_execution_plan().await?; let source_update_stats = try_join_all( (0..plan.source_ops.len()) .map(|idx| async move { - let source_context = Arc::new(SourceContext::load(flow.clone(), idx, pool).await?); + let source_context = flow_context.get_source_indexing_context(idx, pool).await?; source_context.update_source(pool).await }) .collect::>(), diff --git a/src/lib_context.rs b/src/lib_context.rs index f27ac7e0..becb2a9e 100644 --- a/src/lib_context.rs +++ b/src/lib_context.rs @@ -1,50 +1,85 @@ +use crate::prelude::*; + use std::collections::BTreeMap; use std::sync::{Arc, RwLock}; +use crate::execution::source_indexer::SourceIndexingContext; use crate::service::error::ApiError; use crate::settings; use crate::setup; use crate::{builder::AnalyzedFlow, execution::query::SimpleSemanticsQueryHandler}; -use anyhow::Result; +use async_lock::OnceCell; use axum::http::StatusCode; use sqlx::PgPool; use tokio::runtime::Runtime; pub struct FlowContext { pub flow: Arc, - pub query_handlers: BTreeMap>, + pub source_indexing_contexts: Vec>>, + pub query_handlers: Mutex>>, } impl FlowContext { pub fn new(flow: Arc) -> Self { + let mut source_indexing_contexts = Vec::new(); + source_indexing_contexts + .resize_with(flow.flow_instance.source_ops.len(), || OnceCell::new()); Self { flow, - query_handlers: BTreeMap::new(), + source_indexing_contexts, + query_handlers: Mutex::new(BTreeMap::new()), } } + + pub async fn get_source_indexing_context( + &self, + source_idx: usize, + pool: &PgPool, + ) -> Result<&Arc> { + self.source_indexing_contexts[source_idx] + .get_or_try_init(|| async move { + Ok(Arc::new( + SourceIndexingContext::load(self.flow.clone(), source_idx, pool).await?, + )) + }) + .await + } + + pub fn get_query_handler(&self, name: &str) -> Result> { + let query_handlers = self.query_handlers.lock().unwrap(); + let query_handler = query_handlers + .get(name) + .ok_or_else(|| { + ApiError::new( + &format!("Query handler not found: {name}"), + StatusCode::NOT_FOUND, + ) + })? + .clone(); + Ok(query_handler) + } } pub struct LibContext { pub runtime: Runtime, pub pool: PgPool, - pub flows: RwLock>, + pub flows: Mutex>>, pub combined_setup_states: RwLock>, } impl LibContext { - pub fn with_flow_context( - &self, - flow_name: &str, - f: impl FnOnce(&FlowContext) -> R, - ) -> Result { - let flows = self.flows.read().unwrap(); - let flow_context = flows.get(flow_name).ok_or_else(|| { - ApiError::new( - &format!("Flow instance not found: {flow_name}"), - StatusCode::NOT_FOUND, - ) - })?; - Ok(f(flow_context)) + pub fn get_flow_context(&self, flow_name: &str) -> Result> { + let flows = self.flows.lock().unwrap(); + let flow_ctx = flows + .get(flow_name) + .ok_or_else(|| { + ApiError::new( + &format!("Flow instance not found: {flow_name}"), + StatusCode::NOT_FOUND, + ) + })? + .clone(); + Ok(flow_ctx) } } @@ -62,6 +97,6 @@ pub fn create_lib_context(settings: settings::Settings) -> Result { runtime, pool, combined_setup_states: RwLock::new(all_css), - flows: RwLock::new(BTreeMap::new()), + flows: Mutex::new(BTreeMap::new()), }) } diff --git a/src/prelude.rs b/src/prelude.rs index 9efd0362..e47a6b8c 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -6,13 +6,16 @@ pub(crate) use futures::{future::BoxFuture, prelude::*, stream::BoxStream}; pub(crate) use futures::{FutureExt, StreamExt}; pub(crate) use itertools::Itertools; pub(crate) use serde::{Deserialize, Serialize}; -pub(crate) use std::sync::{Arc, Mutex}; +pub(crate) use std::sync::{Arc, Mutex, Weak}; pub(crate) use crate::base::{schema, spec, value}; pub(crate) use crate::builder::{self, plan}; +pub(crate) use crate::execution; +pub(crate) use crate::lib_context::{FlowContext, LibContext}; pub(crate) use crate::ops::interface; pub(crate) use crate::service::error::ApiError; pub(crate) use crate::{api_bail, api_error}; +pub(crate) use anyhow::{anyhow, bail}; pub(crate) use log::{debug, error, info, trace, warn}; diff --git a/src/py/mod.rs b/src/py/mod.rs index 597c2cbd..4cb63c9c 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -1,3 +1,5 @@ +use crate::prelude::*; + use crate::base::spec::VectorSimilarityMetric; use crate::execution::query; use crate::get_lib_context; @@ -7,13 +9,10 @@ use crate::ops::py_factory::PyOpArgSchema; use crate::ops::{interface::ExecutorFactory, py_factory::PyFunctionFactory, register_factory}; use crate::server::{self, ServerSettings}; use crate::settings::Settings; +use crate::setup; use crate::LIB_CONTEXT; -use crate::{api_error, setup}; -use crate::{builder, execution}; -use anyhow::anyhow; use pyo3::{exceptions::PyException, prelude::*}; use std::collections::btree_map; -use std::sync::Arc; mod convert; pub use convert::*; @@ -91,12 +90,12 @@ impl IndexUpdateInfo { } #[pyclass] -pub struct Flow(pub Arc); +pub struct Flow(pub Arc); #[pymethods] impl Flow { pub fn __str__(&self) -> String { - serde_json::to_string_pretty(&self.0.flow_instance).unwrap() + serde_json::to_string_pretty(&self.0.flow.flow_instance).unwrap() } pub fn __repr__(&self) -> String { @@ -104,7 +103,7 @@ impl Flow { } pub fn name(&self) -> &str { - &self.0.flow_instance.name + &self.0.flow.flow_instance.name } pub fn update(&self, py: Python<'_>) -> PyResult { @@ -132,10 +131,10 @@ impl Flow { lib_context .runtime .block_on(async { - let exec_plan = self.0.get_execution_plan().await?; + let exec_plan = self.0.flow.get_execution_plan().await?; execution::dumper::evaluate_and_dump( &exec_plan, - &self.0.data_schema, + &self.0.flow.data_schema, options.into_inner(), &lib_context.pool, ) @@ -181,7 +180,7 @@ impl SimpleSemanticsQueryHandler { let handler = lib_context .runtime .block_on(query::SimpleSemanticsQueryHandler::new( - flow.0.clone(), + flow.0.flow.clone(), target_name, query_transform_flow.0.clone(), default_similarity_metric.0, @@ -194,11 +193,11 @@ impl SimpleSemanticsQueryHandler { pub fn register_query_handler(&self, name: String) -> PyResult<()> { let lib_context = get_lib_context() .ok_or_else(|| PyException::new_err("cocoindex library not initialized"))?; - let mut flows = lib_context.flows.write().unwrap(); - let flow_ctx = flows - .get_mut(&self.0.flow_name) - .ok_or_else(|| PyException::new_err(format!("flow not found: {}", self.0.flow_name)))?; - match flow_ctx.query_handlers.entry(name) { + let flow_ctx = lib_context + .get_flow_context(&self.0.flow_name) + .into_py_result()?; + let mut query_handlers = flow_ctx.query_handlers.lock().unwrap(); + match query_handlers.entry(name) { btree_map::Entry::Occupied(entry) => { return Err(PyException::new_err(format!( "query handler name already exists: {}", @@ -270,8 +269,8 @@ fn check_setup_status( ) -> PyResult { let lib_context = get_lib_context() .ok_or_else(|| PyException::new_err("cocoindex library not initialized"))?; + let flows = lib_context.flows.lock().unwrap(); let all_css = lib_context.combined_setup_states.read().unwrap(); - let flows = lib_context.flows.read().unwrap(); let setup_status = setup::check_setup_status(&flows, &all_css, options.into_inner()).into_py_result()?; Ok(SetupStatusCheck(setup_status)) diff --git a/src/service/flows.rs b/src/service/flows.rs index c324db57..db451177 100644 --- a/src/service/flows.rs +++ b/src/service/flows.rs @@ -17,7 +17,7 @@ pub async fn list_flows( State(lib_context): State>, ) -> Result>, ApiError> { Ok(Json( - lib_context.flows.read().unwrap().keys().cloned().collect(), + lib_context.flows.lock().unwrap().keys().cloned().collect(), )) } @@ -25,16 +25,16 @@ pub async fn get_flow_spec( Path(flow_name): Path, State(lib_context): State>, ) -> Result, ApiError> { - let fl = &lib_context.with_flow_context(&flow_name, |ctx| ctx.flow.clone())?; - Ok(Json(fl.flow_instance.clone())) + let flow_ctx = lib_context.get_flow_context(&flow_name)?; + Ok(Json(flow_ctx.flow.flow_instance.clone())) } pub async fn get_flow_schema( Path(flow_name): Path, State(lib_context): State>, ) -> Result, ApiError> { - let fl = &lib_context.with_flow_context(&flow_name, |ctx| ctx.flow.clone())?; - Ok(Json(fl.data_schema.clone())) + let flow_ctx = lib_context.get_flow_context(&flow_name)?; + Ok(Json(flow_ctx.flow.data_schema.clone())) } #[derive(Deserialize)] @@ -53,8 +53,8 @@ pub async fn get_keys( Query(query): Query, State(lib_context): State>, ) -> Result, ApiError> { - let fl = &lib_context.with_flow_context(&flow_name, |ctx| ctx.flow.clone())?; - let schema = &fl.data_schema; + let flow_ctx = lib_context.get_flow_context(&flow_name)?; + let schema = &flow_ctx.flow.data_schema; let field_idx = schema .fields @@ -77,7 +77,7 @@ pub async fn get_keys( ) })?; - let execution_plan = fl.get_execution_plan().await?; + let execution_plan = flow_ctx.flow.get_execution_plan().await?; let source_op = execution_plan .source_ops .iter() @@ -119,10 +119,11 @@ pub async fn evaluate_data( Query(query): Query, State(lib_context): State>, ) -> Result, ApiError> { - let fl = &lib_context.with_flow_context(&flow_name, |ctx| ctx.flow.clone())?; - let schema = &fl.data_schema; + let flow_ctx = lib_context.get_flow_context(&flow_name)?; + let schema = &flow_ctx.flow.data_schema; - let source_op_idx = fl + let source_op_idx = flow_ctx + .flow .flow_instance .source_ops .iter() @@ -133,7 +134,7 @@ pub async fn evaluate_data( StatusCode::BAD_REQUEST, ) })?; - let plan = fl.get_execution_plan().await?; + let plan = flow_ctx.flow.get_execution_plan().await?; let source_op = &plan.source_ops[source_op_idx]; let field_schema = &schema.fields[source_op.output.field_idx as usize]; let collection_schema = match &field_schema.value_type.typ { @@ -169,7 +170,7 @@ pub async fn update( Path(flow_name): Path, State(lib_context): State>, ) -> Result, ApiError> { - let fl = &lib_context.with_flow_context(&flow_name, |ctx| ctx.flow.clone())?; - let update_info = source_indexer::update(&fl, &lib_context.pool).await?; + let flow_ctx = lib_context.get_flow_context(&flow_name)?; + let update_info = source_indexer::update(&flow_ctx, &lib_context.pool).await?; Ok(Json(update_info)) } diff --git a/src/service/search.rs b/src/service/search.rs index ea45a209..449b9947 100644 --- a/src/service/search.rs +++ b/src/service/search.rs @@ -27,35 +27,26 @@ pub async fn search( Query(query): Query, State(lib_context): State>, ) -> Result, ApiError> { - let query_handler = lib_context.with_flow_context(&flow_name, |flow_ctx| { - Ok(match &query.handler { - Some(handler) => flow_ctx - .query_handlers - .get(handler) - .ok_or_else(|| { - ApiError::new( - &format!("Query handler not found: {handler}"), - StatusCode::NOT_FOUND, - ) - })? - .clone(), - None => { - if flow_ctx.query_handlers.is_empty() { - return Err(ApiError::new( - &format!("No query handler found for flow: {flow_name}"), - StatusCode::NOT_FOUND, - )); - } else if flow_ctx.query_handlers.len() == 1 { - flow_ctx.query_handlers.values().next().unwrap().clone() - } else { - return Err(ApiError::new( - "Found multiple query handlers for flow {}", - StatusCode::BAD_REQUEST, - )); - } + let flow_ctx = lib_context.get_flow_context(&flow_name)?; + let query_handler = match &query.handler { + Some(handler) => flow_ctx.get_query_handler(handler)?, + None => { + let query_handlers = flow_ctx.query_handlers.lock().unwrap(); + if query_handlers.is_empty() { + return Err(ApiError::new( + &format!("No query handler found for flow: {flow_name}"), + StatusCode::NOT_FOUND, + )); + } else if query_handlers.len() == 1 { + query_handlers.values().next().unwrap().clone() + } else { + return Err(ApiError::new( + "Found multiple query handlers for flow {}", + StatusCode::BAD_REQUEST, + )); } - }) - })??; + } + }; let (results, info) = query_handler .search(query.query, query.limit, query.field, query.metric) .await?; diff --git a/src/setup/driver.rs b/src/setup/driver.rs index f46bec03..b002e2b7 100644 --- a/src/setup/driver.rs +++ b/src/setup/driver.rs @@ -1,3 +1,8 @@ +use crate::prelude::*; + +use indexmap::IndexMap; +use serde::de::DeserializeOwned; +use sqlx::PgPool; use std::{ borrow::Cow, collections::{BTreeMap, BTreeSet, HashMap, HashSet}, @@ -5,11 +10,6 @@ use std::{ str::FromStr, }; -use anyhow::{bail, Result}; -use indexmap::IndexMap; -use serde::{de::DeserializeOwned, Deserialize}; -use sqlx::PgPool; - use super::{ db_metadata, CombinedState, DesiredMode, ExistingMode, FlowSetupState, FlowSetupStatusCheck, ObjectSetupStatusCheck, ObjectStatus, ResourceIdentifier, ResourceSetupStatusCheck, @@ -292,7 +292,7 @@ pub struct CheckSetupStatusOptions { } pub fn check_setup_status( - flows: &BTreeMap, + flows: &BTreeMap>, all_setup_state: &AllSetupState, options: CheckSetupStatusOptions, ) -> Result {