Skip to content

Commit

Permalink
feat(snowflake/postgres): scalar UDFs
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Postgres UDFs now use the new `@udf.scalar.python` API. This should be a low-effort replacement for the existing API.
  • Loading branch information
cpcloud committed Jun 21, 2023
1 parent bb5c075 commit dbf5b62
Show file tree
Hide file tree
Showing 15 changed files with 607 additions and 153 deletions.
17 changes: 17 additions & 0 deletions ibis/backends/base/__init__.py
Expand Up @@ -268,6 +268,7 @@ def to_pyarrow(
A pyarrow table holding the results of the executed expression.
"""
pa = self._import_pyarrow()
self._run_pre_execute_hooks(expr)
try:
# Can't construct an array from record batches
# so construct at one column table (if applicable)
Expand Down Expand Up @@ -503,6 +504,8 @@ class BaseBackend(abc.ABC, _FileIOHandler):
name: ClassVar[str]

supports_temporary_tables = False
supports_python_udfs = False
supports_in_memory_tables = True

def __init__(self, *args, **kwargs):
self._con_args: tuple[Any] = args
Expand Down Expand Up @@ -754,6 +757,20 @@ def register_options(cls) -> None:
except ValueError as e:
raise exc.BackendConfigurationNotRegistered(backend_name) from e

def _register_udfs(self, expr: ir.Expr) -> None:
"""Register UDFs contained in `expr` with the backend."""
if self.supports_python_udfs:
raise NotImplementedError(self.name)

def _register_in_memory_tables(self, expr: ir.Expr):
if self.supports_in_memory_tables:
raise NotImplementedError(self.name)

def _run_pre_execute_hooks(self, expr: ir.Expr) -> None:
"""Backend-specific hooks to run before an expression is executed."""
self._register_udfs(expr)
self._register_in_memory_tables(expr)

def compile(
self,
expr: ir.Expr,
Expand Down
19 changes: 10 additions & 9 deletions ibis/backends/base/sql/__init__.py
Expand Up @@ -21,9 +21,7 @@
import pandas as pd
import pyarrow as pa

__all__ = [
'BaseSQLBackend',
]
__all__ = ['BaseSQLBackend']


class BaseSQLBackend(BaseBackend):
Expand Down Expand Up @@ -206,6 +204,12 @@ def to_pyarrow_batches(

return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), batches)

def _compile_udfs(self, expr: ir.Expr) -> Iterable[str]:
"""Return an iterator of DDL strings, once for each UDFs contained within `expr`."""
if self.supports_python_udfs:
raise NotImplementedError(self.name)
return []

def execute(
self,
expr: ir.Expr,
Expand Down Expand Up @@ -243,15 +247,15 @@ def execute(
# `external_tables` in clickhouse, but better to deprecate that
# feature than all this magic.
# we don't want to pass `timecontext` to `raw_sql`
self._run_pre_execute_hooks(expr)

kwargs.pop('timecontext', None)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
self._log(sql)

schema = self.ast_schema(query_ast, **kwargs)

self._run_pre_execute_hooks(expr)

with self._safe_raw_sql(sql, **kwargs) as cursor:
result = self.fetch_from_cursor(cursor, schema)

Expand All @@ -268,10 +272,6 @@ def _register_in_memory_tables(self, expr: ir.Expr) -> None:
for memtable in an.find_memtables(expr.op()):
self._register_in_memory_table(memtable)

def _run_pre_execute_hooks(self, expr: ir.Expr) -> None:
"""Backend-specific hooks to run before an expression is executed."""
self._register_in_memory_tables(expr)

@abc.abstractmethod
def fetch_from_cursor(self, cursor, schema):
"""Fetch data from cursor."""
Expand Down Expand Up @@ -344,6 +344,7 @@ def compile(
The output of compilation. The type of this value depends on the
backend.
"""
util.consume(self._compile_udfs(expr))
return self.compiler.to_ast_ensure_limit(expr, limit, params=params).compile()

def _to_sql(self, expr: ir.Expr, **kwargs) -> str:
Expand Down
53 changes: 53 additions & 0 deletions ibis/backends/base/sql/alchemy/__init__.py
Expand Up @@ -101,6 +101,12 @@ class BaseAlchemyBackend(BaseSQLBackend):
supports_temporary_tables = True
_temporary_prefix = "TEMPORARY"

def _compile_type(self, dtype) -> str:
dialect = self.con.dialect
return sa.types.to_instance(
self.compiler.translator_class.get_sqla_type(dtype)
).compile(dialect=dialect)

def _build_alchemy_url(self, url, host, port, user, password, database, driver):
if url is not None:
return sa.engine.url.make_url(url)
Expand Down Expand Up @@ -701,6 +707,53 @@ def insert(
f"The given obj is of type {type(obj).__name__} ."
)

def _compile_udfs(self, expr: ir.Expr) -> Iterable[str]:
for udf_node in expr.op().find(ops.ScalarUDF):
udf_node_type = type(udf_node)

if udf_node_type not in self.compiler.translator_class._registry:

@self.add_operation(udf_node_type)
def _(t, op):
generator = sa.func
if (namespace := op.__udf_namespace__) is not None:
generator = getattr(generator, namespace)
func = getattr(generator, type(op).__name__)
return func(*map(t.translate, op.args))

compile_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
compiled = compile_func(udf_node)
if compiled is not None:
yield compiled

def _compile_opaque_udf(self, udf_node: ops.ScalarUDF) -> str:
return None

def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str:
if self.supports_python_udfs:
raise NotImplementedError(
f"The {self.name} backend does not support Python scalar UDFs"
)

def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str:
if self.supports_python_udfs:
raise NotImplementedError(
f"The {self.name} backend does not support Pandas-based vectorized scalar UDFs"
)

def _compile_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> str:
if self.supports_python_udfs:
raise NotImplementedError(
f"The {self.name} backend does not support PyArrow-based vectorized scalar UDFs"
)

def _register_udfs(self, expr: ir.Expr) -> None:
with self.begin() as con:
for sql in self._compile_udfs(expr):
con.exec_driver_sql(sql)

def _quote(self, name: str) -> str:
"""Quote an identifier."""
preparer = self.con.dialect.identifier_preparer
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/bigquery/__init__.py
Expand Up @@ -74,6 +74,8 @@ def _create_client_info_gapic(application_name):
class Backend(BaseSQLBackend):
name = "bigquery"
compiler = BigQueryCompiler
supports_in_memory_tables = False
supports_python_udfs = False

def _from_url(self, url: str, **kwargs):
result = urlparse(url)
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/datafusion/__init__.py
Expand Up @@ -30,6 +30,7 @@
class Backend(BaseBackend):
name = 'datafusion'
builder = None
supports_in_memory_tables = False

@property
def version(self):
Expand Down
146 changes: 100 additions & 46 deletions ibis/backends/postgres/__init__.py
Expand Up @@ -2,24 +2,35 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Literal
import inspect
import textwrap
from typing import TYPE_CHECKING, Callable, Iterable, Literal

import sqlalchemy as sa

import ibis.common.exceptions as exc
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
from ibis.backends.postgres.compiler import PostgreSQLCompiler
from ibis.backends.postgres.datatypes import _get_type
from ibis.backends.postgres.udf import udf as _udf
from ibis.common.exceptions import InvalidDecoratorError

if TYPE_CHECKING:
import ibis.expr.datatypes as dt


def _verify_source_line(func_name: str, line: str):
if line.startswith("@"):
raise InvalidDecoratorError(func_name, line)
return line


class Backend(BaseAlchemyBackend):
name = "postgres"
compiler = PostgreSQLCompiler
supports_create_or_replace = False
supports_python_udfs = True

def do_connect(
self,
Expand Down Expand Up @@ -132,53 +143,96 @@ def list_databases(self, like=None):
]
return self._filter_with_like(databases, like)

def udf(
self,
pyfunc,
in_types,
out_type,
schema=None,
replace=False,
name=None,
language="plpythonu",
):
"""Decorator that defines a PL/Python UDF in-database.
def function(self, name: str, *, schema: str | None = None) -> Callable:
query = sa.text(
"""
SELECT
n.nspname as schema,
pg_catalog.pg_get_function_result(p.oid) as return_type,
string_to_array(pg_catalog.pg_get_function_arguments(p.oid), ', ') as signature,
CASE p.prokind
WHEN 'a' THEN 'agg'
WHEN 'w' THEN 'window'
WHEN 'p' THEN 'proc'
ELSE 'func'
END as "Type"
FROM pg_catalog.pg_proc p
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
WHERE p.proname = :name
"""
+ "AND n.nspname OPERATOR(pg_catalog.~) :schema COLLATE pg_catalog.default"
* (schema is not None)
).bindparams(name=name, schema=f"^({schema})$")

Parameters
----------
pyfunc
Python function
in_types
Input types
out_type
Output type
schema
The postgres schema in which to define the UDF
replace
replace UDF in database if already exists
name
name for the UDF to be defined in database
language
Language extension to use for PL/Python
Returns
-------
Callable
A callable ibis expression
Function that takes in Column arguments and returns an instance
inheriting from PostgresUDFNode
"""
def split_name_type(arg: str) -> tuple[str, dt.DataType]:
name, typ = arg.split(" ", 1)
return name, _get_type(typ)

with self.begin() as con:
rows = con.execute(query).mappings().fetchall()

if not rows:
name = f"{schema}.{name}" if schema else name
raise exc.MissingUDFError(name)
elif len(rows) > 1:
raise exc.AmbiguousUDFError(name)

[row] = rows
return_type = _get_type(row["return_type"])
signature = list(map(split_name_type, row["signature"]))

return _udf(
client=self,
python_func=pyfunc,
in_types=in_types,
out_type=out_type,
schema=schema,
replace=replace,
# dummy callable
def fake_func(*args, **kwargs):
...

fake_func.__name__ = name
fake_func.__signature__ = inspect.Signature(
[
inspect.Parameter(
name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typ
)
for name, typ in signature
],
return_annotation=return_type,
)
fake_func.__annotations__ = {"return": return_type, **dict(signature)}
op = ops.udf.scalar._opaque(fake_func, schema=schema)
return op

def _get_udf_source(self, udf_node: ops.ScalarUDF):
config = udf_node.__config__["kwargs"]
func = udf_node.__func__
func_name = func.__name__
schema = config.get("schema", "")
name = type(udf_node).__name__
ident = ".".join(filter(None, [schema, name]))
return dict(
name=name,
language=language,
ident=ident,
signature=", ".join(
f"{name} {self._compile_type(arg.output_dtype)}"
for name, arg in zip(udf_node.argnames, udf_node.args)
),
return_type=self._compile_type(udf_node.output_dtype),
language=config.get("language", "plpython3u"),
source="\n".join(
_verify_source_line(func_name, line)
for line in textwrap.dedent(inspect.getsource(func)).splitlines()
if not line.strip().startswith("@udf")
),
args=", ".join(udf_node.argnames),
)

def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str:
return """\
CREATE OR REPLACE FUNCTION {ident}({signature})
RETURNS {return_type}
LANGUAGE {language}
AS $$
{source}
return {name}({args})
$$""".format(
**self._get_udf_source(udf_node)
)

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
Expand Down

0 comments on commit dbf5b62

Please sign in to comment.