Skip to content

Commit

Permalink
feat(bigquery): implement a few URL ops (#9210)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
  • Loading branch information
krzysztof-kwitt and cpcloud authored Jun 4, 2024
1 parent 6b14c20 commit 3d0f9bc
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 38 deletions.
63 changes: 28 additions & 35 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ class BigQueryCompiler(SQLGlotCompiler):
ops.CountDistinctStar,
ops.DateDiff,
ops.ExtractAuthority,
ops.ExtractFile,
ops.ExtractFragment,
ops.ExtractHost,
ops.ExtractPath,
ops.ExtractProtocol,
ops.ExtractQuery,
ops.ExtractUserInfo,
ops.FindInSet,
ops.Median,
Expand Down Expand Up @@ -120,6 +114,7 @@ class BigQueryCompiler(SQLGlotCompiler):
ops.TimeFromHMS: "time",
ops.TimestampFromYMDHMS: "datetime",
ops.TimestampNow: "current_timestamp",
ops.ExtractHost: "net.host",
}

@staticmethod
Expand Down Expand Up @@ -627,33 +622,7 @@ def visit_Correlation(self, op, *, left, right, how, where):
return self.agg.corr(left, right, where=where)

def visit_TypeOf(self, op, *, arg):
name = sg.to_identifier(util.gen_name("bq_typeof"))
from_ = self._unnest(self.f.array(self.f.format("%T", arg)), as_=name)
ifs = [
self.if_(
self.f.regexp_contains(name, '^[A-Z]+ "'),
self.f.regexp_extract(name, '^([A-Z]+) "'),
),
self.if_(self.f.regexp_contains(name, "^-?[0-9]*$"), "INT64"),
self.if_(
self.f.regexp_contains(
name, r'^(-?[0-9]+[.e].*|CAST\("([^"]*)" AS FLOAT64\))$'
),
"FLOAT64",
),
self.if_(name.isin(sge.convert("true"), sge.convert("false")), "BOOL"),
self.if_(
sg.or_(self.f.starts_with(name, '"'), self.f.starts_with(name, "'")),
"STRING",
),
self.if_(self.f.starts_with(name, 'b"'), "BYTES"),
self.if_(self.f.starts_with(name, "["), "ARRAY"),
self.if_(self.f.regexp_contains(name, r"^(STRUCT)?\("), "STRUCT"),
self.if_(self.f.starts_with(name, "ST_"), "GEOGRAPHY"),
self.if_(name.eq(sge.convert("NULL")), "NULL"),
]
case = sge.Case(ifs=ifs, default=sge.convert("UNKNOWN"))
return sg.select(case).from_(from_).subquery()
return self._pudf("typeof", arg)

def visit_Xor(self, op, *, left, right):
return sg.or_(sg.and_(left, sg.not_(right)), sg.and_(sg.not_(left), right))
Expand All @@ -673,10 +642,10 @@ def visit_CountStar(self, op, *, arg, where):
return self.f.count(STAR)

def visit_Degrees(self, op, *, arg):
return sge.paren(180 * arg / self.f.acos(-1), copy=False)
return self._pudf("degrees", arg)

def visit_Radians(self, op, *, arg):
return sge.paren(self.f.acos(-1) * arg / 180, copy=False)
return self._pudf("radians", arg)

def visit_CountDistinct(self, op, *, arg, where):
if where is not None:
Expand All @@ -685,3 +654,27 @@ def visit_CountDistinct(self, op, *, arg, where):

def visit_RandomUUID(self, op, **kwargs):
return self.f.generate_uuid()

def visit_ExtractFile(self, op, *, arg):
return self._pudf("cw_url_extract_file", arg)

def visit_ExtractFragment(self, op, *, arg):
return self._pudf("cw_url_extract_fragment", arg)

def visit_ExtractPath(self, op, *, arg):
return self._pudf("cw_url_extract_path", arg)

def visit_ExtractProtocol(self, op, *, arg):
return self._pudf("cw_url_extract_protocol", arg)

def visit_ExtractQuery(self, op, *, arg, key):
if key is not None:
return self._pudf("cw_url_extract_parameter", arg, key)
else:
return self._pudf("cw_url_extract_query", arg)

def _pudf(self, name, *args):
name = sg.table(name, db="persistent_udfs", catalog="bigquery-public-data").sql(
self.dialect
)
return self.f[name](*args)
9 changes: 6 additions & 3 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,14 +719,18 @@ def test_substr_with_null_values(backend, alltypes, df):
lambda d: d.authority(),
"user:pass@example.com:80",
id="authority",
marks=[pytest.mark.notyet(["trino"], raises=com.OperationNotDefinedError)],
marks=[
pytest.mark.notyet(
["bigquery", "trino"], raises=com.OperationNotDefinedError
)
],
),
param(
lambda d: d.userinfo(),
"user:pass",
marks=[
pytest.mark.notyet(
["clickhouse", "snowflake", "trino"],
["bigquery", "clickhouse", "snowflake", "trino"],
raises=com.OperationNotDefinedError,
reason="doesn't support `USERINFO`",
)
Expand Down Expand Up @@ -775,7 +779,6 @@ def test_substr_with_null_values(backend, alltypes, df):
)
@pytest.mark.notimpl(
[
"bigquery",
"duckdb",
"exasol",
"mssql",
Expand Down

0 comments on commit 3d0f9bc

Please sign in to comment.