Skip to content

Commit

Permalink
Merge pull request #100 from chdb-io/binUdf
Browse files Browse the repository at this point in the history
Support UDF in Python
  • Loading branch information
auxten committed Sep 4, 2023
2 parents 4061f73 + afc3faa commit 409350e
Show file tree
Hide file tree
Showing 9 changed files with 280 additions and 42 deletions.
20 changes: 19 additions & 1 deletion README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ sess.query(
print("Select from view:\n")
print(sess.query("SELECT * FROM db_xxx.view_xxx", "Pretty"))
```


参见: [test_stateful.py](tests/test_stateful.py)
</details>

<details>
Expand All @@ -126,6 +127,23 @@ conn1.close()
```
</details>

<details>
<summary><h4>🗂️ Query with UDF(User Defined Functions)</h4></summary>

```python
from chdb.udf import chdb_udf
from chdb import query

@chdb_udf()
def sum_udf(lhs, rhs):
return int(lhs) + int(rhs)

print(query("select sum_udf(12,22)"))
```

参见: [test_udf.py](tests/test_udf.py).
</details>

更多示例,请参见 [examples](examples)[tests](tests)

## 演示和示例
Expand Down
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ sess.query(
print("Select from view:\n")
print(sess.query("SELECT * FROM db_xxx.view_xxx", "Pretty"))
```


see also: [test_stateful.py](tests/test_stateful.py).
</details>

<details>
Expand All @@ -132,6 +133,23 @@ conn1.close()
</details>


<details>
<summary><h4>🗂️ Query with UDF(User Defined Functions)</h4></summary>

```python
from chdb.udf import chdb_udf
from chdb import query

@chdb_udf()
def sum_udf(lhs, rhs):
return int(lhs) + int(rhs)

print(query("select sum_udf(12,22)"))
```

see also: [test_udf.py](tests/test_udf.py).
</details>

For more examples, see [examples](examples) and [tests](tests).

<br>
Expand Down
36 changes: 18 additions & 18 deletions chdb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import sys
import os


# If any UDF is defined, the path of the UDF will be set to this variable
# and the path will be deleted when the process exits
# UDF config path will be f"{g_udf_path}/udf_config.xml"
# UDF script path will be f"{g_udf_path}/{func_name}.py"
g_udf_path = ""

chdb_version = (0, 6, 0)
if sys.version_info[:2] >= (3, 7):
# get the path of the current file
Expand Down Expand Up @@ -32,37 +39,30 @@ def to_arrowTable(res):
import pyarrow as pa
import pandas
except ImportError as e:
print(f'ImportError: {e}')
print(f"ImportError: {e}")
print('Please install pyarrow and pandas via "pip install pyarrow pandas"')
raise ImportError('Failed to import pyarrow or pandas') from None
raise ImportError("Failed to import pyarrow or pandas") from None
if len(res) == 0:
return pa.Table.from_batches([], schema=pa.schema([]))
return pa.RecordBatchFileReader(res.bytes()).read_all()


# return pandas dataframe
def to_df(r):
""""convert arrow table to Dataframe"""
"""convert arrow table to Dataframe"""
t = to_arrowTable(r)
return t.to_pandas(use_threads=True)


# wrap _chdb functions
def query(sql, output_format="CSV"):
lower_output_format = output_format.lower()
if lower_output_format == "dataframe":
return to_df(_chdb.query(sql, "Arrow"))
elif lower_output_format == 'arrowtable':
return to_arrowTable(_chdb.query(sql, "Arrow"))
else:
return _chdb.query(sql, output_format)


def query_stateful(sql, output_format="CSV", path=None):
def query(sql, output_format="CSV", path="", udf_path=""):
global g_udf_path
if udf_path != "":
g_udf_path = udf_path
lower_output_format = output_format.lower()
if lower_output_format == "dataframe":
return to_df(_chdb.query_stateful(sql, "Arrow", path))
elif lower_output_format == 'arrowtable':
return to_arrowTable(_chdb.query_stateful(sql, "Arrow", path))
return to_df(_chdb.query(sql, "Arrow", path=path, udf_path=g_udf_path))
elif lower_output_format == "arrowtable":
return to_arrowTable(_chdb.query(sql, "Arrow", path=path, udf_path=g_udf_path))
else:
return _chdb.query_stateful(sql, output_format, path)
return _chdb.query(sql, output_format, path=path, udf_path=g_udf_path)
21 changes: 15 additions & 6 deletions chdb/session/state.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import tempfile
import shutil

from chdb import query_stateful
from chdb import query, g_udf_path


class Session():
class Session:
"""
Session will keep the state of query. All DDL and DML state will be kept in a dir.
Dir path could be passed in as an argument. If not, a temporary dir will be created.
If path is not specified, the temporary dir will be deleted when the Session object is deleted.
Otherwise path will be kept.
Note: The default database is "_local" and the default engine is "Memory" which means all data
Note: The default database is "_local" and the default engine is "Memory" which means all data
will be stored in memory. If you want to store data in disk, you should create another database.
"""

Expand All @@ -28,11 +28,20 @@ def __del__(self):
if self._cleanup:
self.cleanup()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.cleanup()

def cleanup(self):
shutil.rmtree(self._path)
try:
shutil.rmtree(self._path)
except:
pass

def query(self, sql, fmt="CSV"):
"""
Execute a query.
"""
return query_stateful(sql, fmt, path=self._path)
return query(sql, fmt, path=self._path)
1 change: 1 addition & 0 deletions chdb/udf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .udf import *
106 changes: 106 additions & 0 deletions chdb/udf/udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import functools
import inspect
import os
import sys
import tempfile
import atexit
import shutil
import textwrap
from xml.etree import ElementTree as ET
import chdb


def generate_udf(func_name, args, return_type, udf_body):
# generate python script
with open(f"{chdb.g_udf_path}/{func_name}.py", "w") as f:
f.write(f"#!{sys.executable}\n")
f.write("import sys\n")
f.write("\n")
for line in udf_body.split("\n"):
f.write(f"{line}\n")
f.write("\n")
f.write("if __name__ == '__main__':\n")
f.write(" for line in sys.stdin:\n")
f.write(" args = line.strip().split('\t')\n")
for i, arg in enumerate(args):
f.write(f" {arg} = args[{i}]\n")
f.write(f" print({func_name}({', '.join(args)}))\n")
f.write(" sys.stdout.flush()\n")
os.chmod(f"{chdb.g_udf_path}/{func_name}.py", 0o755)
# generate xml file
xml_file = f"{chdb.g_udf_path}/udf_config.xml"
root = ET.Element("functions")
if os.path.exists(xml_file):
tree = ET.parse(xml_file)
root = tree.getroot()
function = ET.SubElement(root, "function")
ET.SubElement(function, "type").text = "executable"
ET.SubElement(function, "name").text = func_name
ET.SubElement(function, "return_type").text = return_type
ET.SubElement(function, "format").text = "TabSeparated"
ET.SubElement(function, "command").text = f"{func_name}.py"
for arg in args:
argument = ET.SubElement(function, "argument")
# We use TabSeparated format, so assume all arguments are strings
ET.SubElement(argument, "type").text = "String"
ET.SubElement(argument, "name").text = arg
tree = ET.ElementTree(root)
tree.write(xml_file)


def chdb_udf(return_type="String"):
"""
Decorator for chDB Python UDF(User Defined Function).
1. The function should be stateless. So, only UDFs are supported, not UDAFs(User Defined Aggregation Function).
2. Default return type is String. If you want to change the return type, you can pass in the return type as an argument.
The return type should be one of the following: https://clickhouse.com/docs/en/sql-reference/data-types
3. The function should take in arguments of type String. As the input is TabSeparated, all arguments are strings.
4. The function will be called for each line of input. Something like this:
```
def sum_udf(lhs, rhs):
return int(lhs) + int(rhs)
for line in sys.stdin:
args = line.strip().split('\t')
lhs = args[0]
rhs = args[1]
print(sum_udf(lhs, rhs))
sys.stdout.flush()
```
5. The function should be pure python function. You SHOULD import all python modules used IN THE FUNCTION.
```
def func_use_json(arg):
import json
...
```
6. Python interpertor used is the same as the one used to run the script. Get from `sys.executable`
"""

def decorator(func):
func_name = func.__name__
sig = inspect.signature(func)
args = list(sig.parameters.keys())
src = inspect.getsource(func)
src = textwrap.dedent(src)
udf_body = src.split("\n", 1)[1] # remove the first line "@chdb_udf()"
# create tmp dir and make sure the dir is deleted when the process exits
if chdb.g_udf_path == "":
chdb.g_udf_path = tempfile.mkdtemp()

# clean up the tmp dir on exit
@atexit.register
def _cleanup():
try:
shutil.rmtree(chdb.g_udf_path)
except:
pass

generate_udf(func_name, args, return_type, udf_body)

@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper

return decorator
50 changes: 34 additions & 16 deletions programs/local/LocalChdb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
extern bool inside_main = true;


local_result * queryToBuffer(const std::string & queryStr, const std::string & format = "CSV", const std::string & path = {})
local_result * queryToBuffer(
const std::string & queryStr,
const std::string & output_format = "CSV",
const std::string & path = {},
const std::string & udfPath = {})
{
std::vector<std::string> argv = {"clickhouse", "--multiquery"};
std::vector<std::string> argv = {"clickhouse", "--", "--multiquery"};

// if format is "Debug" or "debug", then we will add --verbose and --log-level=trace to argv
if (format == "Debug" || format == "debug")
// If format is "Debug" or "debug", then we will add `--verbose` and `--log-level=trace` to argv
if (output_format == "Debug" || output_format == "debug")
{
argv.push_back("--verbose");
argv.push_back("--log-level=trace");
Expand All @@ -21,10 +25,19 @@ local_result * queryToBuffer(const std::string & queryStr, const std::string & f
else
{
// Add format string
argv.push_back("--output-format=" + format);
argv.push_back("--output-format=" + output_format);
}

if (!path.empty())
// If udfPath is not empty, then we will add `--user_scripts_path` and `--user_defined_executable_functions_config` to argv
// the path should be a one time thing, so the caller should take care of the temporary files deletion
if (!udfPath.empty())
{
argv.push_back("--user_scripts_path=" + udfPath);
argv.push_back("--user_defined_executable_functions_config=" + udfPath + "/*.xml");
}

// If path is not empty, then we will add `--path` to argv. This is used for chdb.Session to support stateful query
if (!path.empty())
{
// Add path string
argv.push_back("--path=" + path);
Expand All @@ -42,14 +55,13 @@ local_result * queryToBuffer(const std::string & queryStr, const std::string & f

// Pybind11 will take over the ownership of the `query_result` object
// using smart ptr will cause early free of the object
query_result * query(const std::string & queryStr, const std::string & format = "CSV")
query_result * query(
const std::string & queryStr,
const std::string & output_format = "CSV",
const std::string & path = {},
const std::string & udfPath = {})
{
return new query_result(queryToBuffer(queryStr, format));
}

query_result * query_stateful(const std::string & queryStr, const std::string & format = "CSV", const std::string & path = {})
{
return new query_result(queryToBuffer(queryStr, format, path));
return new query_result(queryToBuffer(queryStr, output_format, path, udfPath));
}

// The `query_result` and `memoryview_wrapper` will hold `local_result_wrapper` with shared_ptr
Expand Down Expand Up @@ -132,9 +144,15 @@ PYBIND11_MODULE(_chdb, m)
.def("get_memview", &query_result::get_memview);


m.def("query", &query, "Stateless query Clickhouse and return a query_result object");

m.def("query_stateful", &query_stateful, "Stateful query Clickhouse and return a query_result object");
m.def(
"query",
&query,
py::arg("queryStr"),
py::arg("output_format") = "CSV",
py::kw_only(),
py::arg("path") = "",
py::arg("udf_path") = "",
"Query chDB and return a query_result object");
}

#endif // PY_TEST_MAIN
10 changes: 10 additions & 0 deletions tests/test_stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ def test_tmp(self):
ret = sess2.query("SELECT chdb_xxx()", "CSV")
self.assertEqual(str(ret), "")

def test_context_mgr(self):
with session.Session() as sess:
sess.query("CREATE FUNCTION chdb_xxx AS () -> '0.12.0'", "CSV")
ret = sess.query("SELECT chdb_xxx()", "CSV")
self.assertEqual(str(ret), '"0.12.0"\n')

with session.Session() as sess:
ret = sess.query("SELECT chdb_xxx()", "CSV")
self.assertEqual(str(ret), "")

def test_zfree_thread_count(self):
time.sleep(3)
thread_count = current_process.num_threads()
Expand Down
Loading

0 comments on commit 409350e

Please sign in to comment.