Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/docs/custom_ops/custom_functions.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ Custom functions take the following additional parameters:
* `batching: bool`: Whether the executor will consume requests in batch.
See the [Batching](#batching) section below for details.

* `max_batch_size: int | None`: The maximum batch size for the executor.

* `behavior_version: int`: The version of the behavior of the function.
When the version is changed, the function will be re-executed even if cache is enabled.
It's required to be set if `cache` is `True`.
Expand Down Expand Up @@ -221,5 +223,25 @@ class ComputeSomethingExecutor:
...
```

### Controlling Batch Size

You can control the maximum batch size using the `max_batch_size` parameter. This is useful for:
* Limiting memory usage when processing large batches
* Reducing latency by flushing batches before they grow too large
* Working with APIs that have request size limits

```python
@cocoindex.op.function(batching=True, max_batch_size=32)
def compute_something(args: list[str]) -> list[str]:
...
```

With `max_batch_size` set, a batch will be flushed when either:

1. No ongoing batches are running
2. The pending batch size reaches `max_batch_size`

This ensures that requests don't wait indefinitely for a batch to fill up, while still allowing efficient batch processing.

</TabItem>
</Tabs>
12 changes: 6 additions & 6 deletions examples/code_embedding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def code_to_embedding(
Embed the text using a SentenceTransformer model.
"""
# You can also switch to Voyage embedding model:
# return text.transform(
# cocoindex.functions.EmbedText(
# api_type=cocoindex.LlmApiType.VOYAGE,
# model="voyage-code-3",
# )
# )
# return text.transform(
# cocoindex.functions.EmbedText(
# api_type=cocoindex.LlmApiType.GEMINI,
# model="text-embedding-004",
# )
# )
return text.transform(
cocoindex.functions.SentenceTransformerEmbed(
model="sentence-transformers/all-MiniLM-L6-v2"
Expand Down
2 changes: 2 additions & 0 deletions python/cocoindex/functions/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class ColPaliEmbedImage(op.FunctionSpec):
gpu=True,
cache=True,
batching=True,
max_batch_size=32,
behavior_version=1,
)
class ColPaliEmbedImageExecutor:
Expand Down Expand Up @@ -204,6 +205,7 @@ class ColPaliEmbedQuery(op.FunctionSpec):
cache=True,
behavior_version=1,
batching=True,
max_batch_size=32,
)
class ColPaliEmbedQueryExecutor:
"""Executor for ColVision query embedding (ColPali, ColQwen2, ColSmol, etc.)."""
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/functions/sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SentenceTransformerEmbed(op.FunctionSpec):
gpu=True,
cache=True,
batching=True,
max_batch_size=512,
behavior_version=1,
arg_relationship=(op.ArgRelationship.EMBEDDING_ORIGIN_TEXT, "text"),
)
Expand Down
14 changes: 11 additions & 3 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class OpArgs:
- gpu: Whether the executor will be executed on GPU.
- cache: Whether the executor will be cached.
- batching: Whether the executor will be batched.
- max_batch_size: The maximum batch size for the executor. Only valid if `batching` is True.
- behavior_version: The behavior version of the executor. Cache will be invalidated if it
changes. Must be provided if `cache` is True.
- arg_relationship: It specifies the relationship between an input argument and the output,
Expand All @@ -161,6 +162,7 @@ class OpArgs:
gpu: bool = False
cache: bool = False
batching: bool = False
max_batch_size: int | None = None
behavior_version: int | None = None
arg_relationship: tuple[ArgRelationship, str] | None = None

Expand Down Expand Up @@ -389,11 +391,17 @@ def enable_cache(self) -> bool:
def behavior_version(self) -> int | None:
return op_args.behavior_version

def batching_options(self) -> dict[str, Any] | None:
if op_args.batching:
return {
"max_batch_size": op_args.max_batch_size,
}
else:
return None

if category == OpCategory.FUNCTION:
_engine.register_function_factory(
op_kind,
_EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor),
op_args.batching,
op_kind, _EngineFunctionExecutorFactory(spec_loader, _WrappedExecutor)
)
else:
raise ValueError(f"Unsupported executor type {category}")
Expand Down
5 changes: 4 additions & 1 deletion src/execution/source_indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,10 @@ impl SourceIndexingContext {
rows_to_retry,
}),
setup_execution_ctx,
update_once_batcher: batching::Batcher::new(UpdateOnceRunner),
update_once_batcher: batching::Batcher::new(
UpdateOnceRunner,
batching::BatchingOptions::default(),
),
}))
}

Expand Down
5 changes: 4 additions & 1 deletion src/ops/factory_bases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ pub trait BatchedFunctionExecutor: Send + Sync + Sized + 'static {
fn into_fn_executor(self) -> impl SimpleFunctionExecutor {
BatchedFunctionExecutorWrapper::new(self)
}

fn batching_options(&self) -> batching::BatchingOptions;
}

#[async_trait]
Expand All @@ -404,10 +406,11 @@ struct BatchedFunctionExecutorWrapper<E: BatchedFunctionExecutor> {

impl<E: BatchedFunctionExecutor> BatchedFunctionExecutorWrapper<E> {
fn new(executor: E) -> Self {
let batching_options = executor.batching_options();
Self {
enable_cache: executor.enable_cache(),
behavior_version: executor.behavior_version(),
batcher: batching::Batcher::new(executor),
batcher: batching::Batcher::new(executor, batching_options),
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/ops/functions/embed_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ impl BatchedFunctionExecutor for Executor {
true
}

fn batching_options(&self) -> batching::BatchingOptions {
// A safe default for most embeddings providers.
// May tune it for specific providers later.
batching::BatchingOptions {
max_batch_size: Some(64),
}
}

async fn evaluate_batch(&self, args: Vec<Vec<Value>>) -> Result<Vec<Value>> {
let texts = args
.iter()
Expand Down
57 changes: 37 additions & 20 deletions src/ops/py_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ struct PyBatchedFunctionExecutor {

enable_cache: bool,
behavior_version: Option<u32>,
batching_options: batching::BatchingOptions,
}

#[async_trait]
Expand Down Expand Up @@ -168,11 +169,13 @@ impl BatchedFunctionExecutor for PyBatchedFunctionExecutor {
fn behavior_version(&self) -> Option<u32> {
self.behavior_version
}
fn batching_options(&self) -> batching::BatchingOptions {
self.batching_options.clone()
}
}

pub(crate) struct PyFunctionFactory {
pub py_function_factory: Py<PyAny>,
pub batching: bool,
}

#[async_trait]
Expand Down Expand Up @@ -237,7 +240,7 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory {
.as_ref()
.ok_or_else(|| anyhow!("Python execution context is missing"))?
.clone();
let (prepare_fut, enable_cache, behavior_version) =
let (prepare_fut, enable_cache, behavior_version, batching_options) =
Python::with_gil(|py| -> anyhow::Result<_> {
let prepare_coro = executor
.call_method(py, "prepare", (), None)
Expand All @@ -257,31 +260,45 @@ impl interface::SimpleFunctionFactory for PyFunctionFactory {
.call_method(py, "behavior_version", (), None)
.to_result_with_py_trace(py)?
.extract::<Option<u32>>(py)?;
Ok((prepare_fut, enable_cache, behavior_version))
let batching_options = executor
.call_method(py, "batching_options", (), None)
.to_result_with_py_trace(py)?
.extract::<crate::py::Pythonized<Option<batching::BatchingOptions>>>(
py,
)?
.into_inner();
Ok((
prepare_fut,
enable_cache,
behavior_version,
batching_options,
))
})?;
prepare_fut.await?;
let executor: Box<dyn interface::SimpleFunctionExecutor> = if self.batching {
Box::new(
PyBatchedFunctionExecutor {
let executor: Box<dyn interface::SimpleFunctionExecutor> =
if let Some(batching_options) = batching_options {
Box::new(
PyBatchedFunctionExecutor {
py_function_executor: executor,
py_exec_ctx,
result_type,
enable_cache,
behavior_version,
batching_options,
}
.into_fn_executor(),
)
} else {
Box::new(Arc::new(PyFunctionExecutor {
py_function_executor: executor,
py_exec_ctx,
num_positional_args,
kw_args_names,
result_type,
enable_cache,
behavior_version,
}
.into_fn_executor(),
)
} else {
Box::new(Arc::new(PyFunctionExecutor {
py_function_executor: executor,
py_exec_ctx,
num_positional_args,
kw_args_names,
result_type,
enable_cache,
behavior_version,
}))
};
}))
};
Ok(executor)
}
};
Expand Down
7 changes: 1 addition & 6 deletions src/py/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,9 @@ fn register_source_connector(name: String, py_source_connector: Py<PyAny>) -> Py
}

#[pyfunction]
fn register_function_factory(
name: String,
py_function_factory: Py<PyAny>,
batching: bool,
) -> PyResult<()> {
fn register_function_factory(name: String, py_function_factory: Py<PyAny>) -> PyResult<()> {
let factory = PyFunctionFactory {
py_function_factory,
batching,
};
register_factory(name, ExecutorFactory::SimpleFunction(Arc::new(factory))).into_py_result()
}
Expand Down
Loading
Loading