Skip to content

Commit

Permalink
fix(bigquery): move sql code to proper argument
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Oct 17, 2023
1 parent 343067c commit abb0bdd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 53 deletions.
89 changes: 38 additions & 51 deletions ibis/backends/bigquery/udf/__init__.py
Expand Up @@ -20,24 +20,10 @@
_udf_name_cache: dict[str, Iterable[int]] = collections.defaultdict(itertools.count)


def _create_udf_node(name, fields):
"""Create a new UDF node type.
Parameters
----------
name : str
Then name of the UDF node
fields : OrderedDict
Mapping of class member name to definition
Returns
-------
result : type
A new BigQueryUDFNode subclass
"""
def _make_udf_name(name):
definition = next(_udf_name_cache[name])
external_name = f"{name}_{definition:d}"
return type(external_name, (BigQueryUDFNode,), fields)
return external_name


class _BigQueryUDF:
Expand Down Expand Up @@ -274,24 +260,6 @@ def js(
if libraries is None:
libraries = []

udf_node_fields = {
name: rlz.ValueOf(None if type_ == "ANY TYPE" else type_)
for name, type_ in params.items()
}

udf_node_fields["dtype"] = output_type
udf_node_fields["shape"] = rlz.shape_like("args")
udf_node_fields["__slots__"] = ("sql",)

udf_node = _create_udf_node(name, udf_node_fields)

from ibis.backends.bigquery.compiler import compiles

@compiles(udf_node)
def compiles_udf_node(t, op):
args = ", ".join(map(t.translate, op.args))
return f"{udf_node.__name__}({args})"

bigquery_signature = ", ".join(
f"{name} {BigQueryType.from_ibis(dt.dtype(type_))}"
for name, type_ in params.items()
Expand All @@ -305,16 +273,35 @@ def compiles_udf_node(t, op):
False: "NOT DETERMINISTIC\n",
None: "",
}.get(determinism)

name = _make_udf_name(name)
sql_code = f'''\
CREATE TEMPORARY FUNCTION {udf_node.__name__}({bigquery_signature})
CREATE TEMPORARY FUNCTION {name}({bigquery_signature})
RETURNS {return_type}
{determinism_formatted}LANGUAGE js AS """
{body}
"""{libraries_opts};'''

udf_node_fields = {
name: rlz.ValueOf(None if type_ == "ANY TYPE" else type_)
for name, type_ in params.items()
}

udf_node_fields["dtype"] = output_type
udf_node_fields["shape"] = rlz.shape_like("args")
udf_node_fields["sql"] = sql_code

udf_node = type(name, (BigQueryUDFNode,), udf_node_fields)

from ibis.backends.bigquery.compiler import compiles

@compiles(udf_node)
def compiles_udf_node(t, op):
args = ", ".join(map(t.translate, op.args))
return f"{udf_node.__name__}({args})"

def wrapped(*args, **kwargs):
node = udf_node(*args, **kwargs)
object.__setattr__(node, "sql", sql_code)
return node.to_expr()

wrapped.__signature__ = inspect.Signature(
Expand Down Expand Up @@ -376,19 +363,6 @@ def sql(
}
return_type = BigQueryType.from_ibis(dt.dtype(output_type))

udf_node_fields["dtype"] = output_type
udf_node_fields["shape"] = rlz.shape_like("args")
udf_node_fields["__slots__"] = ("sql",)

udf_node = _create_udf_node(name, udf_node_fields)

from ibis.backends.bigquery.compiler import compiles

@compiles(udf_node)
def compiles_udf_node(t, op):
args_formatted = ", ".join(map(t.translate, op.args))
return f"{udf_node.__name__}({args_formatted})"

bigquery_signature = ", ".join(
"{name} {type}".format(
name=name,
Expand All @@ -398,14 +372,27 @@ def compiles_udf_node(t, op):
)
for name, type_ in params.items()
)
name = _make_udf_name(name)
sql_code = f"""\
CREATE TEMPORARY FUNCTION {udf_node.__name__}({bigquery_signature})
CREATE TEMPORARY FUNCTION {name}({bigquery_signature})
RETURNS {return_type}
AS ({sql_expression});"""

udf_node_fields["dtype"] = output_type
udf_node_fields["shape"] = rlz.shape_like("args")
udf_node_fields["sql"] = sql_code

udf_node = type(name, (BigQueryUDFNode,), udf_node_fields)

from ibis.backends.bigquery.compiler import compiles

@compiles(udf_node)
def compiles_udf_node(t, op):
args = ", ".join(map(t.translate, op.args))
return f"{udf_node.__name__}({args})"

def wrapper(*args, **kwargs):
node = udf_node(*args, **kwargs)
object.__setattr__(node, "sql", sql_code)
return node.to_expr()

return wrapper
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/tests/test_export.py
Expand Up @@ -329,7 +329,6 @@ def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype):

@pytest.mark.notyet(
[
"bigquery",
"impala",
"mysql",
"oracle",
Expand All @@ -343,7 +342,7 @@ def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype):
)
@pytest.mark.notyet(["clickhouse"], raises=Exception)
@pytest.mark.notyet(["mssql", "pandas"], raises=PyDeltaTableError)
@pytest.mark.notyet(["dask"], raises=NotImplementedError)
@pytest.mark.notyet(["bigquery", "dask"], raises=NotImplementedError)
@pytest.mark.notyet(
["druid"],
raises=pa.lib.ArrowTypeError,
Expand Down

0 comments on commit abb0bdd

Please sign in to comment.