diff --git a/src/ops/py_factory.rs b/src/ops/py_factory.rs index dfafee0b..d14468ae 100644 --- a/src/ops/py_factory.rs +++ b/src/ops/py_factory.rs @@ -1,19 +1,18 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::sync::Arc; use axum::async_trait; use futures::FutureExt; use pyo3::{ - exceptions::PyException, pyclass, pymethods, - types::{IntoPyDict, PyAnyMethods, PyList, PyString, PyTuple}, - Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python, + types::{IntoPyDict, PyString, PyTuple}, + IntoPyObjectExt, Py, PyAny, Python, }; use pythonize::pythonize; use crate::{ base::{schema, value}, builder::plan, - py::IntoPyResult, + py, }; use anyhow::Result; @@ -21,170 +20,6 @@ use super::sdk::{ ExecutorFuture, FlowInstanceContext, SimpleFunctionExecutor, SimpleFunctionFactory, }; -fn basic_value_to_py_object<'py>( - py: Python<'py>, - v: &value::BasicValue, -) -> PyResult> { - let result = match v { - value::BasicValue::Bytes(v) => v.into_bound_py_any(py)?, - value::BasicValue::Str(v) => v.into_bound_py_any(py)?, - value::BasicValue::Bool(v) => v.into_bound_py_any(py)?, - value::BasicValue::Int64(v) => v.into_bound_py_any(py)?, - value::BasicValue::Float32(v) => v.into_bound_py_any(py)?, - value::BasicValue::Float64(v) => v.into_bound_py_any(py)?, - value::BasicValue::Vector(v) => v - .iter() - .map(|v| basic_value_to_py_object(py, v)) - .collect::>>()? - .into_bound_py_any(py)?, - _ => { - return Err(PyException::new_err(format!( - "unsupported value type: {}", - v.kind() - ))) - } - }; - Ok(result) -} - -fn field_values_to_py_object<'py, 'a>( - py: Python<'py>, - values: impl Iterator, -) -> PyResult> { - let fields = values - .map(|v| value_to_py_object(py, v)) - .collect::>>()?; - Ok(PyTuple::new(py, fields)?.into_any()) -} - -fn value_to_py_object<'py>(py: Python<'py>, v: &value::Value) -> PyResult> { - let result = match v { - value::Value::Null => py.None().into_bound(py), - value::Value::Basic(v) => basic_value_to_py_object(py, v)?, - value::Value::Struct(v) => field_values_to_py_object(py, v.fields.iter())?, - value::Value::Collection(v) | value::Value::List(v) => { - let rows = v - .iter() - .map(|v| field_values_to_py_object(py, v.0.fields.iter())) - .collect::>>()?; - PyList::new(py, rows)?.into_any() - } - value::Value::Table(v) => { - let rows = v - .iter() - .map(|(k, v)| { - field_values_to_py_object( - py, - std::iter::once(&value::Value::from(k.clone())).chain(v.0.fields.iter()), - ) - }) - .collect::>>()?; - PyList::new(py, rows)?.into_any() - } - }; - Ok(result) -} - -fn basic_value_from_py_object<'py>( - typ: &schema::BasicValueType, - v: &Bound<'py, PyAny>, -) -> PyResult { - let result = match typ { - schema::BasicValueType::Bytes => { - value::BasicValue::Bytes(Arc::from(v.extract::>()?)) - } - schema::BasicValueType::Str => value::BasicValue::Str(Arc::from(v.extract::()?)), - schema::BasicValueType::Bool => value::BasicValue::Bool(v.extract::()?), - schema::BasicValueType::Int64 => value::BasicValue::Int64(v.extract::()?), - schema::BasicValueType::Float32 => value::BasicValue::Float32(v.extract::()?), - schema::BasicValueType::Float64 => value::BasicValue::Float64(v.extract::()?), - schema::BasicValueType::Vector(elem) => value::BasicValue::Vector(Arc::from( - v.extract::>>()? - .into_iter() - .map(|v| basic_value_from_py_object(&elem.element_type, &v)) - .collect::>>()?, - )), - _ => { - return Err(PyException::new_err(format!( - "unsupported value type: {}", - typ - ))) - } - }; - Ok(result) -} - -fn field_values_from_py_object<'py>( - schema: &schema::StructSchema, - v: &Bound<'py, PyAny>, -) -> PyResult { - let list = v.extract::>>()?; - if list.len() != schema.fields.len() { - return Err(PyException::new_err(format!( - "struct field number mismatch, expected {}, got {}", - schema.fields.len(), - list.len() - ))); - } - Ok(value::FieldValues { - fields: schema - .fields - .iter() - .zip(list.into_iter()) - .map(|(f, v)| value_from_py_object(&f.value_type.typ, &v)) - .collect::>>()?, - }) -} - -fn value_from_py_object<'py>( - typ: &schema::ValueType, - v: &Bound<'py, PyAny>, -) -> PyResult { - let result = if v.is_none() { - value::Value::Null - } else { - match typ { - schema::ValueType::Basic(typ) => { - value::Value::Basic(basic_value_from_py_object(typ, v)?) - } - schema::ValueType::Struct(schema) => { - value::Value::Struct(field_values_from_py_object(schema, v)?) - } - schema::ValueType::Collection(schema) => { - let list = v.extract::>>()?; - let values = list - .into_iter() - .map(|v| field_values_from_py_object(&schema.row, &v)) - .collect::>>()?; - match schema.kind { - schema::CollectionKind::Collection => { - value::Value::Collection(values.into_iter().map(|v| v.into()).collect()) - } - schema::CollectionKind::List => { - value::Value::List(values.into_iter().map(|v| v.into()).collect()) - } - schema::CollectionKind::Table => value::Value::Table( - values - .into_iter() - .map(|v| { - let mut iter = v.fields.into_iter(); - let key = iter.next().unwrap().to_key().into_py_result()?; - Ok(( - key, - value::ScopeValue(value::FieldValues { - fields: iter.collect::>(), - }), - )) - }) - .collect::>>()?, - ), - } - } - } - }; - Ok(result) -} - #[pyclass(name = "OpArgSchema")] pub struct PyOpArgSchema { value_type: crate::py::Pythonized, @@ -222,7 +57,7 @@ impl SimpleFunctionExecutor for Arc { Python::with_gil(|py| -> Result<_> { let mut args = Vec::with_capacity(self.num_positional_args); for v in input[0..self.num_positional_args].iter() { - args.push(value_to_py_object(py, v)?); + args.push(py::value_to_py_object(py, v)?); } let kwargs = if self.kw_args_names.is_empty() { @@ -234,7 +69,7 @@ impl SimpleFunctionExecutor for Arc { .iter() .zip(input[self.num_positional_args..].iter()) { - kwargs.push((name.bind(py), value_to_py_object(py, v)?)); + kwargs.push((name.bind(py), py::value_to_py_object(py, v)?)); } Some(kwargs) }; @@ -248,7 +83,7 @@ impl SimpleFunctionExecutor for Arc { .as_ref(), )?; - Ok(value_from_py_object( + Ok(py::value_from_py_object( &self.result_type.typ, result.bind(py), )?) diff --git a/src/py/convert.rs b/src/py/convert.rs new file mode 100644 index 00000000..c1a5174d --- /dev/null +++ b/src/py/convert.rs @@ -0,0 +1,217 @@ +use pyo3::types::{PyList, PyTuple}; +use pyo3::IntoPyObjectExt; +use pyo3::{exceptions::PyException, prelude::*}; +use pythonize::{depythonize, pythonize}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use std::collections::BTreeMap; +use std::ops::Deref; +use std::sync::Arc; + +use super::IntoPyResult; +use crate::base::{schema, value}; + +pub struct Pythonized(pub T); + +impl<'py, T: DeserializeOwned> FromPyObject<'py> for Pythonized { + fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult { + Ok(Pythonized(depythonize(obj).into_py_result()?)) + } +} + +impl<'py, T: Serialize> IntoPyObject<'py> for &Pythonized { + type Target = PyAny; + type Output = Bound<'py, PyAny>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> PyResult { + pythonize(py, &self.0).into_py_result() + } +} + +impl<'py, T: Serialize> IntoPyObject<'py> for Pythonized { + type Target = PyAny; + type Output = Bound<'py, PyAny>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> PyResult { + (&self).into_pyobject(py) + } +} + +impl Pythonized { + pub fn into_inner(self) -> T { + self.0 + } +} + +impl Deref for Pythonized { + type Target = T; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +fn basic_value_to_py_object<'py>( + py: Python<'py>, + v: &value::BasicValue, +) -> PyResult> { + let result = match v { + value::BasicValue::Bytes(v) => v.into_bound_py_any(py)?, + value::BasicValue::Str(v) => v.into_bound_py_any(py)?, + value::BasicValue::Bool(v) => v.into_bound_py_any(py)?, + value::BasicValue::Int64(v) => v.into_bound_py_any(py)?, + value::BasicValue::Float32(v) => v.into_bound_py_any(py)?, + value::BasicValue::Float64(v) => v.into_bound_py_any(py)?, + value::BasicValue::Vector(v) => v + .iter() + .map(|v| basic_value_to_py_object(py, v)) + .collect::>>()? + .into_bound_py_any(py)?, + _ => { + return Err(PyException::new_err(format!( + "unsupported value type: {}", + v.kind() + ))) + } + }; + Ok(result) +} + +fn field_values_to_py_object<'py, 'a>( + py: Python<'py>, + values: impl Iterator, +) -> PyResult> { + let fields = values + .map(|v| value_to_py_object(py, v)) + .collect::>>()?; + Ok(PyTuple::new(py, fields)?.into_any()) +} + +pub fn value_to_py_object<'py>(py: Python<'py>, v: &value::Value) -> PyResult> { + let result = match v { + value::Value::Null => py.None().into_bound(py), + value::Value::Basic(v) => basic_value_to_py_object(py, v)?, + value::Value::Struct(v) => field_values_to_py_object(py, v.fields.iter())?, + value::Value::Collection(v) | value::Value::List(v) => { + let rows = v + .iter() + .map(|v| field_values_to_py_object(py, v.0.fields.iter())) + .collect::>>()?; + PyList::new(py, rows)?.into_any() + } + value::Value::Table(v) => { + let rows = v + .iter() + .map(|(k, v)| { + field_values_to_py_object( + py, + std::iter::once(&value::Value::from(k.clone())).chain(v.0.fields.iter()), + ) + }) + .collect::>>()?; + PyList::new(py, rows)?.into_any() + } + }; + Ok(result) +} + +fn basic_value_from_py_object<'py>( + typ: &schema::BasicValueType, + v: &Bound<'py, PyAny>, +) -> PyResult { + let result = match typ { + schema::BasicValueType::Bytes => { + value::BasicValue::Bytes(Arc::from(v.extract::>()?)) + } + schema::BasicValueType::Str => value::BasicValue::Str(Arc::from(v.extract::()?)), + schema::BasicValueType::Bool => value::BasicValue::Bool(v.extract::()?), + schema::BasicValueType::Int64 => value::BasicValue::Int64(v.extract::()?), + schema::BasicValueType::Float32 => value::BasicValue::Float32(v.extract::()?), + schema::BasicValueType::Float64 => value::BasicValue::Float64(v.extract::()?), + schema::BasicValueType::Vector(elem) => value::BasicValue::Vector(Arc::from( + v.extract::>>()? + .into_iter() + .map(|v| basic_value_from_py_object(&elem.element_type, &v)) + .collect::>>()?, + )), + _ => { + return Err(PyException::new_err(format!( + "unsupported value type: {}", + typ + ))) + } + }; + Ok(result) +} + +fn field_values_from_py_object<'py>( + schema: &schema::StructSchema, + v: &Bound<'py, PyAny>, +) -> PyResult { + let list = v.extract::>>()?; + if list.len() != schema.fields.len() { + return Err(PyException::new_err(format!( + "struct field number mismatch, expected {}, got {}", + schema.fields.len(), + list.len() + ))); + } + Ok(value::FieldValues { + fields: schema + .fields + .iter() + .zip(list.into_iter()) + .map(|(f, v)| value_from_py_object(&f.value_type.typ, &v)) + .collect::>>()?, + }) +} + +pub fn value_from_py_object<'py>( + typ: &schema::ValueType, + v: &Bound<'py, PyAny>, +) -> PyResult { + let result = if v.is_none() { + value::Value::Null + } else { + match typ { + schema::ValueType::Basic(typ) => { + value::Value::Basic(basic_value_from_py_object(typ, v)?) + } + schema::ValueType::Struct(schema) => { + value::Value::Struct(field_values_from_py_object(schema, v)?) + } + schema::ValueType::Collection(schema) => { + let list = v.extract::>>()?; + let values = list + .into_iter() + .map(|v| field_values_from_py_object(&schema.row, &v)) + .collect::>>()?; + match schema.kind { + schema::CollectionKind::Collection => { + value::Value::Collection(values.into_iter().map(|v| v.into()).collect()) + } + schema::CollectionKind::List => { + value::Value::List(values.into_iter().map(|v| v.into()).collect()) + } + schema::CollectionKind::Table => value::Value::Table( + values + .into_iter() + .map(|v| { + let mut iter = v.fields.into_iter(); + let key = iter.next().unwrap().to_key().into_py_result()?; + Ok(( + key, + value::ScopeValue(value::FieldValues { + fields: iter.collect::>(), + }), + )) + }) + .collect::>>()?, + ), + } + } + } + }; + Ok(result) +} diff --git a/src/py.rs b/src/py/mod.rs similarity index 89% rename from src/py.rs rename to src/py/mod.rs index 4bad52db..c1941376 100644 --- a/src/py.rs +++ b/src/py/mod.rs @@ -12,13 +12,12 @@ use crate::{api_error, setup}; use crate::{builder, execution}; use anyhow::anyhow; use pyo3::{exceptions::PyException, prelude::*}; -use pythonize::{depythonize, pythonize}; -use serde::de::DeserializeOwned; -use serde::Serialize; use std::collections::btree_map; -use std::ops::Deref; use std::sync::Arc; +mod convert; +pub use convert::*; + pub trait IntoPyResult { fn into_py_result(self) -> PyResult; } @@ -32,47 +31,6 @@ impl IntoPyResult for Result { } } -pub struct Pythonized(pub T); - -impl<'py, T: DeserializeOwned> FromPyObject<'py> for Pythonized { - fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult { - Ok(Pythonized(depythonize(obj).into_py_result()?)) - } -} - -impl<'py, T: Serialize> IntoPyObject<'py> for &Pythonized { - type Target = PyAny; - type Output = Bound<'py, PyAny>; - type Error = PyErr; - - fn into_pyobject(self, py: Python<'py>) -> PyResult { - pythonize(py, &self.0).into_py_result() - } -} - -impl<'py, T: Serialize> IntoPyObject<'py> for Pythonized { - type Target = PyAny; - type Output = Bound<'py, PyAny>; - type Error = PyErr; - - fn into_pyobject(self, py: Python<'py>) -> PyResult { - (&self).into_pyobject(py) - } -} - -impl Pythonized { - pub fn into_inner(self) -> T { - self.0 - } -} - -impl Deref for Pythonized { - type Target = T; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - #[pyfunction] fn init(py: Python<'_>, settings: Pythonized) -> PyResult<()> { py.allow_threads(|| -> anyhow::Result<()> {