Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install pipenv
run: curl https://raw.githubusercontent.com/pypa/pipenv/master/get-pipenv.py | python
- uses: actions/setup-python@v5
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Lint
python-version: "3.12"
- name: Install dependencies
working-directory: python
run: pip install ruff
- name: Check format
working-directory: python
run: |
pipenv install --dev
pipenv run ruff check databend_udf
pipenv run ruff format --check databend_udf
ruff format --check .
- name: build
working-directory: python
run: |
Expand Down
3 changes: 3 additions & 0 deletions python/.gitignore → .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ dist/
__pycache__/
Pipfile.lock
.ruff_cache/
.vscode
python/example/test.py

123 changes: 118 additions & 5 deletions python/databend_udf/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import inspect
from concurrent.futures import ThreadPoolExecutor
from typing import Iterator, Callable, Optional, Union, List, Dict
from prometheus_client import Counter, Gauge, Histogram
from prometheus_client import start_http_server
import threading

import pyarrow as pa
from pyarrow.flight import FlightServerBase, FlightInfo
Expand Down Expand Up @@ -229,11 +232,80 @@ class UDFServer(FlightServerBase):
_location: str
_functions: Dict[str, UserDefinedFunction]

def __init__(self, location="0.0.0.0:8815", **kwargs):
def __init__(self, location="0.0.0.0:8815", metric_location=None, **kwargs):
super(UDFServer, self).__init__("grpc://" + location, **kwargs)
self._location = location
self._metric_location = metric_location
self._functions = {}

# Initialize Prometheus metrics
self.requests_count = Counter(
"udf_server_requests_count",
"Total number of UDF requests processed",
["function_name"],
)
self.rows_count = Counter(
"udf_server_rows_count", "Total number of rows processed", ["function_name"]
)
self.running_requests = Gauge(
"udf_server_running_requests_count",
"Number of currently running UDF requests",
["function_name"],
)
self.running_rows = Gauge(
"udf_server_running_rows_count",
"Number of currently processing rows",
["function_name"],
)
self.response_duration = Histogram(
"udf_server_response_duration_seconds",
"Time spent processing UDF requests",
["function_name"],
buckets=(
0.005,
0.01,
0.025,
0.05,
0.075,
0.1,
0.25,
0.5,
0.75,
1.0,
2.5,
5.0,
7.5,
10.0,
),
)

self.error_count = Counter(
"udf_server_errors_count",
"Total number of UDF processing errors",
["function_name", "error_type"],
)

self.add_function(builtin_echo)
self.add_function(builtin_healthy)

def _start_metrics_server(self):
"""Start Prometheus metrics HTTP server if metric_location is provided"""
try:
host, port = self._metric_location.split(":")
port = int(port)

def start_server():
start_http_server(port, host)
logger.info(
f"Prometheus metrics server started on {self._metric_location}"
)

metrics_thread = threading.Thread(target=start_server, daemon=True)
metrics_thread.start()
except Exception as e:
logger.error(f"Failed to start metrics server: {e}")
raise

def get_flight_info(self, context, descriptor):
"""Return the result schema of a function."""
func_name = descriptor.path[0].decode("utf-8")
Expand All @@ -257,13 +329,38 @@ def do_exchange(self, context, descriptor, reader, writer):
raise ValueError(f"Function {func_name} does not exists")
udf = self._functions[func_name]
writer.begin(udf._result_schema)

# Increment request counter
self.requests_count.labels(function_name=func_name).inc()
# Increment running requests gauge
self.running_requests.labels(function_name=func_name).inc()

try:
for batch in reader:
for output_batch in udf.eval_batch(batch.data):
writer.write_batch(output_batch)
with self.response_duration.labels(function_name=func_name).time():
for batch in reader:
# Update row metrics
batch_rows = batch.data.num_rows
self.rows_count.labels(function_name=func_name).inc(batch_rows)
self.running_rows.labels(function_name=func_name).inc(batch_rows)

try:
for output_batch in udf.eval_batch(batch.data):
writer.write_batch(output_batch)
finally:
# Decrease running rows gauge after processing
self.running_rows.labels(function_name=func_name).dec(
batch_rows
)

except Exception as e:
self.error_count.labels(
function_name=func_name, error_type=e.__class__.__name__
).inc()
logger.exception(e)
raise e
finally:
# Decrease running requests gauge
self.running_requests.labels(function_name=func_name).dec()

def add_function(self, udf: UserDefinedFunction):
"""Add a function to the server."""
Expand All @@ -284,7 +381,13 @@ def add_function(self, udf: UserDefinedFunction):

def serve(self):
"""Start the server."""
logger.info(f"listening on {self._location}")
logger.info(f"UDF server listening on {self._location}")
if self._metric_location:
self._start_metrics_server()
logger.info(
f"Prometheus metrics available at http://{self._metric_location}/metrics"
)

super(UDFServer, self).serve()


Expand Down Expand Up @@ -586,3 +689,13 @@ def _field_type_to_string(field: pa.Field) -> str:
return f"TUPLE({args_str})"
else:
raise ValueError(f"Unsupported type: {t}")


@udf(input_types=["VARCHAR"], result_type="VARCHAR")
def builtin_echo(a):
return a


@udf(input_types=[], result_type="INT")
def builtin_healthy():
return 1
3 changes: 2 additions & 1 deletion python/example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import List, Dict, Any, Tuple, Optional

from databend_udf import udf, UDFServer
# from test import udf, UDFServer

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -313,7 +314,7 @@ def wait_concurrent(x):


if __name__ == "__main__":
udf_server = UDFServer("0.0.0.0:8815")
udf_server = UDFServer("0.0.0.0:8815", metric_location="0.0.0.0:8816")
udf_server.add_function(add_signed)
udf_server.add_function(add_unsigned)
udf_server.add_function(add_float)
Expand Down
6 changes: 4 additions & 2 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ name = "databend-udf"
version = "0.2.5"
readme = "README.md"
requires-python = ">=3.7"
dependencies = ["pyarrow"]

dependencies = [
"pyarrow",
"prometheus-client>=0.17.0"
]
[project.optional-dependencies]
lint = ["ruff"]

Expand Down