From d0bc28f78c4f867c9e391f527602333f6e1e5f16 Mon Sep 17 00:00:00 2001 From: LJ Date: Thu, 13 Mar 2025 16:03:21 -0700 Subject: [PATCH] Update `SplitRecursively` to take language and chunk sizes dynamically. --- examples/code_embedding/code_embedding.py | 4 +- examples/pdf_embedding/pdf_embedding.py | 4 +- examples/text_embedding/text_embedding.py | 8 +- python/cocoindex/convert.py | 13 +++ python/cocoindex/flow.py | 22 +++- python/cocoindex/functions.py | 3 - python/cocoindex/op.py | 11 +- src/ops/functions/split_recursively.rs | 136 +++++++++++++--------- 8 files changed, 115 insertions(+), 86 deletions(-) create mode 100644 python/cocoindex/convert.py diff --git a/examples/code_embedding/code_embedding.py b/examples/code_embedding/code_embedding.py index 9b07d216..80b4b288 100644 --- a/examples/code_embedding/code_embedding.py +++ b/examples/code_embedding/code_embedding.py @@ -21,8 +21,8 @@ def code_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind with data_scope["files"].row() as file: file["chunks"] = file["content"].transform( - cocoindex.functions.SplitRecursively( - language="javascript", chunk_size=300, chunk_overlap=100)) + cocoindex.functions.SplitRecursively(), + language="javascript", chunk_size=300, chunk_overlap=100) with file["chunks"].row() as chunk: chunk["embedding"] = chunk["text"].call(code_to_embedding) code_embeddings.collect(filename=file["filename"], location=chunk["location"], diff --git a/examples/pdf_embedding/pdf_embedding.py b/examples/pdf_embedding/pdf_embedding.py index 8d6d5c63..ae0833ae 100644 --- a/examples/pdf_embedding/pdf_embedding.py +++ b/examples/pdf_embedding/pdf_embedding.py @@ -50,8 +50,8 @@ def pdf_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoinde with data_scope["documents"].row() as doc: doc["markdown"] = doc["content"].transform(PdfToMarkdown()) doc["chunks"] = doc["markdown"].transform( - cocoindex.functions.SplitRecursively( - language="markdown", chunk_size=300, chunk_overlap=100)) + cocoindex.functions.SplitRecursively(), + language="markdown", chunk_size=300, chunk_overlap=100) with doc["chunks"].row() as chunk: chunk["embedding"] = chunk["text"].call(text_to_embedding) diff --git a/examples/text_embedding/text_embedding.py b/examples/text_embedding/text_embedding.py index 510340f7..7a05bbcf 100644 --- a/examples/text_embedding/text_embedding.py +++ b/examples/text_embedding/text_embedding.py @@ -23,12 +23,8 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind with data_scope["documents"].row() as doc: doc["chunks"] = doc["content"].transform( - cocoindex.functions.SplitRecursively( - language="markdown", chunk_size=300, chunk_overlap=100)) - - doc["chunks"] = flow_builder.call( - cocoindex.functions.SplitRecursively(), - doc["content"], language="markdown", chunk_size=300, chunk_overlap=100); + cocoindex.functions.SplitRecursively(), + language="markdown", chunk_size=300, chunk_overlap=100) with doc["chunks"].row() as chunk: chunk["embedding"] = text_to_embedding(chunk["text"]) diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py new file mode 100644 index 00000000..f5701c10 --- /dev/null +++ b/python/cocoindex/convert.py @@ -0,0 +1,13 @@ +""" +Utilities to convert between Python and engine values. +""" +import dataclasses +from typing import Any + +def to_engine_value(value: Any) -> Any: + """Convert a Python value to an engine value.""" + if dataclasses.is_dataclass(value): + return [to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)] + if isinstance(value, (list, tuple)): + return [to_engine_value(v) for v in value] + return value diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index 2e3f0f5f..163605eb 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -162,11 +162,15 @@ def for_each(self, f: Callable[[DataScope], None]) -> None: with self.row() as scope: f(scope) - def transform(self, fn_spec: op.FunctionSpec, /, name: str | None = None) -> DataSlice: + def transform(self, fn_spec: op.FunctionSpec, *args, **kwargs) -> DataSlice: """ Apply a function to the data slice. """ - args = [(self._state.engine_data_slice, None)] + transform_args = [(self._state.engine_data_slice, None)] + transform_args += [(self._state.flow_builder_state.get_data_slice(v), None) for v in args] + transform_args += [(self._state.flow_builder_state.get_data_slice(v), k) + for (k, v) in kwargs.items()] + flow_builder_state = self._state.flow_builder_state return _create_data_slice( flow_builder_state, @@ -174,12 +178,11 @@ def transform(self, fn_spec: op.FunctionSpec, /, name: str | None = None) -> Dat flow_builder_state.engine_flow_builder.transform( _spec_kind(fn_spec), _spec_value_dump(fn_spec), - args, + transform_args, target_scope, flow_builder_state.field_name_builder.build_name( name, prefix=_to_snake_case(_spec_kind(fn_spec))+'_'), - ), - name) + )) def call(self, func: Callable[[DataSlice], T]) -> T: """ @@ -282,6 +285,14 @@ def __init__(self, /, name: str | None = None): self.engine_flow_builder = _engine.FlowBuilder(flow_name) self.field_name_builder = _NameBuilder() + def get_data_slice(self, v: Any) -> _engine.DataSlice: + """ + Return a data slice that represents the given value. + """ + if isinstance(v, DataSlice): + return v._state.engine_data_slice + return self.engine_flow_builder.constant(encode_enriched_type(type(v)), v) + class FlowBuilder: """ A flow builder is used to build a flow. @@ -313,7 +324,6 @@ def add_source(self, spec: op.SourceSpec, /, name: str | None = None) -> DataSli name ) - class Flow: """ A flow describes an indexing pipeline. diff --git a/python/cocoindex/functions.py b/python/cocoindex/functions.py index b0753893..fa21b91a 100644 --- a/python/cocoindex/functions.py +++ b/python/cocoindex/functions.py @@ -7,9 +7,6 @@ class SplitRecursively(op.FunctionSpec): """Split a document (in string) recursively.""" - chunk_size: int - chunk_overlap: int - language: str | None = None class ExtractByLlm(op.FunctionSpec): """Extract information from a text using a LLM.""" diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 1f625255..34892197 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -9,6 +9,7 @@ from threading import Lock from .typing import encode_enriched_type, analyze_type_info, COLLECTION_TYPES +from .convert import to_engine_value from . import _engine @@ -59,14 +60,6 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs): result_type = executor.analyze(*args, **kwargs) return (encode_enriched_type(result_type), executor) -def _to_engine_value(value: Any) -> Any: - """Convert a Python value to an engine value.""" - if dataclasses.is_dataclass(value): - return [_to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)] - if isinstance(value, (list, tuple)): - return [_to_engine_value(v) for v in value] - return value - def _make_engine_struct_value_converter( field_path: list[str], src_fields: list[dict[str, Any]], @@ -251,7 +244,7 @@ def __call__(self, *args, **kwargs): output = super().__call__(*converted_args, **converted_kwargs) else: output = super().__call__(*converted_args, **converted_kwargs) - return _to_engine_value(output) + return to_engine_value(output) _WrappedClass.__name__ = cls.__name__ diff --git a/src/ops/functions/split_recursively.rs b/src/ops/functions/split_recursively.rs index 798f876f..35a302bc 100644 --- a/src/ops/functions/split_recursively.rs +++ b/src/ops/functions/split_recursively.rs @@ -5,19 +5,13 @@ use std::{collections::HashMap, sync::Arc}; use crate::base::field_attrs; use crate::{fields_value, ops::sdk::*}; -#[derive(Debug, Deserialize)] -pub struct Spec { - #[serde(default)] - language: Option, - - chunk_size: usize, - - #[serde(default)] - chunk_overlap: usize, -} +type Spec = EmptySpec; pub struct Args { text: ResolvedOpArg, + chunk_size: ResolvedOpArg, + chunk_overlap: Option, + language: Option, } static DEFAULT_SEPARATORS: LazyLock> = LazyLock::new(|| { @@ -97,36 +91,13 @@ static SEPARATORS_BY_LANG: LazyLock>> = LazyLoc .collect() }); -struct Executor { - spec: Spec, - args: Args, +struct SplitTask { separators: &'static [Regex], + chunk_size: usize, + chunk_overlap: usize, } -impl Executor { - fn new(spec: Spec, args: Args) -> Result { - let separators = spec - .language - .as_ref() - .and_then(|lang| { - SEPARATORS_BY_LANG - .get(lang.to_lowercase().as_str()) - .map(|v| v.as_slice()) - }) - .unwrap_or(DEFAULT_SEPARATORS.as_slice()); - Ok(Self { - spec, - args, - separators, - }) - } - - fn add_output<'s>(pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) { - if !text.trim().is_empty() { - output.push((RangeValue::new(pos, pos + text.len()), text)); - } - } - +impl SplitTask { fn split_substring<'s>( &self, s: &'s str, @@ -135,7 +106,7 @@ impl Executor { output: &mut Vec<(RangeValue, &'s str)>, ) { if next_sep_id >= self.separators.len() { - Self::add_output(base_pos, s, output); + self.add_output(base_pos, s, output); return; } @@ -147,8 +118,8 @@ impl Executor { let mut start_pos = chunks[0].start; for i in 1..chunks.len() - 1 { let chunk = &chunks[i]; - if chunk.end - start_pos > self.spec.chunk_size { - Self::add_output(base_pos + start_pos, &s[start_pos..chunk.end], output); + if chunk.end - start_pos > self.chunk_size { + self.add_output(base_pos + start_pos, &s[start_pos..chunk.end], output); // Find the new start position, allowing overlap within the threshold. let mut new_start_idx = i + 1; @@ -156,8 +127,8 @@ impl Executor { while new_start_idx > 0 { let prev_pos = chunks[new_start_idx - 1].start; if prev_pos <= start_pos - || chunk.end - prev_pos > self.spec.chunk_overlap - || next_chunk.end - prev_pos > self.spec.chunk_size + || chunk.end - prev_pos > self.chunk_overlap + || next_chunk.end - prev_pos > self.chunk_size { break; } @@ -168,32 +139,49 @@ impl Executor { } let last_chunk = &chunks[chunks.len() - 1]; - Self::add_output(base_pos + start_pos, &s[start_pos..last_chunk.end], output); + self.add_output(base_pos + start_pos, &s[start_pos..last_chunk.end], output); }; let mut small_chunks = Vec::new(); - let mut process_chunk = |start: usize, end: usize| { - let chunk = &s[start..end]; - if chunk.len() <= self.spec.chunk_size { - small_chunks.push(RangeValue::new(start, start + chunk.len())); - } else { - flush_small_chunks(&small_chunks, output); - small_chunks.clear(); - self.split_substring(chunk, base_pos + start, next_sep_id + 1, output); - } - }; + let mut process_chunk = + |start: usize, end: usize, output: &mut Vec<(RangeValue, &'s str)>| { + let chunk = &s[start..end]; + if chunk.len() <= self.chunk_size { + small_chunks.push(RangeValue::new(start, start + chunk.len())); + } else { + flush_small_chunks(&small_chunks, output); + small_chunks.clear(); + self.split_substring(chunk, base_pos + start, next_sep_id + 1, output); + } + }; let mut next_start_pos = 0; for cap in self.separators[next_sep_id].find_iter(s) { - process_chunk(next_start_pos, cap.start()); + process_chunk(next_start_pos, cap.start(), output); next_start_pos = cap.end(); } if next_start_pos < s.len() { - process_chunk(next_start_pos, s.len()); + process_chunk(next_start_pos, s.len(), output); } flush_small_chunks(&small_chunks, output); } + + fn add_output<'s>(&self, pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) { + if !text.trim().is_empty() { + output.push((RangeValue::new(pos, pos + text.len()), text)); + } + } +} + +struct Executor { + args: Args, +} + +impl Executor { + fn new(args: Args) -> Result { + Ok(Self { args }) + } } fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator) { @@ -229,9 +217,32 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator) -> Result { + let task = SplitTask { + separators: self + .args + .language + .value(&input)? + .map(|v| v.as_str()) + .transpose()? + .and_then(|lang| { + SEPARATORS_BY_LANG + .get(lang.to_lowercase().as_str()) + .map(|v| v.as_slice()) + }) + .unwrap_or(DEFAULT_SEPARATORS.as_slice()), + chunk_size: self.args.chunk_size.value(&input)?.as_int64()? as usize, + chunk_overlap: self + .args + .chunk_overlap + .value(&input)? + .map(|v| v.as_int64()) + .transpose()? + .unwrap_or(0) as usize, + }; + let text = self.args.text.value(&input)?.as_str()?; let mut output = Vec::new(); - self.split_substring(text, 0, 0, &mut output); + task.split_substring(text, 0, 0, &mut output); translate_bytes_to_chars( text, @@ -271,6 +282,15 @@ impl SimpleFunctionFactoryBase for Factory { text: args_resolver .next_arg("text")? .expect_type(&ValueType::Basic(BasicValueType::Str))?, + chunk_size: args_resolver + .next_arg("chunk_size")? + .expect_type(&ValueType::Basic(BasicValueType::Int64))?, + chunk_overlap: args_resolver + .next_optional_arg("chunk_overlap")? + .expect_type(&ValueType::Basic(BasicValueType::Int64))?, + language: args_resolver + .next_optional_arg("language")? + .expect_type(&ValueType::Basic(BasicValueType::Str))?, }; let output_schema = make_output_type(CollectionSchema::new( CollectionKind::Table, @@ -288,10 +308,10 @@ impl SimpleFunctionFactoryBase for Factory { async fn build_executor( self: Arc, - spec: Spec, + _spec: Spec, args: Args, _context: Arc, ) -> Result> { - Ok(Box::new(Executor::new(spec, args)?)) + Ok(Box::new(Executor::new(args)?)) } }