Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions src/ops/py_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use pythonize::pythonize;
use crate::{
base::{schema, value},
builder::plan,
py,
py::{self, FromPyResult},
};
use anyhow::{anyhow, Result};

Expand Down Expand Up @@ -74,14 +74,17 @@ impl PyFunctionExecutor {
Some(kwargs)
};

let result = self.py_function_executor.call(
py,
PyTuple::new(py, args.into_iter())?,
kwargs
.map(|kwargs| -> Result<_> { Ok(kwargs.into_py_dict(py)?) })
.transpose()?
.as_ref(),
)?;
let result = self
.py_function_executor
.call(
py,
PyTuple::new(py, args.into_iter())?,
kwargs
.map(|kwargs| -> Result<_> { Ok(kwargs.into_py_dict(py)?) })
.transpose()?
.as_ref(),
)
.from_py_result(py)?;
Ok(result.into_bound(py))
}
}
Expand All @@ -99,8 +102,9 @@ impl SimpleFunctionExecutor for Arc<PyFunctionExecutor> {
result_coro,
)?)
})?;
let result = result_fut.await?;
let result = result_fut.await;
Python::with_gil(|py| -> Result<_> {
let result = result.from_py_result(py)?;
Ok(py::value_from_py_object(
&self.result_type.typ,
&result.into_bound(py),
Expand Down Expand Up @@ -156,11 +160,14 @@ impl SimpleFunctionFactory for PyFunctionFactory {
.iter()
.map(|(name, _)| PyString::new(py, name).unbind())
.collect::<Vec<_>>();
let result = self.py_function_factory.call(
py,
PyTuple::new(py, args.into_iter())?,
Some(&kwargs.into_py_dict(py)?),
)?;
let result = self
.py_function_factory
.call(
py,
PyTuple::new(py, args.into_iter())?,
Some(&kwargs.into_py_dict(py)?),
)
.from_py_result(py)?;
let (result_type, executor) = result
.extract::<(crate::py::Pythonized<schema::EnrichedValueType>, Py<PyAny>)>(py)?;
Ok((
Expand All @@ -181,18 +188,22 @@ impl SimpleFunctionFactory for PyFunctionFactory {
.clone();
let (prepare_fut, enable_cache, behavior_version) =
Python::with_gil(|py| -> anyhow::Result<_> {
let prepare_coro = executor.call_method(py, "prepare", (), None)?;
let prepare_coro = executor
.call_method(py, "prepare", (), None)
.from_py_result(py)?;
let prepare_fut = pyo3_async_runtimes::into_future_with_locals(
&pyo3_async_runtimes::TaskLocals::new(
py_exec_ctx.event_loop.bind(py).clone(),
),
prepare_coro.into_bound(py),
)?;
let enable_cache = executor
.call_method(py, "enable_cache", (), None)?
.call_method(py, "enable_cache", (), None)
.from_py_result(py)?
.extract::<bool>(py)?;
let behavior_version = executor
.call_method(py, "behavior_version", (), None)?
.call_method(py, "behavior_version", (), None)
.from_py_result(py)?
.extract::<Option<u32>>(py)?;
Ok((prepare_fut, enable_cache, behavior_version))
})?;
Expand Down
4 changes: 2 additions & 2 deletions src/py/convert.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use bytes::Bytes;
use pyo3::types::{PyList, PyTuple};
use pyo3::IntoPyObjectExt;
use pyo3::{exceptions::PyException, prelude::*};
Expand All @@ -6,8 +7,7 @@ use serde::de::DeserializeOwned;
use serde::Serialize;
use std::collections::BTreeMap;
use std::ops::Deref;
use std::sync::Arc;
use bytes::Bytes;
use std::sync::Arc;

use super::IntoPyResult;
use crate::base::{schema, value};
Expand Down
19 changes: 19 additions & 0 deletions src/py/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::setup;
use pyo3::{exceptions::PyException, prelude::*};
use pyo3_async_runtimes::tokio::future_into_py;
use std::collections::btree_map;
use std::fmt::Write;

mod convert;
pub use convert::*;
Expand All @@ -26,6 +27,24 @@ impl PythonExecutionContext {
}
}

pub trait FromPyResult<T> {
fn from_py_result(self, py: Python<'_>) -> anyhow::Result<T>;
}

impl<T> FromPyResult<T> for Result<T, PyErr> {
fn from_py_result(self, py: Python<'_>) -> anyhow::Result<T> {
match self {
Ok(value) => Ok(value),
Err(err) => {
let mut err_str = format!("Error calling Python function: {}", err);
if let Some(tb) = err.traceback(py) {
write!(&mut err_str, "\n{}", tb.format()?)?;
}
Err(anyhow::anyhow!(err_str))
}
}
}
}
pub trait IntoPyResult<T> {
fn into_py_result(self) -> PyResult<T>;
}
Expand Down