Skip to content

Commit

Permalink
fix(bigquery): escape the schema (project ID) for BQ builtin UDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
tswast authored and cpcloud committed Dec 16, 2023
1 parent ec979f0 commit 8096552
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 4 deletions.
9 changes: 6 additions & 3 deletions ibis/backends/base/sql/__init__.py
Expand Up @@ -4,7 +4,7 @@
import contextlib
import os
from functools import lru_cache
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

import toolz

Expand Down Expand Up @@ -255,18 +255,21 @@ def _register_udfs(self, expr: ir.Expr) -> None:
if self.supports_python_udfs:
raise NotImplementedError(self.name)

def _gen_udf_name(self, name: str, schema: Optional[str]) -> str:
return ".".join(filter(None, (schema, name)))

def _gen_udf_rule(self, op: ops.ScalarUDF):
@self.add_operation(type(op))
def _(t, op):
func = ".".join(filter(None, (op.__udf_namespace__, op.__func_name__)))
func = self._gen_udf_name(op.__func_name__, schema=op.__udf_namespace__)
return f"{func}({', '.join(map(t.translate, op.args))})"

def _gen_udaf_rule(self, op: ops.AggUDF):
from ibis import NA

@self.add_operation(type(op))
def _(t, op):
func = ".".join(filter(None, (op.__udf_namespace__, op.__func_name__)))
func = self._gen_udf_name(op.__func_name__, schema=op.__udf_namespace__)
args = ", ".join(
t.translate(
ops.IfElse(where, arg, NA)
Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/bigquery/__init__.py
Expand Up @@ -9,7 +9,7 @@
import re
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional
from urllib.parse import parse_qs, urlparse

import google.auth.credentials
Expand Down Expand Up @@ -785,6 +785,12 @@ def to_pyarrow_batches(
)
return pa.RecordBatchReader.from_batches(schema.to_pyarrow(), batch_iter)

def _gen_udf_name(self, name: str, schema: Optional[str]) -> str:
func = ".".join(filter(None, (schema, name)))
if "." in func:
return ".".join(f"`{part}`" for part in func.split("."))
return func

def get_schema(self, name, schema: str | None = None, database: str | None = None):
table_ref = bq.TableReference(
bq.DatasetReference(
Expand Down
@@ -0,0 +1,2 @@
SELECT
`bqutil`.`fn`.from_hex('face') AS `from_hex_'face'`
@@ -0,0 +1,2 @@
SELECT
farm_fingerprint(b'Hello, World!') AS `farm_fingerprint_b'Hello_ World_'`
30 changes: 30 additions & 0 deletions ibis/backends/bigquery/tests/unit/udf/test_builtin.py
@@ -0,0 +1,30 @@

import ibis

to_sql = ibis.bigquery.compile


@ibis.udf.scalar.builtin
def farm_fingerprint(value: bytes) -> int:
...


@ibis.udf.scalar.builtin(schema="bqutil.fn")
def from_hex(value: str) -> int:
"""Community function to convert from hex string to integer.
See:
https://github.com/GoogleCloudPlatform/bigquery-utils/tree/master/udfs/community#from_hexvalue-string
"""


def test_bqutil_fn_from_hex(snapshot):
# Project ID should be enclosed in backticks.
expr = from_hex("face")
snapshot.assert_match(to_sql(expr), "out.sql")


def test_farm_fingerprint(snapshot):
# No backticks needed if there is no schema defined.
expr = farm_fingerprint(b"Hello, World!")
snapshot.assert_match(to_sql(expr), "out.sql")

0 comments on commit 8096552

Please sign in to comment.