Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions examples/code_embedding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,10 @@
import cocoindex
import os

class ExtractExtension(cocoindex.op.FunctionSpec):
"""Summarize a Python module."""

@cocoindex.op.executor_class()
class ExtractExtensionExecutor:
"""Executor for ExtractExtension."""

spec: ExtractExtension

def __call__(self, filename: str) -> str:
return os.path.splitext(filename)[1]
@cocoindex.op.function()
def extract_extension(filename: str) -> str:
"""Extract the extension of a filename."""
return os.path.splitext(filename)[1]

def code_to_embedding(text: cocoindex.DataSlice) -> cocoindex.DataSlice:
"""
Expand All @@ -35,7 +28,7 @@ def code_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind
code_embeddings = data_scope.add_collector()

with data_scope["files"].row() as file:
file["extension"] = file["filename"].transform(ExtractExtension())
file["extension"] = file["filename"].transform(extract_extension)
file["chunks"] = file["content"].transform(
cocoindex.functions.SplitRecursively(),
language=file["extension"], chunk_size=1000, chunk_overlap=300)
Expand Down
22 changes: 7 additions & 15 deletions examples/manuals_llm_extraction/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,13 @@ class ModuleSummary:
num_classes: int
num_methods: int

@dataclasses.dataclass
class SummarizeModule(cocoindex.op.FunctionSpec):
@cocoindex.op.function()
def summarize_module(module_info: ModuleInfo) -> ModuleSummary:
"""Summarize a Python module."""

@cocoindex.op.executor_class()
class SummarizeModuleExecutor:
"""Executor for SummarizeModule."""

spec: SummarizeModule

def __call__(self, module_info: ModuleInfo) -> ModuleSummary:
return ModuleSummary(
num_classes=len(module_info.classes),
num_methods=len(module_info.methods),
)
return ModuleSummary(
num_classes=len(module_info.classes),
num_methods=len(module_info.methods),
)

@cocoindex.flow_def(name="ManualExtraction")
def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope):
Expand All @@ -103,7 +95,7 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco
# api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"),
output_type=ModuleInfo,
instruction="Please extract Python module information from the manual."))
doc["module_summary"] = doc["module_info"].transform(SummarizeModule())
doc["module_summary"] = doc["module_info"].transform(summarize_module)
modules_index.collect(
filename=doc["filename"],
module_info=doc["module_info"],
Expand Down
269 changes: 165 additions & 104 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,140 @@ def _make_engine_value_converter(

_gpu_dispatch_lock = Lock()

def executor_class(gpu: bool = False, cache: bool = False, behavior_version: int | None = None) -> Callable[[type], type]:
@dataclasses.dataclass
class OpArgs:
"""
Decorate a class to provide an executor for an op.
- gpu: Whether the executor will be executed on GPU.
- cache: Whether the executor will be cached.
- behavior_version: The behavior version of the executor. Cache will be invalidated if it
changes. Must be provided if `cache` is True.
"""
gpu: bool = False
cache: bool = False
behavior_version: int | None = None

def _register_op_factory(
category: OpCategory,
expected_args: list[tuple[str, inspect.Parameter]],
expected_return,
executor_cls: type,
spec_cls: type,
op_args: OpArgs,
):
"""
Register an op factory.
"""
class _Fallback:
def enable_cache(self):
return op_args.cache

def behavior_version(self):
return op_args.behavior_version

class _WrappedClass(executor_cls, _Fallback):
_args_converters: list[Callable[[Any], Any]]
_kwargs_converters: dict[str, Callable[[str, Any], Any]]

def __init__(self, spec):
super().__init__()
self.spec = spec

def analyze(self, *args, **kwargs):
"""
Analyze the spec and arguments. In this phase, argument types should be validated.
It should return the expected result type for the current op.
"""
self._args_converters = []
self._kwargs_converters = {}

# Match arguments with parameters.
next_param_idx = 0
for arg in args:
if next_param_idx >= len(expected_args):
raise ValueError(
f"Too many arguments passed in: {len(args)} > {len(expected_args)}")
arg_name, arg_param = expected_args[next_param_idx]
if arg_param.kind in (
inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD):
raise ValueError(
f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
self._args_converters.append(
_make_engine_value_converter(
[arg_name], arg.value_type['type'], arg_param.annotation))
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
next_param_idx += 1

expected_kwargs = expected_args[next_param_idx:]

for kwarg_name, kwarg in kwargs.items():
expected_arg = next(
(arg for arg in expected_kwargs
if (arg[0] == kwarg_name and arg[1].kind in (
inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD))
or arg[1].kind == inspect.Parameter.VAR_KEYWORD),
None)
if expected_arg is None:
raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
arg_param = expected_arg[1]
self._kwargs_converters[kwarg_name] = _make_engine_value_converter(
[kwarg_name], kwarg.value_type['type'], arg_param.annotation)

missing_args = [name for (name, arg) in expected_kwargs
if arg.default is inspect.Parameter.empty
and (arg.kind == inspect.Parameter.POSITIONAL_ONLY or
(arg.kind in (inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD)
and name not in kwargs))]
if len(missing_args) > 0:
raise ValueError(f"Missing arguments: {', '.join(missing_args)}")

prepare_method = getattr(executor_cls, 'analyze', None)
if prepare_method is not None:
return prepare_method(self, *args, **kwargs)
else:
return expected_return

def prepare(self):
"""
Prepare for execution.
It's executed after `analyze` and before any `__call__` execution.
"""
setup_method = getattr(executor_cls, 'prepare', None)
if setup_method is not None:
setup_method(self)

def __call__(self, *args, **kwargs):
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg)
for arg_name, arg in kwargs.items()}
if op_args.gpu:
# For GPU executions, data-level parallelism is applied, so we don't want to
# execute different tasks in parallel.
# Besides, multiprocessing is more appropriate for pytorch.
# For now, we use a lock to ensure only one task is executed at a time.
# TODO: Implement multi-processing dispatching.
with _gpu_dispatch_lock:
output = super().__call__(*converted_args, **converted_kwargs)
else:
output = super().__call__(*converted_args, **converted_kwargs)
return to_engine_value(output)

_WrappedClass.__name__ = executor_cls.__name__

Args:
gpu: Whether the executor will be executed on GPU.
cache: Whether the executor will be cached.
behavior_version: The behavior version of the executor. Cache will be invalidated if it changes. Must be provided if `cache` is True.
if category == OpCategory.FUNCTION:
_engine.register_function_factory(
spec_cls.__name__,
_FunctionExecutorFactory(spec_cls, _WrappedClass))
else:
raise ValueError(f"Unsupported executor type {category}")

return _WrappedClass

def executor_class(**args) -> Callable[[type], type]:
"""
Decorate a class to provide an executor for an op.
"""
op_args = OpArgs(**args)

def _inner(cls: type[Executor]) -> type:
"""
Expand All @@ -149,110 +274,46 @@ def _inner(cls: type[Executor]) -> type:
type_hints = get_type_hints(cls)
if 'spec' not in type_hints:
raise TypeError("Expect a `spec` field with type hint")

spec_cls = type_hints['spec']
op_name = spec_cls.__name__
category = spec_cls._op_category

sig = inspect.signature(cls.__call__)
expected_args = list(sig.parameters.items())[1:] # First argument is `self`
expected_return = sig.return_annotation

cls_type: type = cls

class _Fallback:
def enable_cache(self):
return cache

def behavior_version(self):
return behavior_version

class _WrappedClass(cls_type, _Fallback):
_args_converters: list[Callable[[Any], Any]]
_kwargs_converters: dict[str, Callable[[str, Any], Any]]

def __init__(self, spec):
super().__init__()
self.spec = spec

def analyze(self, *args, **kwargs):
"""
Analyze the spec and arguments. In this phase, argument types should be validated.
It should return the expected result type for the current op.
"""
self._args_converters = []
self._kwargs_converters = {}

# Match arguments with parameters.
next_param_idx = 0
for arg in args:
if next_param_idx >= len(expected_args):
raise ValueError(f"Too many arguments passed in: {len(args)} > {len(expected_args)}")
arg_name, arg_param = expected_args[next_param_idx]
if arg_param.kind == inspect.Parameter.KEYWORD_ONLY or arg_param.kind == inspect.Parameter.VAR_KEYWORD:
raise ValueError(f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
self._args_converters.append(
_make_engine_value_converter([arg_name], arg.value_type['type'], arg_param.annotation))
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
next_param_idx += 1

expected_kwargs = expected_args[next_param_idx:]

for kwarg_name, kwarg in kwargs.items():
expected_arg = next(
(arg for arg in expected_kwargs
if (arg[0] == kwarg_name and arg[1].kind in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD))
or arg[1].kind == inspect.Parameter.VAR_KEYWORD),
None)
if expected_arg is None:
raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
arg_param = expected_arg[1]
self._kwargs_converters[kwarg_name] = _make_engine_value_converter(
[kwarg_name], kwarg.value_type['type'], arg_param.annotation)

missing_args = [name for (name, arg) in expected_kwargs
if arg.default is inspect.Parameter.empty
and (arg.kind == inspect.Parameter.POSITIONAL_ONLY or
(arg.kind in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) and name not in kwargs))]
if len(missing_args) > 0:
raise ValueError(f"Missing arguments: {', '.join(missing_args)}")

prepare_method = getattr(cls_type, 'analyze', None)
if prepare_method is not None:
return prepare_method(self, *args, **kwargs)
else:
return expected_return

def prepare(self):
"""
Prepare for execution.
It's executed after `analyze` and before any `__call__` execution.
"""
setup_method = getattr(cls_type, 'prepare', None)
if setup_method is not None:
setup_method(self)
return _register_op_factory(
category=spec_cls._op_category,
expected_args=list(sig.parameters.items())[1:], # First argument is `self`
expected_return=sig.return_annotation,
executor_cls=cls,
spec_cls=spec_cls,
op_args=op_args)

return _inner

def function(**args) -> Callable[[Callable], FunctionSpec]:
"""
Decorate a function to provide a function for an op.
"""
op_args = OpArgs(**args)

def _inner(fn: Callable) -> FunctionSpec:

# Convert snake case to camel case.
op_name = ''.join(word.capitalize() for word in fn.__name__.split('_'))
sig = inspect.signature(fn)

class _Executor:
def __call__(self, *args, **kwargs):
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg) for arg_name, arg in kwargs.items()}
if gpu:
# For GPU executions, data-level parallelism is applied, so we don't want to execute different tasks in parallel.
# Besides, multiprocessing is more appropriate for pytorch.
# For now, we use a lock to ensure only one task is executed at a time.
# TODO: Implement multi-processing dispatching.
with _gpu_dispatch_lock:
output = super().__call__(*converted_args, **converted_kwargs)
else:
output = super().__call__(*converted_args, **converted_kwargs)
return to_engine_value(output)
return fn(*args, **kwargs)

_WrappedClass.__name__ = cls.__name__
class _Spec(FunctionSpec):
pass
_Spec.__name__ = op_name

if category == OpCategory.FUNCTION:
_engine.register_function_factory(op_name, _FunctionExecutorFactory(spec_cls, _WrappedClass))
else:
raise ValueError(f"Unsupported executor type {category}")
_register_op_factory(
category=OpCategory.FUNCTION,
expected_args=list(sig.parameters.items()),
expected_return=sig.return_annotation,
executor_cls=_Executor,
spec_cls=_Spec,
op_args=op_args)

return _WrappedClass
return _Spec()

return _inner