From 17df8c4a50289fe4e6dd8d7572aaf49d6e979f12 Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Fri, 4 Jul 2025 23:56:34 -0700 Subject: [PATCH] feat(flow-control): basic flow control for source #rows --- python/cocoindex/flow.py | 11 ++++++++++- src/base/spec.rs | 8 ++++++++ src/builder/analyzer.rs | 3 +++ src/builder/flow_builder.rs | 6 +++++- src/builder/plan.rs | 1 + src/execution/source_indexer.rs | 1 + src/utils/concur_control.rs | 30 ++++++++++++++++++++++++++++++ src/utils/mod.rs | 3 +++ 8 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 src/utils/concur_control.rs diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index 605e771d..c12d8872 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -416,6 +416,11 @@ class _SourceRefreshOptions: refresh_interval: datetime.timedelta | None = None +@dataclass +class _ExecutionOptions: + max_inflight_count: int | None = None + + class FlowBuilder: """ A flow builder is used to build a flow. @@ -439,6 +444,7 @@ def add_source( *, name: str | None = None, refresh_interval: datetime.timedelta | None = None, + max_inflight_count: int | None = None, ) -> DataSlice[T]: """ Import a source to the flow. @@ -454,9 +460,12 @@ def add_source( self._state.field_name_builder.build_name( name, prefix=_to_snake_case(_spec_kind(spec)) + "_" ), - dump_engine_object( + refresh_options=dump_engine_object( _SourceRefreshOptions(refresh_interval=refresh_interval) ), + execution_options=dump_engine_object( + _ExecutionOptions(max_inflight_count=max_inflight_count) + ), ), name, ) diff --git a/src/base/spec.rs b/src/base/spec.rs index 86613779..9ec43346 100644 --- a/src/base/spec.rs +++ b/src/base/spec.rs @@ -253,6 +253,11 @@ impl SpecFormatter for OpSpec { } } +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ExecutionOptions { + pub max_inflight_count: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct SourceRefreshOptions { pub refresh_interval: Option, @@ -274,6 +279,9 @@ pub struct ImportOpSpec { #[serde(default)] pub refresh_options: SourceRefreshOptions, + + #[serde(default)] + pub execution_options: ExecutionOptions, } impl SpecFormatter for ImportOpSpec { diff --git a/src/builder/analyzer.rs b/src/builder/analyzer.rs index 2d888673..e7d2cf09 100644 --- a/src/builder/analyzer.rs +++ b/src/builder/analyzer.rs @@ -697,6 +697,9 @@ impl AnalyzerContext { primary_key_type, name: op_name, refresh_options: import_op.spec.refresh_options, + concurrency_controller: utils::ConcurrencyController::new( + import_op.spec.execution_options.max_inflight_count, + ), }) }; Ok(result_fut) diff --git a/src/builder/flow_builder.rs b/src/builder/flow_builder.rs index 333e0a35..9ff13ce6 100644 --- a/src/builder/flow_builder.rs +++ b/src/builder/flow_builder.rs @@ -288,7 +288,7 @@ impl FlowBuilder { OpScopeRef(self.root_op_scope.clone()) } - #[pyo3(signature = (kind, op_spec, target_scope, name, refresh_options=None))] + #[pyo3(signature = (kind, op_spec, target_scope, name, refresh_options=None, execution_options=None))] pub fn add_source( &mut self, py: Python<'_>, @@ -297,6 +297,7 @@ impl FlowBuilder { target_scope: Option, name: String, refresh_options: Option>, + execution_options: Option>, ) -> PyResult { if let Some(target_scope) = target_scope { if *target_scope != self.root_op_scope { @@ -313,6 +314,9 @@ impl FlowBuilder { spec: op_spec.into_inner(), }, refresh_options: refresh_options.map(|o| o.into_inner()).unwrap_or_default(), + execution_options: execution_options + .map(|o| o.into_inner()) + .unwrap_or_default(), }, }; let analyzer_ctx = AnalyzerContext { diff --git a/src/builder/plan.rs b/src/builder/plan.rs index ab68bbf8..9e487f82 100644 --- a/src/builder/plan.rs +++ b/src/builder/plan.rs @@ -56,6 +56,7 @@ pub struct AnalyzedImportOp { pub output: AnalyzedOpOutput, pub primary_key_type: schema::ValueType, pub refresh_options: spec::SourceRefreshOptions, + pub concurrency_controller: utils::ConcurrencyController, } pub struct AnalyzedFunctionExecInfo { diff --git a/src/execution/source_indexer.rs b/src/execution/source_indexer.rs index 5982425b..c8a6f688 100644 --- a/src/execution/source_indexer.rs +++ b/src/execution/source_indexer.rs @@ -282,6 +282,7 @@ impl SourceIndexingContext { state.scan_generation }; while let Some(row) = rows_stream.next().await { + let _ = import_op.concurrency_controller.acquire().await?; for row in row? { self.process_source_key_if_newer( row.key, diff --git a/src/utils/concur_control.rs b/src/utils/concur_control.rs new file mode 100644 index 00000000..eec06a7f --- /dev/null +++ b/src/utils/concur_control.rs @@ -0,0 +1,30 @@ +use crate::prelude::*; + +use tokio::sync::{Semaphore, SemaphorePermit}; + +pub struct ConcurrencyController { + inflight_count_sem: Option, +} + +pub struct ConcurrencyControllerPermit<'a> { + _inflight_count_permit: Option>, +} + +impl ConcurrencyController { + pub fn new(max_inflight_count: Option) -> Self { + Self { + inflight_count_sem: max_inflight_count.map(|max| Semaphore::new(max as usize)), + } + } + + pub async fn acquire<'a>(&'a self) -> Result> { + let inflight_count_permit = if let Some(sem) = &self.inflight_count_sem { + Some(sem.acquire().await?) + } else { + None + }; + Ok(ConcurrencyControllerPermit { + _inflight_count_permit: inflight_count_permit, + }) + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 8923dedf..81c5e38e 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -3,3 +3,6 @@ pub mod fingerprint; pub mod immutable; pub mod retryable; pub mod yaml_ser; + +mod concur_control; +pub use concur_control::ConcurrencyController;