Skip to content

Commit

Permalink
refactor(clickhouse): apply repetitive transformations as pattern rep…
Browse files Browse the repository at this point in the history
…lacements
  • Loading branch information
cpcloud authored and kszucs committed Oct 16, 2023
1 parent 973133b commit e966af8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 70 deletions.
16 changes: 12 additions & 4 deletions ibis/backends/clickhouse/compiler/core.py
Expand Up @@ -101,19 +101,27 @@ def fn(node, _, **kwargs):
False, dtype="bool"
)

# subtract one from ranking functions to convert from 1-indexed to 0-indexed
subtract_one_from_ranking_functions = p.WindowFunction(
p.RankBase | p.NTile
# subtract one from one-based functions to convert to zero-based indexing
subtract_one_from_one_indexed_functions = (
p.WindowFunction(p.RankBase | p.NTile)
| p.StringFind
| p.FindInSet
| p.ArrayPosition
) >> c.Subtract(_, 1)

add_one_to_nth_value_input = p.NthValue >> _.copy(nth=c.Add(_.nth, 1))

nullify_empty_string_results = (p.ExtractURLField | p.DayOfWeekName) >> c.NullIf(
_, ""
)

op = op.replace(
replace_literals
| replace_in_column_with_table_array_view
| replace_empty_in_values_with_false
| subtract_one_from_ranking_functions
| subtract_one_from_one_indexed_functions
| add_one_to_nth_value_input
| nullify_empty_string_results
)
# apply translate rules in topological order
node = op.map(fn)[op]
Expand Down
88 changes: 22 additions & 66 deletions ibis/backends/clickhouse/compiler/values.py
Expand Up @@ -206,9 +206,9 @@ def _string_find(op, *, arg, substr, start, end, **_):
raise com.UnsupportedOperationError("String find doesn't support end argument")

if start is not None:
return F.locate(arg, substr, start) - 1
return F.locate(arg, substr, start)

return F.locate(arg, substr) - 1
return F.locate(arg, substr)


@translate_val.register(ops.RegexSearch)
Expand All @@ -234,7 +234,7 @@ def _regex_extract(op, *, arg, pattern, index, **_):

@translate_val.register(ops.FindInSet)
def _index_of(op, *, needle, values, **_):
return F.indexOf(F.array(*values), needle) - 1
return F.indexOf(F.array(*values), needle)


@translate_val.register(ops.Round)
Expand Down Expand Up @@ -567,11 +567,6 @@ def _struct_field(op, *, arg, field: str, **_):
return cast(sg.exp.Dot(this=arg, expression=sg.exp.convert(idx + 1)), op.dtype)


@translate_val.register(ops.NthValue)
def _nth_value(op, *, arg, nth, **_):
return F.nth_value(arg, _parenthesize(op.nth, nth))


@translate_val.register(ops.Repeat)
def _repeat(op, *, arg, times, **_):
return F.repeat(arg, F.accurateCast(times, "UInt64"))
Expand All @@ -597,7 +592,8 @@ def _in_column(op, *, value, options, **_):
return value.isin(options.this if isinstance(options, sg.exp.Subquery) else options)


_NUM_WEEKDAYS = 7
_DAYS = calendar.day_name
_NUM_WEEKDAYS = len(_DAYS)


@translate_val.register(ops.DayOfWeekIndex)
Expand All @@ -618,15 +614,11 @@ def day_of_week_name(op, *, arg, **_):
#
# We test against 20 in CI, so we implement day_of_week_name as follows
num_weekdays = _NUM_WEEKDAYS
weekdays = range(num_weekdays)
base = (((F.toDayOfWeek(arg) - 1) % num_weekdays) + num_weekdays) % num_weekdays
return F.nullIf(
sg.exp.Case(
this=base,
ifs=[if_(day, calendar.day_name[day]) for day in weekdays],
default=sg.exp.convert(""),
),
"",
return sg.exp.Case(
this=base,
ifs=[if_(i, day) for i, day in enumerate(_DAYS)],
default=sg.exp.convert(""),
)


Expand Down Expand Up @@ -804,6 +796,16 @@ def formatter(op, *, left, right, **_):
ops.NTile: "ntile",
ops.ArrayIntersect: "arrayIntersect",
ops.ExtractEpochSeconds: "toRelativeSecondNum",
ops.NthValue: "nth_value",
ops.MinRank: "rank",
ops.DenseRank: "dense_rank",
ops.RowNumber: "row_number",
ops.ExtractProtocol: "protocol",
ops.ExtractAuthority: "netloc",
ops.ExtractHost: "domain",
ops.ExtractPath: "path",
ops.ExtractFragment: "fragment",
ops.ArrayPosition: "indexOf",
}


Expand Down Expand Up @@ -925,58 +927,17 @@ def formatter(op, *, arg, offset, default, **_):
shift_like(ops.Lead, F.leadInFrame)


@translate_val.register(ops.RowNumber)
def _row_number(op, **_):
return F.row_number()


@translate_val.register(ops.DenseRank)
def _dense_rank(op, **_):
return F.dense_rank()


@translate_val.register(ops.MinRank)
def _rank(op, **_):
return F.rank()


@translate_val.register(ops.ExtractProtocol)
def _extract_protocol(op, *, arg, **_):
return F.nullIf(F.protocol(arg), "")


@translate_val.register(ops.ExtractAuthority)
def _extract_authority(op, *, arg, **_):
return F.nullIf(F.netloc(arg), "")


@translate_val.register(ops.ExtractHost)
def _extract_host(op, *, arg, **_):
return F.nullIf(F.domain(arg), "")


@translate_val.register(ops.ExtractFile)
def _extract_file(op, *, arg, **_):
return F.nullIf(F.cutFragment(F.pathFull(arg)), "")


@translate_val.register(ops.ExtractPath)
def _extract_path(op, *, arg, **_):
return F.nullIf(F.path(arg), "")
return F.cutFragment(F.pathFull(arg))


@translate_val.register(ops.ExtractQuery)
def _extract_query(op, *, arg, key, **_):
if key is not None:
input = F.extractURLParameter(arg, key)
return F.extractURLParameter(arg, key)
else:
input = F.queryString(arg)
return F.nullIf(input, "")


@translate_val.register(ops.ExtractFragment)
def _extract_fragment(op, *, arg, **_):
return F.nullIf(F.fragment(arg), "")
return F.queryString(arg)


@translate_val.register(ops.ArrayStringJoin)
Expand All @@ -1001,11 +962,6 @@ def _array_filter(op, *, arg, param, body, **_):
return F.arrayFilter(func, arg)


@translate_val.register(ops.ArrayPosition)
def _array_position(op, *, arg, other, **_):
return F.indexOf(arg, other) - 1


@translate_val.register(ops.ArrayRemove)
def _array_remove(op, *, arg, other, **_):
x = sg.to_identifier("x")
Expand Down

0 comments on commit e966af8

Please sign in to comment.