From 287a6536ab7e8bee08065042524f6d5b53c5bc9c Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Mon, 19 Feb 2024 06:27:00 +0800 Subject: [PATCH 1/3] fix type conversion --- python/databend_udf/udf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/databend_udf/udf.py b/python/databend_udf/udf.py index 85851bd..9775fd1 100644 --- a/python/databend_udf/udf.py +++ b/python/databend_udf/udf.py @@ -543,9 +543,9 @@ def _field_type_to_string(field: pa.Field) -> str: return "DATE" elif pa.types.is_timestamp(t): return "TIMESTAMP" - elif pa.types.is_large_unicode(t): + elif pa.types.is_large_unicode(t) or pa.types.is_unicode(t): return "VARCHAR" - elif pa.types.is_large_binary(t): + elif pa.types.is_large_binary(t) or pa.types.is_binary(t): if _field_is_variant(field): return "VARIANT" else: From ae390cdbc67bf780ead56591c79c9a7c7a47a34b Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Sun, 1 Sep 2024 11:36:53 +0800 Subject: [PATCH 2/3] feat(udf): support batch mode --- python/databend_udf/udf.py | 23 +++++++++++++++++++---- python/example/server.py | 15 +++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/python/databend_udf/udf.py b/python/databend_udf/udf.py index 9775fd1..247b386 100644 --- a/python/databend_udf/udf.py +++ b/python/databend_udf/udf.py @@ -59,9 +59,10 @@ class ScalarFunction(UserDefinedFunction): _io_threads: Optional[int] _executor: Optional[ThreadPoolExecutor] _skip_null: bool + _batch_mode: bool def __init__( - self, func, input_types, result_type, name=None, io_threads=None, skip_null=None + self, func, input_types, result_type, name=None, io_threads=None, skip_null=None, batch_mode=False ): self._func = func self._input_schema = pa.schema( @@ -78,6 +79,7 @@ def __init__( func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ ) self._io_threads = io_threads + self._batch_mode = batch_mode self._executor = ( ThreadPoolExecutor(max_workers=self._io_threads) if self._io_threads is not None @@ -98,7 +100,11 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: _input_process_func(_list_field(field))(array) for array, field in zip(inputs, self._input_schema) ] - if self._executor is not None: + + # evaluate the function for each row + if self._batch_mode: + column = self._func(*inputs) + elif self._executor is not None: # concurrently evaluate the function for each row if self._skip_null: tasks = [] @@ -113,7 +119,6 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: ] column = [future.result() for future in tasks] else: - # evaluate the function for each row if self._skip_null: column = [] for row in range(batch.num_rows): @@ -140,6 +145,7 @@ def udf( name: Optional[str] = None, io_threads: Optional[int] = None, skip_null: Optional[bool] = False, + batch_mode: Optional[bool] = False, ) -> Callable: """ Annotation for creating a user-defined scalar function. @@ -153,6 +159,7 @@ def udf( - skip_null: A boolean value specifying whether to skip NULL value. If it is set to True, NULL values will not be passed to the function, and the corresponding return value is set to NULL. Default to False. + - batch_mode: A boolean value specifying whether to use batch mode. Default to False. Example: ``` @@ -170,6 +177,13 @@ def external_api(x): response = requests.get(my_endpoint + '?param=' + x) return response["data"] ``` + + Batch mode example: + ``` + @udf(input_types=['INT', 'INT'], result_type='INT', batch_mode=True) + def gcd(x, y): + return [x_i if y_i == 0 else gcd(y_i, x_i % y_i) for x_i, y_i in zip(x, y)] + ``` """ if io_threads is not None and io_threads > 1: @@ -180,10 +194,11 @@ def external_api(x): name, io_threads=io_threads, skip_null=skip_null, + batch_mode=batch_mode ) else: return lambda f: ScalarFunction( - f, input_types, result_type, name, skip_null=skip_null + f, input_types, result_type, name, skip_null=skip_null, batch_mode=batch_mode ) diff --git a/python/example/server.py b/python/example/server.py index 42ad118..8d68898 100644 --- a/python/example/server.py +++ b/python/example/server.py @@ -54,6 +54,20 @@ def gcd(x: int, y: int) -> int: (x, y) = (y, x % y) return x +@udf( + name="gcd_batch", + input_types=["INT", "INT"], + result_type="INT", + batch_mode=True, +) +def gcd_batch(x: list[int], y: list[int]) -> list[int]: + def gcd_single(x_i, y_i): + if x_i == None or y_i == None: + return None + while y_i != 0: + (x_i, y_i) = (y_i, x_i % y_i) + return x_i + return [gcd_single(x_i, y_i) for x_i, y_i in zip(x, y)] @udf(input_types=["VARCHAR", "VARCHAR", "VARCHAR"], result_type="VARCHAR") def split_and_join(s: str, split_s: str, join_s: str) -> str: @@ -303,6 +317,7 @@ def wait_concurrent(x): udf_server.add_function(binary_reverse) udf_server.add_function(bool_select) udf_server.add_function(gcd) + udf_server.add_function(gcd_batch) udf_server.add_function(split_and_join) udf_server.add_function(decimal_div) udf_server.add_function(hex_to_dec) From e0ccad9eef11fb9948777377a6ab7704785f10e2 Mon Sep 17 00:00:00 2001 From: sundy-li <543950155@qq.com> Date: Sun, 1 Sep 2024 11:40:03 +0800 Subject: [PATCH 3/3] feat(udf): support batch mode --- python/databend_udf/udf.py | 24 ++++++++++++++++++------ python/example/server.py | 3 +++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/python/databend_udf/udf.py b/python/databend_udf/udf.py index 247b386..1e092f2 100644 --- a/python/databend_udf/udf.py +++ b/python/databend_udf/udf.py @@ -62,7 +62,14 @@ class ScalarFunction(UserDefinedFunction): _batch_mode: bool def __init__( - self, func, input_types, result_type, name=None, io_threads=None, skip_null=None, batch_mode=False + self, + func, + input_types, + result_type, + name=None, + io_threads=None, + skip_null=None, + batch_mode=False, ): self._func = func self._input_schema = pa.schema( @@ -100,8 +107,8 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: _input_process_func(_list_field(field))(array) for array, field in zip(inputs, self._input_schema) ] - - # evaluate the function for each row + + # evaluate the function for each row if self._batch_mode: column = self._func(*inputs) elif self._executor is not None: @@ -177,7 +184,7 @@ def external_api(x): response = requests.get(my_endpoint + '?param=' + x) return response["data"] ``` - + Batch mode example: ``` @udf(input_types=['INT', 'INT'], result_type='INT', batch_mode=True) @@ -194,11 +201,16 @@ def gcd(x, y): name, io_threads=io_threads, skip_null=skip_null, - batch_mode=batch_mode + batch_mode=batch_mode, ) else: return lambda f: ScalarFunction( - f, input_types, result_type, name, skip_null=skip_null, batch_mode=batch_mode + f, + input_types, + result_type, + name, + skip_null=skip_null, + batch_mode=batch_mode, ) diff --git a/python/example/server.py b/python/example/server.py index 8d68898..d58cc84 100644 --- a/python/example/server.py +++ b/python/example/server.py @@ -54,6 +54,7 @@ def gcd(x: int, y: int) -> int: (x, y) = (y, x % y) return x + @udf( name="gcd_batch", input_types=["INT", "INT"], @@ -67,8 +68,10 @@ def gcd_single(x_i, y_i): while y_i != 0: (x_i, y_i) = (y_i, x_i % y_i) return x_i + return [gcd_single(x_i, y_i) for x_i, y_i in zip(x, y)] + @udf(input_types=["VARCHAR", "VARCHAR", "VARCHAR"], result_type="VARCHAR") def split_and_join(s: str, split_s: str, join_s: str) -> str: return join_s.join(s.split(split_s))