From 1addedcebe8286718a92bda8afce872256760f3a Mon Sep 17 00:00:00 2001 From: LJ Date: Thu, 6 Mar 2025 23:36:28 -0800 Subject: [PATCH] Support Python->Rust Struct/Table type bindings. --- python/cocoindex/op.py | 17 +++++++++--- src/ops/py_factory.rs | 60 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index a51e4b3a..9afb6ba6 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -1,10 +1,10 @@ """ Facilities for defining cocoindex operations. """ +import dataclasses import inspect from typing import get_type_hints, Protocol, Any, Callable, dataclass_transform -from dataclasses import dataclass from enum import Enum from threading import Lock @@ -28,7 +28,7 @@ def __new__(mcs, name, bases, attrs, category: OpCategory | None = None): setattr(cls, '_op_category', category) else: # It's the specific class providing specific fields. - cls = dataclass(cls) + cls = dataclasses.dataclass(cls) return cls class SourceSpec(metaclass=SpecMeta, category=OpCategory.SOURCE): # pylint: disable=too-few-public-methods @@ -59,6 +59,14 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs): result_type = executor.analyze(*args, **kwargs) return (dump_type(result_type), executor) +def to_engine_value(value: Any) -> Any: + """Convert a Python value to an engine value.""" + if dataclasses.is_dataclass(value): + return [to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)] + elif isinstance(value, list) or isinstance(value, tuple): + return [to_engine_value(v) for v in value] + return value + _gpu_dispatch_lock = Lock() def executor_class(gpu: bool = False, cache: bool = False, behavior_version: int | None = None) -> Callable[[type], type]: @@ -162,9 +170,10 @@ def __call__(self, *args, **kwargs): # For now, we use a lock to ensure only one task is executed at a time. # TODO: Implement multi-processing dispatching. with _gpu_dispatch_lock: - return super().__call__(*args, **kwargs) + output = super().__call__(*args, **kwargs) else: - return super().__call__(*args, **kwargs) + output = super().__call__(*args, **kwargs) + return to_engine_value(output) _WrappedClass.__name__ = cls.__name__ diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index 4c21636c..db820934 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use axum::async_trait; use blocking::unblock; @@ -6,13 +6,14 @@ use futures::FutureExt; use pyo3::{ exceptions::PyException, pyclass, pymethods, - types::{IntoPyDict, PyAnyMethods, PyString, PyTuple}, + types::{IntoPyDict, PyAnyMethods, PyList, PyString, PyTuple}, Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python, }; use crate::{ base::{schema, value}, builder::plan, + py::IntoPyResult, }; use anyhow::Result; @@ -89,6 +90,28 @@ fn basic_value_from_py_object<'py>( Ok(result) } +fn field_values_from_py_object<'py>( + schema: &schema::StructSchema, + v: &Bound<'py, PyAny>, +) -> PyResult { + let list = v.extract::>>()?; + if list.len() != schema.fields.len() { + return Err(PyException::new_err(format!( + "struct field number mismatch, expected {}, got {}", + schema.fields.len(), + list.len() + ))); + } + Ok(value::FieldValues { + fields: schema + .fields + .iter() + .zip(list.into_iter()) + .map(|(f, v)| value_from_py_object(&f.value_type.typ, &v)) + .collect::>>()?, + }) +} + fn value_from_py_object<'py>( typ: &schema::ValueType, v: &Bound<'py, PyAny>, @@ -100,6 +123,39 @@ fn value_from_py_object<'py>( schema::ValueType::Basic(typ) => { value::Value::Basic(basic_value_from_py_object(typ, v)?) } + schema::ValueType::Struct(schema) => { + value::Value::Struct(field_values_from_py_object(schema, v)?) + } + schema::ValueType::Collection(schema) => { + let list = v.extract::>>()?; + let values = list + .into_iter() + .map(|v| field_values_from_py_object(&schema.row, &v)) + .collect::>>()?; + match schema.kind { + schema::CollectionKind::Collection => { + value::Value::Collection(values.into_iter().map(|v| v.into()).collect()) + } + schema::CollectionKind::List => { + value::Value::List(values.into_iter().map(|v| v.into()).collect()) + } + schema::CollectionKind::Table => value::Value::Table( + values + .into_iter() + .map(|v| { + let mut iter = v.fields.into_iter(); + let key = iter.next().unwrap().to_key().into_py_result()?; + Ok(( + key, + value::ScopeValue(value::FieldValues { + fields: iter.collect::>(), + }), + )) + }) + .collect::>>()?, + ), + } + } _ => { return Err(PyException::new_err(format!( "unsupported value type: {}",