Skip to content
70 changes: 51 additions & 19 deletions host/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@ use std::{any::Any, ops::DerefMut, sync::Arc};

use ::http::HeaderName;
use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature};
use datafusion_common::{DataFusionError, Result as DataFusionResult, config::ConfigOptions};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl},
};
use tokio::{runtime::Handle, sync::Mutex};
use wasmtime::{
Engine, Store,
component::{Component, ResourceAny},
};
use wasmtime_wasi::{ResourceTable, WasiCtx, WasiCtxView, WasiView, p2::pipe::MemoryOutputPipe};
use wasmtime_wasi::{
ResourceTable, WasiCtx, WasiCtxView, WasiView, async_trait, p2::pipe::MemoryOutputPipe,
};
use wasmtime_wasi_http::{
HttpResult, WasiHttpCtx, WasiHttpView,
bindings::http::types::ErrorCode as HttpErrorCode,
Expand Down Expand Up @@ -394,6 +399,11 @@ impl WasmScalarUdf {

Ok(udfs)
}

/// Convert this [WasmScalarUdf] into an [AsyncScalarUDF].
pub fn as_async_udf(self) -> AsyncScalarUDF {
AsyncScalarUDF::new(Arc::new(self))
}
}

impl std::fmt::Debug for WasmScalarUdf {
Expand Down Expand Up @@ -450,21 +460,43 @@ impl ScalarUDFImpl for WasmScalarUdf {
})
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
async_in_sync_context(async {
let args = args.try_into()?;
let mut store_guard = self.store.lock().await;
let return_type = self
.bindings
.datafusion_udf_wasm_udf_types()
.scalar_udf()
.call_invoke_with_args(store_guard.deref_mut(), self.resource, &args)
.await
.context(
"call ScalarUdf::invoke_with_args",
Some(&store_guard.data().stderr.contents()),
)??;
return_type.try_into()
})
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
Err(DataFusionError::NotImplemented(
"synchronous invocation of WasmScalarUdf is not supported, use invoke_async_with_args instead".to_string(),
))
}
}

#[async_trait]
impl AsyncScalarUDFImpl for WasmScalarUdf {
fn ideal_batch_size(&self) -> Option<usize> {
None
}

async fn invoke_async_with_args(
&self,
args: ScalarFunctionArgs,
_option: &ConfigOptions,
) -> DataFusionResult<arrow::array::ArrayRef> {
let args = args.try_into()?;
let mut store_guard = self.store.lock().await;
let return_type = self
.bindings
.datafusion_udf_wasm_udf_types()
.scalar_udf()
.call_invoke_with_args(store_guard.deref_mut(), self.resource, &args)
.await
.context(
"call ScalarUdf::invoke_with_args",
Some(&store_guard.data().stderr.contents()),
)??;

drop(store_guard);

let columnar_value: ColumnarValue = return_type.try_into()?;
match columnar_value {
ColumnarValue::Array(v) => Ok(v),
ColumnarValue::Scalar(v) => v.to_array_of_size(args.number_rows as usize),
}
}
}
1 change: 0 additions & 1 deletion host/tests/integration_tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
mod python;
mod rust;
mod test_utils;
95 changes: 53 additions & 42 deletions host/tests/integration_tests/python/argument_forms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ use arrow::{
array::{Array, Int64Array},
datatypes::{DataType, Field},
};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};

use crate::integration_tests::{
python::test_utils::python_scalar_udf, test_utils::ColumnarValueExt,
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
async_udf::AsyncScalarUDFImpl,
};

use crate::integration_tests::python::test_utils::python_scalar_udf;

#[tokio::test(flavor = "multi_thread")]
async fn test_positional_or_keyword() {
const CODE: &str = "
Expand All @@ -32,18 +34,21 @@ def foo(x: int) -> int:
);

let array = udf
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
Some(3),
None,
Some(-10),
])))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
})
.unwrap()
.unwrap_array();
.invoke_async_with_args(
ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
Some(3),
None,
Some(-10),
])))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
},
&ConfigOptions::default(),
)
.await
.unwrap();
assert_eq!(
array.as_ref(),
&Int64Array::from_iter([Some(4), None, Some(-9)]) as &dyn Array,
Expand Down Expand Up @@ -91,18 +96,21 @@ def foo(x: int, /) -> int:
);

let array = udf
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
Some(3),
None,
Some(-10),
])))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
})
.unwrap()
.unwrap_array();
.invoke_async_with_args(
ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
Some(3),
None,
Some(-10),
])))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
},
&ConfigOptions::default(),
)
.await
.unwrap();
assert_eq!(
array.as_ref(),
&Int64Array::from_iter([Some(4), None, Some(-9)]) as &dyn Array,
Expand Down Expand Up @@ -151,20 +159,23 @@ def foo(x: int, /, y: int) -> int:
);

let array = udf
.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(Int64Array::from_iter([Some(3)]))),
ColumnarValue::Array(Arc::new(Int64Array::from_iter([Some(4)]))),
],
arg_fields: vec![
Arc::new(Field::new("a1", DataType::Int64, true)),
Arc::new(Field::new("a2", DataType::Int64, true)),
],
number_rows: 1,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
})
.unwrap()
.unwrap_array();
.invoke_async_with_args(
ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(Int64Array::from_iter([Some(3)]))),
ColumnarValue::Array(Arc::new(Int64Array::from_iter([Some(4)]))),
],
arg_fields: vec![
Arc::new(Field::new("a1", DataType::Int64, true)),
Arc::new(Field::new("a2", DataType::Int64, true)),
],
number_rows: 1,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
},
&ConfigOptions::default(),
)
.await
.unwrap();
assert_eq!(
array.as_ref(),
&Int64Array::from_iter([Some(7)]) as &dyn Array,
Expand Down
56 changes: 32 additions & 24 deletions host/tests/integration_tests/python/examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ use arrow::{
datatypes::{DataType, Field},
};
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};

use crate::integration_tests::{
python::test_utils::python_scalar_udf, test_utils::ColumnarValueExt,
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
async_udf::AsyncScalarUDFImpl,
};

use crate::integration_tests::python::test_utils::python_scalar_udf;

#[tokio::test(flavor = "multi_thread")]
async fn test_add_one() {
const CODE: &str = "
Expand All @@ -37,33 +39,39 @@ def add_one(x: int) -> int:

// call with array
let array = udf
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
Some(3),
None,
Some(1),
])))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
})
.unwrap()
.unwrap_array();
.invoke_async_with_args(
ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
Some(3),
None,
Some(1),
])))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
},
&ConfigOptions::default(),
)
.await
.unwrap();
assert_eq!(
array.as_ref(),
&Int64Array::from_iter([Some(4), None, Some(2)]) as &dyn Array,
);

// call with scalar, output will still be an array
let array = udf
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(ScalarValue::Int64(Some(3)))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
})
.unwrap()
.unwrap_array();
.invoke_async_with_args(
ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(ScalarValue::Int64(Some(3)))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
},
&ConfigOptions::default(),
)
.await
.unwrap();
assert_eq!(
array.as_ref(),
&Int64Array::from_iter([Some(4), Some(4), Some(4)]) as &dyn Array,
Expand Down
61 changes: 33 additions & 28 deletions host/tests/integration_tests/python/runtime/dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ use arrow::{
array::{Array, Int64Array},
datatypes::{DataType, Field},
};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, async_udf::AsyncScalarUDFImpl};

use crate::integration_tests::{
python::test_utils::python_scalar_udf, test_utils::ColumnarValueExt,
};
use crate::integration_tests::python::test_utils::python_scalar_udf;

#[tokio::test(flavor = "multi_thread")]
async fn call_other_function() {
Expand All @@ -25,18 +24,21 @@ def foo(x: int) -> int:

let udf = python_scalar_udf(CODE).await.unwrap();
let array = udf
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
Some(1),
Some(2),
Some(3),
])))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
})
.unwrap()
.unwrap_array();
.invoke_async_with_args(
ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
Some(1),
Some(2),
Some(3),
])))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
},
&ConfigOptions::default(),
)
.await
.unwrap();
assert_eq!(
array.as_ref(),
&Int64Array::from_iter([Some(12), Some(23), Some(34)]) as &dyn Array,
Expand All @@ -59,18 +61,21 @@ def foo(x: int) -> int:

let udf = python_scalar_udf(CODE).await.unwrap();
let array = udf
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
Some(10),
Some(20),
Some(10),
])))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
})
.unwrap()
.unwrap_array();
.invoke_async_with_args(
ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter([
Some(10),
Some(20),
Some(10),
])))],
arg_fields: vec![Arc::new(Field::new("a1", DataType::Int64, true))],
number_rows: 3,
return_field: Arc::new(Field::new("r", DataType::Int64, true)),
},
&ConfigOptions::default(),
)
.await
.unwrap();
assert_eq!(
array.as_ref(),
&Int64Array::from_iter([Some(11), Some(22), Some(11)]) as &dyn Array,
Expand Down
Loading