Skip to content

Commit

Permalink
fix(api): make re_extract conform to semantics of Python's re.match
Browse files Browse the repository at this point in the history
Alter various backend implementations of `StringValue.re_extract` to
match the behavior of `re.match`.

BREAKING CHANGE: `re_extract` now follows `re.match` behavior. In particular, the `0`th group is now the entire string if there's a match, otherwise the groups are 1-based.
  • Loading branch information
cpcloud authored and kszucs committed Nov 22, 2022
1 parent 6bb5b4f commit 5981227
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 36 deletions.
30 changes: 22 additions & 8 deletions ibis/backends/clickhouse/registry.py
Expand Up @@ -220,14 +220,28 @@ def _string_find(translator, op):


def _regex_extract(translator, op):
arg_ = translator.translate(op.arg)
pattern_ = translator.translate(op.pattern)
index = op.index

base = f"extractAll(CAST({arg_} AS String), {pattern_})"
if index is not None:
return f"{base}[{translator.translate(index)} + 1]"
return base
arg = translator.translate(op.arg)
pattern = translator.translate(op.pattern)
index = "Null" if op.index is None else translator.translate(op.index)

# arg can be Nullable, which is not allowed in extractAll, so cast to non
# nullable type
arg = f"CAST({arg} AS String)"

# extract all matches in pattern
extracted = f"CAST(extractAll({arg}, {pattern}) AS Array(Nullable(String)))"

# if there's a match
# if the index IS zero or null
# return the full string
# else
# return the Nth match group
# else
# return null
does_match = f"match({arg}, {pattern})"
idx = f"CAST(nullIf({index}, 0) AS Nullable(Int64))"
then = f"if({idx} IS NULL, {arg}, {extracted}[{idx}])"
return f"if({does_match}, {then}, NULL)"


def _parse_url(translator, op):
Expand Down
@@ -1 +1 @@
extractAll(CAST(`string_col` AS String), '[\d]+')[3 + 1]
if(match(CAST(`string_col` AS String), '[\d]+'), if(CAST(nullIf(3, 0) AS Nullable(Int64)) IS NULL, CAST(`string_col` AS String), CAST(extractAll(CAST(`string_col` AS String), '[\d]+') AS Array(Nullable(String)))[CAST(nullIf(3, 0) AS Nullable(Int64))]), NULL)
28 changes: 21 additions & 7 deletions ibis/backends/datafusion/compiler.py
Expand Up @@ -4,8 +4,10 @@
import datafusion as df
import datafusion.functions
import pyarrow as pa
import pyarrow.compute as pc

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.datafusion.datatypes import to_pyarrow_type

Expand Down Expand Up @@ -259,13 +261,6 @@ def substring(op):
return df.functions.substr(arg, start, length)


@translate.register(ops.RegexExtract)
def regex_extract(op):
arg = translate(op.arg)
pattern = translate(op.pattern)
return df.functions.regexp_match(arg, pattern)


@translate.register(ops.Repeat)
def repeat(op):
arg = translate(op.arg)
Expand Down Expand Up @@ -447,3 +442,22 @@ def elementwise_udf(op):
@translate.register(ops.StringConcat)
def string_concat(op):
return df.functions.concat(*map(translate, op.args))


@translate.register(ops.RegexExtract)
def regex_extract(op):
arg = translate(op.arg)
pattern = translate(ops.StringConcat("(", op.pattern, ")"))
if (index := getattr(op.index, "value", None)) is None:
raise ValueError(
"re_extract `index` expressions must be literals. "
"Arbitrary expressions are not supported in the DataFusion backend"
)
string_array_get = df.udf(
lambda arr, index=index: pc.list_element(arr, index),
input_types=[to_pyarrow_type(dt.Array(dt.string))],
return_type=to_pyarrow_type(dt.string),
volatility="immutable",
name="string_array_get",
)
return string_array_get(df.functions.regexp_match(arg, pattern))
4 changes: 1 addition & 3 deletions ibis/backends/duckdb/registry.py
Expand Up @@ -131,9 +131,7 @@ def _regex_extract(t, op):
# DuckDB requires the index to be a constant so we compile
# the value and inline it using sa.text
sa.text(
str(
(index + 1).compile(compile_kwargs=dict(literal_binds=True))
)
str(index.compile(compile_kwargs=dict(literal_binds=True)))
),
),
)
Expand Down
21 changes: 7 additions & 14 deletions ibis/backends/postgres/registry.py
Expand Up @@ -298,20 +298,13 @@ def _log(t, op):

def _regex_extract(t, op):
arg = t.translate(op.arg)
pattern = t.translate(op.pattern)
return sa.case(
[
(
sa.func.textregexeq(arg, pattern),
sa.func.regexp_match(
arg,
pattern,
type_=postgresql.ARRAY(sa.TEXT),
)[t.translate(op.index) + 1],
)
],
else_="",
)
# wrap in parens to support 0th group being the whole string
pattern = "(" + t.translate(op.pattern) + ")"
# arrays are 1-based in postgres
index = t.translate(op.index) + 1
does_match = sa.func.textregexeq(arg, pattern)
matches = sa.func.regexp_match(arg, pattern, type_=postgresql.ARRAY(sa.TEXT))
return sa.case([(does_match, matches[index])], else_=None)


def _cardinality(array):
Expand Down
11 changes: 8 additions & 3 deletions ibis/expr/types/strings.py
Expand Up @@ -541,14 +541,19 @@ def re_extract(
Parameters
----------
pattern
Reguar expression string
Reguar expression pattern string
index
Zero-based index of match to return
The index of the match group to return.
The behavior of this function follows the behavior of Python's
[`re.match`](https://docs.python.org/3/library/re.html#match-objects):
when `index` is zero and there's a match, return the entire string,
otherwise return the content of the `index`-th match group.
Returns
-------
StringValue
Extracted match
Extracted match or whole string if `index` is zero
"""
return ops.RegexExtract(self, pattern, index).to_expr()

Expand Down

0 comments on commit 5981227

Please sign in to comment.