diff --git a/docs/docs/custom_ops/custom_functions.mdx b/docs/docs/custom_ops/custom_functions.mdx index e05abf67..9b63379c 100644 --- a/docs/docs/custom_ops/custom_functions.mdx +++ b/docs/docs/custom_ops/custom_functions.mdx @@ -145,6 +145,8 @@ Custom functions take the following additional parameters: * `batching: bool`: Whether the executor will consume requests in batch. See the [Batching](#batching) section below for details. +* `max_batch_size: int | None`: The maximum batch size for the executor. + * `behavior_version: int`: The version of the behavior of the function. When the version is changed, the function will be re-executed even if cache is enabled. It's required to be set if `cache` is `True`. @@ -221,5 +223,25 @@ class ComputeSomethingExecutor: ... ``` +### Controlling Batch Size + +You can control the maximum batch size using the `max_batch_size` parameter. This is useful for: +* Limiting memory usage when processing large batches +* Reducing latency by flushing batches before they grow too large +* Working with APIs that have request size limits + +```python +@cocoindex.op.function(batching=True, max_batch_size=32) +def compute_something(args: list[str]) -> list[str]: + ... +``` + +With `max_batch_size` set, a batch will be flushed when either: + +1. No ongoing batches are running +2. The pending batch size reaches `max_batch_size` + +This ensures that requests don't wait indefinitely for a batch to fill up, while still allowing efficient batch processing. + diff --git a/examples/code_embedding/main.py b/examples/code_embedding/main.py index 65c3943e..01ab1c53 100644 --- a/examples/code_embedding/main.py +++ b/examples/code_embedding/main.py @@ -16,12 +16,12 @@ def code_to_embedding( Embed the text using a SentenceTransformer model. """ # You can also switch to Voyage embedding model: - # return text.transform( - # cocoindex.functions.EmbedText( - # api_type=cocoindex.LlmApiType.VOYAGE, - # model="voyage-code-3", - # ) - # ) + # return text.transform( + # cocoindex.functions.EmbedText( + # api_type=cocoindex.LlmApiType.GEMINI, + # model="text-embedding-004", + # ) + # ) return text.transform( cocoindex.functions.SentenceTransformerEmbed( model="sentence-transformers/all-MiniLM-L6-v2" diff --git a/python/cocoindex/functions/colpali.py b/python/cocoindex/functions/colpali.py index 35d04e20..eaa31a11 100644 --- a/python/cocoindex/functions/colpali.py +++ b/python/cocoindex/functions/colpali.py @@ -125,6 +125,7 @@ class ColPaliEmbedImage(op.FunctionSpec): gpu=True, cache=True, batching=True, + max_batch_size=32, behavior_version=1, ) class ColPaliEmbedImageExecutor: @@ -204,6 +205,7 @@ class ColPaliEmbedQuery(op.FunctionSpec): cache=True, behavior_version=1, batching=True, + max_batch_size=32, ) class ColPaliEmbedQueryExecutor: """Executor for ColVision query embedding (ColPali, ColQwen2, ColSmol, etc.).""" diff --git a/python/cocoindex/functions/sbert.py b/python/cocoindex/functions/sbert.py index b4d8c3b0..94cfbf1e 100644 --- a/python/cocoindex/functions/sbert.py +++ b/python/cocoindex/functions/sbert.py @@ -31,6 +31,7 @@ class SentenceTransformerEmbed(op.FunctionSpec): gpu=True, cache=True, batching=True, + max_batch_size=512, behavior_version=1, arg_relationship=(op.ArgRelationship.EMBEDDING_ORIGIN_TEXT, "text"), ) diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 781349dd..694b79d1 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -151,6 +151,7 @@ class OpArgs: - gpu: Whether the executor will be executed on GPU. - cache: Whether the executor will be cached. - batching: Whether the executor will be batched. + - max_batch_size: The maximum batch size for the executor. Only valid if `batching` is True. - behavior_version: The behavior version of the executor. Cache will be invalidated if it changes. Must be provided if `cache` is True. - arg_relationship: It specifies the relationship between an input argument and the output, @@ -161,6 +162,7 @@ class OpArgs: gpu: bool = False cache: bool = False batching: bool = False + max_batch_size: int | None = None behavior_version: int | None = None arg_relationship: tuple[ArgRelationship, str] | None = None @@ -389,11 +391,17 @@ def enable_cache(self) -> bool: def behavior_version(self) -> int | None: return op_args.behavior_version + def batching_options(self) -> dict[str, Any] | None: + if op_args.batching: + return { + "max_batch_size": op_args.max_batch_size, + } + else: + return None + if category == OpCategory.FUNCTION: _engine.register_function_factory( - op_kind, - _EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor), - op_args.batching, + op_kind, _EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor) ) else: raise ValueError(f"Unsupported executor type {category}") diff --git a/src/execution/source_indexer.rs b/src/execution/source_indexer.rs index df1f9720..65ff5f04 100644 --- a/src/execution/source_indexer.rs +++ b/src/execution/source_indexer.rs @@ -304,7 +304,10 @@ impl SourceIndexingContext { rows_to_retry, }), setup_execution_ctx, - update_once_batcher: batching::Batcher::new(UpdateOnceRunner), + update_once_batcher: batching::Batcher::new( + UpdateOnceRunner, + batching::BatchingOptions::default(), + ), })) } diff --git a/src/ops/factory_bases.rs b/src/ops/factory_bases.rs index 23739034..9be8ebab 100644 --- a/src/ops/factory_bases.rs +++ b/src/ops/factory_bases.rs @@ -381,6 +381,8 @@ pub trait BatchedFunctionExecutor: Send + Sync + Sized + 'static { fn into_fn_executor(self) -> impl SimpleFunctionExecutor { BatchedFunctionExecutorWrapper::new(self) } + + fn batching_options(&self) -> batching::BatchingOptions; } #[async_trait] @@ -404,10 +406,11 @@ struct BatchedFunctionExecutorWrapper { impl BatchedFunctionExecutorWrapper { fn new(executor: E) -> Self { + let batching_options = executor.batching_options(); Self { enable_cache: executor.enable_cache(), behavior_version: executor.behavior_version(), - batcher: batching::Batcher::new(executor), + batcher: batching::Batcher::new(executor, batching_options), } } } diff --git a/src/ops/functions/embed_text.rs b/src/ops/functions/embed_text.rs index 2efb3fa5..1a1ce735 100644 --- a/src/ops/functions/embed_text.rs +++ b/src/ops/functions/embed_text.rs @@ -36,6 +36,14 @@ impl BatchedFunctionExecutor for Executor { true } + fn batching_options(&self) -> batching::BatchingOptions { + // A safe default for most embeddings providers. + // May tune it for specific providers later. + batching::BatchingOptions { + max_batch_size: Some(64), + } + } + async fn evaluate_batch(&self, args: Vec>) -> Result> { let texts = args .iter() diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index 885797f9..7bcf3b3a 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -121,6 +121,7 @@ struct PyBatchedFunctionExecutor { enable_cache: bool, behavior_version: Option, + batching_options: batching::BatchingOptions, } #[async_trait] @@ -168,11 +169,13 @@ impl BatchedFunctionExecutor for PyBatchedFunctionExecutor { fn behavior_version(&self) -> Option { self.behavior_version } + fn batching_options(&self) -> batching::BatchingOptions { + self.batching_options.clone() + } } pub(crate) struct PyFunctionFactory { pub py_function_factory: Py, - pub batching: bool, } #[async_trait] @@ -237,7 +240,7 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory { .as_ref() .ok_or_else(|| anyhow!("Python execution context is missing"))? .clone(); - let (prepare_fut, enable_cache, behavior_version) = + let (prepare_fut, enable_cache, behavior_version, batching_options) = Python::with_gil(|py| -> anyhow::Result<_> { let prepare_coro = executor .call_method(py, "prepare", (), None) @@ -257,31 +260,45 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory { .call_method(py, "behavior_version", (), None) .to_result_with_py_trace(py)? .extract::>(py)?; - Ok((prepare_fut, enable_cache, behavior_version)) + let batching_options = executor + .call_method(py, "batching_options", (), None) + .to_result_with_py_trace(py)? + .extract::>>( + py, + )? + .into_inner(); + Ok(( + prepare_fut, + enable_cache, + behavior_version, + batching_options, + )) })?; prepare_fut.await?; - let executor: Box = if self.batching { - Box::new( - PyBatchedFunctionExecutor { + let executor: Box = + if let Some(batching_options) = batching_options { + Box::new( + PyBatchedFunctionExecutor { + py_function_executor: executor, + py_exec_ctx, + result_type, + enable_cache, + behavior_version, + batching_options, + } + .into_fn_executor(), + ) + } else { + Box::new(Arc::new(PyFunctionExecutor { py_function_executor: executor, py_exec_ctx, + num_positional_args, + kw_args_names, result_type, enable_cache, behavior_version, - } - .into_fn_executor(), - ) - } else { - Box::new(Arc::new(PyFunctionExecutor { - py_function_executor: executor, - py_exec_ctx, - num_positional_args, - kw_args_names, - result_type, - enable_cache, - behavior_version, - })) - }; + })) + }; Ok(executor) } }; diff --git a/src/py/mod.rs b/src/py/mod.rs index 2a38c9c0..5e95783c 100644 --- a/src/py/mod.rs +++ b/src/py/mod.rs @@ -156,14 +156,9 @@ fn register_source_connector(name: String, py_source_connector: Py) -> Py } #[pyfunction] -fn register_function_factory( - name: String, - py_function_factory: Py, - batching: bool, -) -> PyResult<()> { +fn register_function_factory(name: String, py_function_factory: Py) -> PyResult<()> { let factory = PyFunctionFactory { py_function_factory, - batching, }; register_factory(name, ExecutorFactory::SimpleFunction(Arc::new(factory))).into_py_result() } diff --git a/src/utils/batching.rs b/src/utils/batching.rs index 14a5bfbe..6b9eace4 100644 --- a/src/utils/batching.rs +++ b/src/utils/batching.rs @@ -36,7 +36,10 @@ impl Default for Batch { enum BatcherState { #[default] Idle, - Busy(Option>), + Busy { + pending_batch: Option>, + ongoing_count: usize, + }, } struct BatcherData { @@ -95,6 +98,7 @@ impl BatcherData { pub struct Batcher { data: Arc>, + options: BatchingOptions, } enum BatchExecutionAction { @@ -106,13 +110,19 @@ enum BatchExecutionAction { num_cancelled_tx: watch::Sender, }, } + +#[derive(Default, Clone, Serialize, Deserialize)] +pub struct BatchingOptions { + pub max_batch_size: Option, +} impl Batcher { - pub fn new(runner: R) -> Self { + pub fn new(runner: R, options: BatchingOptions) -> Self { Self { data: Arc::new(BatcherData { runner, state: Mutex::new(BatcherState::Idle), }), + options, } } pub async fn run(&self, input: R::Input) -> Result { @@ -120,19 +130,42 @@ impl Batcher { let mut state = self.data.state.lock().unwrap(); match &mut *state { state @ BatcherState::Idle => { - *state = BatcherState::Busy(None); + *state = BatcherState::Busy { + pending_batch: None, + ongoing_count: 1, + }; BatchExecutionAction::Inline { input } } - BatcherState::Busy(batch) => { - let batch = batch.get_or_insert_default(); + BatcherState::Busy { + pending_batch, + ongoing_count, + } => { + let batch = pending_batch.get_or_insert_default(); batch.inputs.push(input); let (output_tx, output_rx) = oneshot::channel(); batch.output_txs.push(output_tx); + let num_cancelled_tx = batch.num_cancelled_tx.clone(); + + // Check if we've reached max_batch_size and need to flush immediately + let should_flush = self + .options + .max_batch_size + .map(|max_size| batch.inputs.len() >= max_size) + .unwrap_or(false); + + if should_flush { + // Take the batch and trigger execution + let batch_to_run = pending_batch.take().unwrap(); + *ongoing_count += 1; + let data = self.data.clone(); + tokio::spawn(async move { data.run_batch(batch_to_run).await }); + } + BatchExecutionAction::Batched { output_rx, - num_cancelled_tx: batch.num_cancelled_tx.clone(), + num_cancelled_tx, } } } @@ -173,13 +206,33 @@ struct BatchKickOffNext<'a, R: Runner + 'static> { impl<'a, R: Runner + 'static> Drop for BatchKickOffNext<'a, R> { fn drop(&mut self) { let mut state = self.batcher_data.state.lock().unwrap(); - let existing_state = std::mem::take(&mut *state); - let BatcherState::Busy(Some(batch)) = existing_state else { - return; - }; - *state = BatcherState::Busy(None); - let data = self.batcher_data.clone(); - tokio::spawn(async move { data.run_batch(batch).await }); + + match &mut *state { + BatcherState::Idle => { + // Nothing to do, already idle + return; + } + BatcherState::Busy { + pending_batch, + ongoing_count, + } => { + // Decrement the ongoing count first + *ongoing_count -= 1; + + if *ongoing_count == 0 { + // All batches done, check if there's a pending batch + if let Some(batch) = pending_batch.take() { + // Kick off the pending batch and set ongoing_count to 1 + *ongoing_count = 1; + let data = self.batcher_data.clone(); + tokio::spawn(async move { data.run_batch(batch).await }); + } else { + // No pending batch, transition to Idle + *state = BatcherState::Idle; + } + } + } + } } } @@ -263,7 +316,7 @@ mod tests { let runner = TestRunner { recorded_calls: recorded_calls.clone(), }; - let batcher = Arc::new(Batcher::new(runner)); + let batcher = Arc::new(Batcher::new(runner, BatchingOptions::default())); let (n1_tx, n1_rx) = oneshot::channel::<()>(); let (n2_tx, n2_rx) = oneshot::channel::<()>(); @@ -319,4 +372,216 @@ mod tests { Ok(()) } + + #[tokio::test(flavor = "current_thread")] + async fn respects_max_batch_size() -> Result<()> { + let recorded_calls = Arc::new(Mutex::new(Vec::>::new())); + let runner = TestRunner { + recorded_calls: recorded_calls.clone(), + }; + let batcher = Arc::new(Batcher::new( + runner, + BatchingOptions { + max_batch_size: Some(2), + }, + )); + + let (n1_tx, n1_rx) = oneshot::channel::<()>(); + let (n2_tx, n2_rx) = oneshot::channel::<()>(); + let (n3_tx, n3_rx) = oneshot::channel::<()>(); + let (n4_tx, n4_rx) = oneshot::channel::<()>(); + + // Submit first call; it should execute inline and block on n1 + let b1 = batcher.clone(); + let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await }); + + // Wait until the runner has recorded the first inline call + wait_until_len(&recorded_calls, 1).await; + + // Submit second call; it should be batched + let b2 = batcher.clone(); + let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await }); + + // Submit third call; this should trigger a flush because max_batch_size=2 + // The batch [2, 3] should be executed immediately + let b3 = batcher.clone(); + let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await }); + + // Wait for the second batch to be recorded + wait_until_len(&recorded_calls, 2).await; + + // Verify that the second batch was triggered by max_batch_size + { + let calls = recorded_calls.lock().unwrap(); + assert_eq!(calls.len(), 2, "second batch should have started"); + assert_eq!(calls[1], vec![2, 3], "second batch should contain [2, 3]"); + } + + // Submit fourth call; it should wait because there are still ongoing batches + let b4 = batcher.clone(); + let f4 = tokio::spawn(async move { b4.run((4_i64, n4_rx)).await }); + + // Give it a moment to ensure no new batch starts + sleep(Duration::from_millis(50)).await; + { + let len_now = recorded_calls.lock().unwrap().len(); + assert_eq!( + len_now, 2, + "third batch should not start until all ongoing batches complete" + ); + } + + // Unblock the first inline call + let _ = n1_tx.send(()); + + // Wait for first result + let v1 = f1.await??; + assert_eq!(v1, 2); + + // Batch [2,3] is still running, so batch [4] shouldn't start yet + sleep(Duration::from_millis(50)).await; + { + let len_now = recorded_calls.lock().unwrap().len(); + assert_eq!( + len_now, 2, + "third batch should not start until all ongoing batches complete" + ); + } + + // Unblock batch [2,3] - this should trigger batch [4] to start + let _ = n2_tx.send(()); + let _ = n3_tx.send(()); + + let v2 = f2.await??; + let v3 = f3.await??; + assert_eq!(v2, 4); + assert_eq!(v3, 6); + + // Now batch [4] should start since all previous batches are done + wait_until_len(&recorded_calls, 3).await; + + // Unblock batch [4] + let _ = n4_tx.send(()); + let v4 = f4.await??; + assert_eq!(v4, 8); + + // Validate the call recording: [1], [2, 3] (flushed by max_batch_size), [4] + let calls = recorded_calls.lock().unwrap().clone(); + assert_eq!(calls.len(), 3); + assert_eq!(calls[0], vec![1]); + assert_eq!(calls[1], vec![2, 3]); + assert_eq!(calls[2], vec![4]); + + Ok(()) + } + + #[tokio::test(flavor = "current_thread")] + async fn tracks_multiple_concurrent_batches() -> Result<()> { + let recorded_calls = Arc::new(Mutex::new(Vec::>::new())); + let runner = TestRunner { + recorded_calls: recorded_calls.clone(), + }; + let batcher = Arc::new(Batcher::new( + runner, + BatchingOptions { + max_batch_size: Some(2), + }, + )); + + let (n1_tx, n1_rx) = oneshot::channel::<()>(); + let (n2_tx, n2_rx) = oneshot::channel::<()>(); + let (n3_tx, n3_rx) = oneshot::channel::<()>(); + let (n4_tx, n4_rx) = oneshot::channel::<()>(); + let (n5_tx, n5_rx) = oneshot::channel::<()>(); + let (n6_tx, n6_rx) = oneshot::channel::<()>(); + + // Submit first call - executes inline + let b1 = batcher.clone(); + let f1 = tokio::spawn(async move { b1.run((1_i64, n1_rx)).await }); + wait_until_len(&recorded_calls, 1).await; + + // Submit calls 2-3 - should batch and flush at max_batch_size + let b2 = batcher.clone(); + let f2 = tokio::spawn(async move { b2.run((2_i64, n2_rx)).await }); + let b3 = batcher.clone(); + let f3 = tokio::spawn(async move { b3.run((3_i64, n3_rx)).await }); + wait_until_len(&recorded_calls, 2).await; + + // Submit calls 4-5 - should batch and flush at max_batch_size + let b4 = batcher.clone(); + let f4 = tokio::spawn(async move { b4.run((4_i64, n4_rx)).await }); + let b5 = batcher.clone(); + let f5 = tokio::spawn(async move { b5.run((5_i64, n5_rx)).await }); + wait_until_len(&recorded_calls, 3).await; + + // Submit call 6 - should be batched but not flushed yet + let b6 = batcher.clone(); + let f6 = tokio::spawn(async move { b6.run((6_i64, n6_rx)).await }); + + // Give it a moment to ensure no new batch starts + sleep(Duration::from_millis(50)).await; + { + let len_now = recorded_calls.lock().unwrap().len(); + assert_eq!( + len_now, 3, + "fourth batch should not start with ongoing batches" + ); + } + + // Unblock batch [2, 3] - should not cause [6] to execute yet (batch 1 still ongoing) + let _ = n2_tx.send(()); + let _ = n3_tx.send(()); + let v2 = f2.await??; + let v3 = f3.await??; + assert_eq!(v2, 4); + assert_eq!(v3, 6); + + sleep(Duration::from_millis(50)).await; + { + let len_now = recorded_calls.lock().unwrap().len(); + assert_eq!( + len_now, 3, + "batch [6] should still not start (batch 1 and batch [4,5] still ongoing)" + ); + } + + // Unblock batch [4, 5] - should not cause [6] to execute yet (batch 1 still ongoing) + let _ = n4_tx.send(()); + let _ = n5_tx.send(()); + let v4 = f4.await??; + let v5 = f5.await??; + assert_eq!(v4, 8); + assert_eq!(v5, 10); + + sleep(Duration::from_millis(50)).await; + { + let len_now = recorded_calls.lock().unwrap().len(); + assert_eq!( + len_now, 3, + "batch [6] should still not start (batch 1 still ongoing)" + ); + } + + // Unblock batch 1 - NOW batch [6] should start + let _ = n1_tx.send(()); + let v1 = f1.await??; + assert_eq!(v1, 2); + + wait_until_len(&recorded_calls, 4).await; + + // Unblock batch [6] + let _ = n6_tx.send(()); + let v6 = f6.await??; + assert_eq!(v6, 12); + + // Validate the call recording + let calls = recorded_calls.lock().unwrap().clone(); + assert_eq!(calls.len(), 4); + assert_eq!(calls[0], vec![1]); + assert_eq!(calls[1], vec![2, 3]); + assert_eq!(calls[2], vec![4, 5]); + assert_eq!(calls[3], vec![6]); + + Ok(()) + } }