Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/cocoindex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
from .flow import update_all_flows, FlowLiveUpdater, FlowLiveUpdaterOptions
from .llm import LlmSpec, LlmApiType
from .vector import VectorSimilarityMetric
from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry
from .lib import *
from ._engine import OpArgSchema
from ._engine import OpArgSchema
22 changes: 22 additions & 0 deletions python/cocoindex/auth_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Auth registry is used to register and reference auth entries.
"""

from dataclasses import dataclass

from . import _engine
from .convert import dump_engine_object

@dataclass
class AuthEntryReference:
"""Reference an auth entry by its key."""
key: str

def add_auth_entry(key: str, value) -> AuthEntryReference:
"""Add an auth entry to the registry. Returns its reference."""
_engine.add_auth_entry(key, dump_engine_object(value))
return AuthEntryReference(key)

def ref_auth_entry(key: str) -> AuthEntryReference:
"""Reference an auth entry by its key."""
return AuthEntryReference(key)
27 changes: 25 additions & 2 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
Utilities to convert between Python and engine values.
"""
import dataclasses
import datetime
import inspect
import uuid

from typing import Any, Callable
from .typing import analyze_type_info, COLLECTION_TYPES
from enum import Enum
from typing import Any, Callable, get_origin
from .typing import analyze_type_info, encode_enriched_type, COLLECTION_TYPES

def to_engine_value(value: Any) -> Any:
"""Convert a Python value to an engine value."""
Expand Down Expand Up @@ -100,3 +102,24 @@ def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[lis

return lambda values: dst_dataclass_type(
*(converter(values) for converter in field_value_converters))

def dump_engine_object(v: Any) -> Any:
"""Recursively dump an object for engine. Engine side uses `Pythonized` to catch."""
if v is None:
return None
elif isinstance(v, type) or get_origin(v) is not None:
return encode_enriched_type(v)
elif isinstance(v, Enum):
return v.value
elif isinstance(v, datetime.timedelta):
total_secs = v.total_seconds()
secs = int(total_secs)
nanos = int((total_secs - secs) * 1e9)
return {'secs': secs, 'nanos': nanos}
elif hasattr(v, '__dict__'):
return {k: dump_engine_object(v) for k, v in v.__dict__.items()}
elif isinstance(v, (list, tuple)):
return [dump_engine_object(item) for item in v]
elif isinstance(v, dict):
return {k: dump_engine_object(v) for k, v in v.items()}
return v
38 changes: 10 additions & 28 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
import re
import inspect
import datetime
from typing import Any, Callable, Sequence, TypeVar, get_origin

from typing import Any, Callable, Sequence, TypeVar
from threading import Lock
from enum import Enum
from dataclasses import dataclass

from . import _engine
from . import vector
from . import op
from .convert import dump_engine_object
from .typing import encode_enriched_type

class _NameBuilder:
Expand Down Expand Up @@ -64,27 +66,6 @@ def _create_data_slice(
def _spec_kind(spec: Any) -> str:
return spec.__class__.__name__

def _dump_engine_object(v: Any) -> Any:
"""Recursively dump an object for engine. Engine side uses `Pythonzized` to catch."""
if v is None:
return None
elif isinstance(v, type) or get_origin(v) is not None:
return encode_enriched_type(v)
elif isinstance(v, Enum):
return v.value
elif isinstance(v, datetime.timedelta):
total_secs = v.total_seconds()
secs = int(total_secs)
nanos = int((total_secs - secs) * 1e9)
return {'secs': secs, 'nanos': nanos}
elif hasattr(v, '__dict__'):
return {k: _dump_engine_object(v) for k, v in v.__dict__.items()}
elif isinstance(v, (list, tuple)):
return [_dump_engine_object(item) for item in v]
elif isinstance(v, dict):
return {k: _dump_engine_object(v) for k, v in v.items()}
return v

T = TypeVar('T')

class _DataSliceState:
Expand Down Expand Up @@ -176,6 +157,7 @@ def transform(self, fn_spec: op.FunctionSpec, *args, **kwargs) -> DataSlice:
"""
Apply a function to the data slice.
"""
transform_args: list[tuple[Any, str | None]]
transform_args = [(self._state.engine_data_slice, None)]
transform_args += [(self._state.flow_builder_state.get_data_slice(v), None) for v in args]
transform_args += [(self._state.flow_builder_state.get_data_slice(v), k)
Expand All @@ -187,7 +169,7 @@ def transform(self, fn_spec: op.FunctionSpec, *args, **kwargs) -> DataSlice:
lambda target_scope, name:
flow_builder_state.engine_flow_builder.transform(
_spec_kind(fn_spec),
_dump_engine_object(fn_spec),
dump_engine_object(fn_spec),
transform_args,
target_scope,
flow_builder_state.field_name_builder.build_name(
Expand Down Expand Up @@ -298,7 +280,7 @@ def export(self, name: str, target_spec: op.StorageSpec, /, *,
{"field_name": field_name, "metric": metric.value}
for field_name, metric in vector_index]
self._flow_builder_state.engine_flow_builder.export(
name, _spec_kind(target_spec), _dump_engine_object(target_spec),
name, _spec_kind(target_spec), dump_engine_object(target_spec),
index_options, self._engine_data_collector, setup_by_user)


Expand Down Expand Up @@ -357,11 +339,11 @@ def add_source(self, spec: op.SourceSpec, /, *,
self._state,
lambda target_scope, name: self._state.engine_flow_builder.add_source(
_spec_kind(spec),
_dump_engine_object(spec),
dump_engine_object(spec),
target_scope,
self._state.field_name_builder.build_name(
name, prefix=_to_snake_case(_spec_kind(spec))+'_'),
_dump_engine_object(_SourceRefreshOptions(refresh_interval=refresh_interval)),
dump_engine_object(_SourceRefreshOptions(refresh_interval=refresh_interval)),
),
name
)
Expand All @@ -382,7 +364,7 @@ class FlowLiveUpdater:

def __init__(self, fl: Flow, options: FlowLiveUpdaterOptions | None = None):
self._engine_live_updater = _engine.FlowLiveUpdater(
fl._lazy_engine_flow(), _dump_engine_object(options or FlowLiveUpdaterOptions()))
fl._lazy_engine_flow(), dump_engine_object(options or FlowLiveUpdaterOptions()))

def __enter__(self) -> FlowLiveUpdater:
return self
Expand Down Expand Up @@ -469,7 +451,7 @@ def evaluate_and_dump(self, options: EvaluateAndDumpOptions):
"""
Evaluate the flow and dump flow outputs to files.
"""
return self._lazy_engine_flow().evaluate_and_dump(_dump_engine_object(options))
return self._lazy_engine_flow().evaluate_and_dump(dump_engine_object(options))

def internal_flow(self) -> _engine.Flow:
"""
Expand Down
11 changes: 11 additions & 0 deletions src/py/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,16 @@ fn apply_setup_changes(py: Python<'_>, setup_status: &SetupStatusCheck) -> PyRes
})
}

#[pyfunction]
fn add_auth_entry(key: String, value: Pythonized<serde_json::Value>) -> PyResult<()> {
let lib_context = get_lib_context().into_py_result()?;
lib_context
.auth_registry
.add(key, value.into_inner())
.into_py_result()?;
Ok(())
}

/// A Python module implemented in Rust.
#[pymodule]
#[pyo3(name = "_engine")]
Expand All @@ -333,6 +343,7 @@ fn cocoindex_engine(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(drop_setup, m)?)?;
m.add_function(wrap_pyfunction!(apply_setup_changes, m)?)?;
m.add_function(wrap_pyfunction!(flow_names_with_setup, m)?)?;
m.add_function(wrap_pyfunction!(add_auth_entry, m)?)?;

m.add_class::<builder::flow_builder::FlowBuilder>()?;
m.add_class::<builder::flow_builder::DataCollector>()?;
Expand Down