Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/code_embedding/code_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def code_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind

with data_scope["files"].row() as file:
file["chunks"] = file["content"].transform(
cocoindex.functions.SplitRecursively(
language="javascript", chunk_size=300, chunk_overlap=100))
cocoindex.functions.SplitRecursively(),
language="javascript", chunk_size=300, chunk_overlap=100)
with file["chunks"].row() as chunk:
chunk["embedding"] = chunk["text"].call(code_to_embedding)
code_embeddings.collect(filename=file["filename"], location=chunk["location"],
Expand Down
4 changes: 2 additions & 2 deletions examples/pdf_embedding/pdf_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def pdf_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoinde
with data_scope["documents"].row() as doc:
doc["markdown"] = doc["content"].transform(PdfToMarkdown())
doc["chunks"] = doc["markdown"].transform(
cocoindex.functions.SplitRecursively(
language="markdown", chunk_size=300, chunk_overlap=100))
cocoindex.functions.SplitRecursively(),
language="markdown", chunk_size=300, chunk_overlap=100)

with doc["chunks"].row() as chunk:
chunk["embedding"] = chunk["text"].call(text_to_embedding)
Expand Down
8 changes: 2 additions & 6 deletions examples/text_embedding/text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind

with data_scope["documents"].row() as doc:
doc["chunks"] = doc["content"].transform(
cocoindex.functions.SplitRecursively(
language="markdown", chunk_size=300, chunk_overlap=100))

doc["chunks"] = flow_builder.call(
cocoindex.functions.SplitRecursively(),
doc["content"], language="markdown", chunk_size=300, chunk_overlap=100);
cocoindex.functions.SplitRecursively(),
language="markdown", chunk_size=300, chunk_overlap=100)

with doc["chunks"].row() as chunk:
chunk["embedding"] = text_to_embedding(chunk["text"])
Expand Down
13 changes: 13 additions & 0 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Utilities to convert between Python and engine values.
"""
import dataclasses
from typing import Any

def to_engine_value(value: Any) -> Any:
"""Convert a Python value to an engine value."""
if dataclasses.is_dataclass(value):
return [to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
if isinstance(value, (list, tuple)):
return [to_engine_value(v) for v in value]
return value
22 changes: 16 additions & 6 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,27 @@ def for_each(self, f: Callable[[DataScope], None]) -> None:
with self.row() as scope:
f(scope)

def transform(self, fn_spec: op.FunctionSpec, /, name: str | None = None) -> DataSlice:
def transform(self, fn_spec: op.FunctionSpec, *args, **kwargs) -> DataSlice:
"""
Apply a function to the data slice.
"""
args = [(self._state.engine_data_slice, None)]
transform_args = [(self._state.engine_data_slice, None)]
transform_args += [(self._state.flow_builder_state.get_data_slice(v), None) for v in args]
transform_args += [(self._state.flow_builder_state.get_data_slice(v), k)
for (k, v) in kwargs.items()]

flow_builder_state = self._state.flow_builder_state
return _create_data_slice(
flow_builder_state,
lambda target_scope, name:
flow_builder_state.engine_flow_builder.transform(
_spec_kind(fn_spec),
_spec_value_dump(fn_spec),
args,
transform_args,
target_scope,
flow_builder_state.field_name_builder.build_name(
name, prefix=_to_snake_case(_spec_kind(fn_spec))+'_'),
),
name)
))

def call(self, func: Callable[[DataSlice], T]) -> T:
"""
Expand Down Expand Up @@ -282,6 +285,14 @@ def __init__(self, /, name: str | None = None):
self.engine_flow_builder = _engine.FlowBuilder(flow_name)
self.field_name_builder = _NameBuilder()

def get_data_slice(self, v: Any) -> _engine.DataSlice:
"""
Return a data slice that represents the given value.
"""
if isinstance(v, DataSlice):
return v._state.engine_data_slice
return self.engine_flow_builder.constant(encode_enriched_type(type(v)), v)

class FlowBuilder:
"""
A flow builder is used to build a flow.
Expand Down Expand Up @@ -313,7 +324,6 @@ def add_source(self, spec: op.SourceSpec, /, name: str | None = None) -> DataSli
name
)


class Flow:
"""
A flow describes an indexing pipeline.
Expand Down
3 changes: 0 additions & 3 deletions python/cocoindex/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

class SplitRecursively(op.FunctionSpec):
"""Split a document (in string) recursively."""
chunk_size: int
chunk_overlap: int
language: str | None = None

class ExtractByLlm(op.FunctionSpec):
"""Extract information from a text using a LLM."""
Expand Down
11 changes: 2 additions & 9 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from threading import Lock

from .typing import encode_enriched_type, analyze_type_info, COLLECTION_TYPES
from .convert import to_engine_value
from . import _engine


Expand Down Expand Up @@ -59,14 +60,6 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
result_type = executor.analyze(*args, **kwargs)
return (encode_enriched_type(result_type), executor)

def _to_engine_value(value: Any) -> Any:
"""Convert a Python value to an engine value."""
if dataclasses.is_dataclass(value):
return [_to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
if isinstance(value, (list, tuple)):
return [_to_engine_value(v) for v in value]
return value

def _make_engine_struct_value_converter(
field_path: list[str],
src_fields: list[dict[str, Any]],
Expand Down Expand Up @@ -251,7 +244,7 @@ def __call__(self, *args, **kwargs):
output = super().__call__(*converted_args, **converted_kwargs)
else:
output = super().__call__(*converted_args, **converted_kwargs)
return _to_engine_value(output)
return to_engine_value(output)

_WrappedClass.__name__ = cls.__name__

Expand Down
136 changes: 78 additions & 58 deletions src/ops/functions/split_recursively.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,13 @@ use std::{collections::HashMap, sync::Arc};
use crate::base::field_attrs;
use crate::{fields_value, ops::sdk::*};

#[derive(Debug, Deserialize)]
pub struct Spec {
#[serde(default)]
language: Option<String>,

chunk_size: usize,

#[serde(default)]
chunk_overlap: usize,
}
type Spec = EmptySpec;

pub struct Args {
text: ResolvedOpArg,
chunk_size: ResolvedOpArg,
chunk_overlap: Option<ResolvedOpArg>,
language: Option<ResolvedOpArg>,
}

static DEFAULT_SEPARATORS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
Expand Down Expand Up @@ -97,36 +91,13 @@ static SEPARATORS_BY_LANG: LazyLock<HashMap<&'static str, Vec<Regex>>> = LazyLoc
.collect()
});

struct Executor {
spec: Spec,
args: Args,
struct SplitTask {
separators: &'static [Regex],
chunk_size: usize,
chunk_overlap: usize,
}

impl Executor {
fn new(spec: Spec, args: Args) -> Result<Self> {
let separators = spec
.language
.as_ref()
.and_then(|lang| {
SEPARATORS_BY_LANG
.get(lang.to_lowercase().as_str())
.map(|v| v.as_slice())
})
.unwrap_or(DEFAULT_SEPARATORS.as_slice());
Ok(Self {
spec,
args,
separators,
})
}

fn add_output<'s>(pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) {
if !text.trim().is_empty() {
output.push((RangeValue::new(pos, pos + text.len()), text));
}
}

impl SplitTask {
fn split_substring<'s>(
&self,
s: &'s str,
Expand All @@ -135,7 +106,7 @@ impl Executor {
output: &mut Vec<(RangeValue, &'s str)>,
) {
if next_sep_id >= self.separators.len() {
Self::add_output(base_pos, s, output);
self.add_output(base_pos, s, output);
return;
}

Expand All @@ -147,17 +118,17 @@ impl Executor {
let mut start_pos = chunks[0].start;
for i in 1..chunks.len() - 1 {
let chunk = &chunks[i];
if chunk.end - start_pos > self.spec.chunk_size {
Self::add_output(base_pos + start_pos, &s[start_pos..chunk.end], output);
if chunk.end - start_pos > self.chunk_size {
self.add_output(base_pos + start_pos, &s[start_pos..chunk.end], output);

// Find the new start position, allowing overlap within the threshold.
let mut new_start_idx = i + 1;
let next_chunk = &chunks[i + 1];
while new_start_idx > 0 {
let prev_pos = chunks[new_start_idx - 1].start;
if prev_pos <= start_pos
|| chunk.end - prev_pos > self.spec.chunk_overlap
|| next_chunk.end - prev_pos > self.spec.chunk_size
|| chunk.end - prev_pos > self.chunk_overlap
|| next_chunk.end - prev_pos > self.chunk_size
{
break;
}
Expand All @@ -168,32 +139,49 @@ impl Executor {
}

let last_chunk = &chunks[chunks.len() - 1];
Self::add_output(base_pos + start_pos, &s[start_pos..last_chunk.end], output);
self.add_output(base_pos + start_pos, &s[start_pos..last_chunk.end], output);
};

let mut small_chunks = Vec::new();
let mut process_chunk = |start: usize, end: usize| {
let chunk = &s[start..end];
if chunk.len() <= self.spec.chunk_size {
small_chunks.push(RangeValue::new(start, start + chunk.len()));
} else {
flush_small_chunks(&small_chunks, output);
small_chunks.clear();
self.split_substring(chunk, base_pos + start, next_sep_id + 1, output);
}
};
let mut process_chunk =
|start: usize, end: usize, output: &mut Vec<(RangeValue, &'s str)>| {
let chunk = &s[start..end];
if chunk.len() <= self.chunk_size {
small_chunks.push(RangeValue::new(start, start + chunk.len()));
} else {
flush_small_chunks(&small_chunks, output);
small_chunks.clear();
self.split_substring(chunk, base_pos + start, next_sep_id + 1, output);
}
};

let mut next_start_pos = 0;
for cap in self.separators[next_sep_id].find_iter(s) {
process_chunk(next_start_pos, cap.start());
process_chunk(next_start_pos, cap.start(), output);
next_start_pos = cap.end();
}
if next_start_pos < s.len() {
process_chunk(next_start_pos, s.len());
process_chunk(next_start_pos, s.len(), output);
}

flush_small_chunks(&small_chunks, output);
}

fn add_output<'s>(&self, pos: usize, text: &'s str, output: &mut Vec<(RangeValue, &'s str)>) {
if !text.trim().is_empty() {
output.push((RangeValue::new(pos, pos + text.len()), text));
}
}
}

struct Executor {
args: Args,
}

impl Executor {
fn new(args: Args) -> Result<Self> {
Ok(Self { args })
}
}

fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator<Item = &'a mut usize>) {
Expand Down Expand Up @@ -229,9 +217,32 @@ fn translate_bytes_to_chars<'a>(text: &str, offsets: impl Iterator<Item = &'a mu
#[async_trait]
impl SimpleFunctionExecutor for Executor {
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
let task = SplitTask {
separators: self
.args
.language
.value(&input)?
.map(|v| v.as_str())
.transpose()?
.and_then(|lang| {
SEPARATORS_BY_LANG
.get(lang.to_lowercase().as_str())
.map(|v| v.as_slice())
})
.unwrap_or(DEFAULT_SEPARATORS.as_slice()),
chunk_size: self.args.chunk_size.value(&input)?.as_int64()? as usize,
chunk_overlap: self
.args
.chunk_overlap
.value(&input)?
.map(|v| v.as_int64())
.transpose()?
.unwrap_or(0) as usize,
};

let text = self.args.text.value(&input)?.as_str()?;
let mut output = Vec::new();
self.split_substring(text, 0, 0, &mut output);
task.split_substring(text, 0, 0, &mut output);

translate_bytes_to_chars(
text,
Expand Down Expand Up @@ -271,6 +282,15 @@ impl SimpleFunctionFactoryBase for Factory {
text: args_resolver
.next_arg("text")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
chunk_size: args_resolver
.next_arg("chunk_size")?
.expect_type(&ValueType::Basic(BasicValueType::Int64))?,
chunk_overlap: args_resolver
.next_optional_arg("chunk_overlap")?
.expect_type(&ValueType::Basic(BasicValueType::Int64))?,
language: args_resolver
.next_optional_arg("language")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?,
};
let output_schema = make_output_type(CollectionSchema::new(
CollectionKind::Table,
Expand All @@ -288,10 +308,10 @@ impl SimpleFunctionFactoryBase for Factory {

async fn build_executor(
self: Arc<Self>,
spec: Spec,
_spec: Spec,
args: Args,
_context: Arc<FlowInstanceContext>,
) -> Result<Box<dyn SimpleFunctionExecutor>> {
Ok(Box::new(Executor::new(spec, args)?))
Ok(Box::new(Executor::new(args)?))
}
}