Skip to content

Commit

Permalink
feat(flink): add remaining operators for Flink to pass/skip the commo…
Browse files Browse the repository at this point in the history
…n tests
  • Loading branch information
mfatihaktas authored and cpcloud committed Oct 19, 2023
1 parent 5d32c91 commit b27adc6
Showing 1 changed file with 136 additions and 56 deletions.
192 changes: 136 additions & 56 deletions ibis/backends/flink/registry.py
Expand Up @@ -40,9 +40,19 @@ def extract_field_formatter(translator: ExprTranslator, op: ops.Node) -> str:


def _cast(translator: ExprTranslator, op: ops.generic.Cast) -> str:
if op.to.is_timestamp() and op.to.timezone:
arg_translated = translator.translate(op.arg)
return f"TO_TIMESTAMP(CONVERT_TZ(CAST({arg_translated} AS STRING), 'UTC+0', '{op.to.timezone}'))"
arg, to = op.arg, op.to
arg_translated = translator.translate(arg)
if to.is_timestamp():
if arg.dtype.is_numeric():
arg_translated = f"FROM_UNIXTIME({arg_translated})"

if to.timezone:
return f"TO_TIMESTAMP(CONVERT_TZ(CAST({arg_translated} AS STRING), 'UTC+0', '{to.timezone}'))"
else:
return f"TO_TIMESTAMP({arg_translated})"

elif to.is_json():
return arg_translated

from ibis.backends.base.sql.registry.main import cast

Expand Down Expand Up @@ -101,48 +111,6 @@ def _filter(translator: ExprTranslator, op: ops.Node) -> str:
return f"CASE WHEN {bool_expr} THEN {true_expr} ELSE {false_null_expr} END"


def _timestamp_add(translator: ExprTranslator, op: ops.temporal.TimestampAdd) -> str:
return _left_op_right(translator=translator, op_node=op, op_sign="+")


def _timestamp_diff(translator: ExprTranslator, op: ops.temporal.TimestampDiff) -> str:
return _left_op_right(translator=translator, op_node=op, op_sign="-")


def _timestamp_sub(translator: ExprTranslator, op: ops.temporal.TimestampSub) -> str:
table_column = op.left
interval = op.right

table_column_translated = translator.translate(table_column)
interval_translated = translator.translate(interval)
return f"{table_column_translated} - {interval_translated}"


def _timestamp_from_unix(translator: ExprTranslator, op: ops.Node) -> str:
arg, unit = op.args

numeric = helpers.quote_identifier(arg.name, force=True)
if unit == TimestampUnit.MILLISECOND:
precision = 3
elif unit == TimestampUnit.SECOND:
precision = 0
else:
raise ValueError(f"{unit!r} unit is not supported!")

return f"TO_TIMESTAMP_LTZ({numeric}, {precision})"


def _timestamp_from_ymdhms(
translator: ExprTranslator, op: ops.temporal.TimestampFromYMDHMS
) -> str:
year, month, day, hours, minutes, seconds = (
f"CAST({translator.translate(e)} AS STRING)"
for e in [op.year, op.month, op.day, op.hours, op.minutes, op.seconds]
)
concat_string = f"CONCAT({year}, '-', {month}, '-', {day}, ' ', {hours}, ':', {minutes}, ':', {seconds})"
return f"CAST({concat_string} AS TIMESTAMP)"


def _format_window_start(translator: ExprTranslator, boundary):
if boundary is None:
return "UNBOUNDED PRECEDING"
Expand Down Expand Up @@ -281,6 +249,35 @@ def _array_index(translator: ExprTranslator, op: ops.arrays.ArrayIndex):
return f"{table_column_translated} [ {index_translated} + 1 ]"


def _array_length(translator: ExprTranslator, op: ops.arrays.ArrayLength) -> str:
return f"CARDINALITY({translator.translate(op.arg)})"


def _json_get_item(translator: ExprTranslator, op: ops.json.JSONGetItem) -> str:
arg_translated = translator.translate(op.arg)
if op.index.dtype.is_integer():
query_path = f"$[{op.index.value}]"
else: # is string
query_path = f"$.{op.index.value}"

return (
f"JSON_QUERY({arg_translated}, '{query_path}' WITH CONDITIONAL ARRAY WRAPPER)"
)


def _map(translator: ExprTranslator, op: ops.maps.Map) -> str:
key_array = translator.translate(op.keys)
value_array = translator.translate(op.values)

return f"MAP_FROM_ARRAYS({key_array}, {value_array})"


def _map_get(translator: ExprTranslator, op: ops.maps.MapGet) -> str:
map_ = translator.translate(op.arg)
key = translator.translate(op.key)
return f"{map_} [ {key} ]"


def _day_of_week_index(
translator: ExprTranslator, op: ops.temporal.DayOfWeekIndex
) -> str:
Expand All @@ -297,16 +294,19 @@ def _date_diff(translator: ExprTranslator, op: ops.temporal.DateDiff) -> str:


def _date_from_ymd(translator: ExprTranslator, op: ops.temporal.DateFromYMD) -> str:
year, month, day = op.year, op.month, op.day
date_string = f"{year.value}-{month.value}-{day.value}"
return f"CAST('{date_string}' AS DATE)"
year, month, day = (
f"CAST({translator.translate(e)} AS STRING)"
for e in [op.year, op.month, op.day]
)
concat_string = f"CONCAT({year}, '-', {month}, '-', {day})"
return f"CAST({concat_string} AS DATE)"


def _date_sub(translator: ExprTranslator, op: ops.temporal.DateSub) -> str:
return _left_op_right(translator=translator, op_node=op, op_sign="-")


def extract_epoch_seconds(translator: ExprTranslator, op: ops.Node) -> str:
def _extract_epoch_seconds(translator: ExprTranslator, op: ops.Node) -> str:
arg = translator.translate(op.arg)
return f"UNIX_TIMESTAMP(CAST({arg} AS STRING))"

Expand All @@ -319,6 +319,66 @@ def _string_to_timestamp(
return f"TO_TIMESTAMP({arg}, {format_string})"


def _time(translator: ExprTranslator, op: ops.temporal.Time) -> str:
if op.arg.dtype.is_timestamp():
datetime = op.arg.value
return f"TIME '{datetime.hour}:{datetime.minute}:{datetime.second}'"

else:
raise com.UnsupportedOperationError(f"Does NOT support dtype= {op.arg.dtype}")


def _time_from_hms(translator: ExprTranslator, op: ops.temporal.TimeFromHMS) -> str:
hours, minutes, seconds = (
f"CAST({translator.translate(e)} AS STRING)"
for e in [op.hours, op.minutes, op.seconds]
)
concat_string = f"CONCAT({hours}, ':', {minutes}, ':', {seconds})"
return f"CAST({concat_string} AS TIME)"


def _timestamp_add(translator: ExprTranslator, op: ops.temporal.TimestampAdd) -> str:
return _left_op_right(translator=translator, op_node=op, op_sign="+")


def _timestamp_diff(translator: ExprTranslator, op: ops.temporal.TimestampDiff) -> str:
return _left_op_right(translator=translator, op_node=op, op_sign="-")


def _timestamp_sub(translator: ExprTranslator, op: ops.temporal.TimestampSub) -> str:
table_column = op.left
interval = op.right

table_column_translated = translator.translate(table_column)
interval_translated = translator.translate(interval)
return f"{table_column_translated} - {interval_translated}"


def _timestamp_from_unix(translator: ExprTranslator, op: ops.Node) -> str:
arg, unit = op.args

numeric = helpers.quote_identifier(arg.name, force=True)
if unit == TimestampUnit.MILLISECOND:
precision = 3
elif unit == TimestampUnit.SECOND:
precision = 0
else:
raise ValueError(f"{unit!r} unit is not supported!")

return f"TO_TIMESTAMP_LTZ({numeric}, {precision})"


def _timestamp_from_ymdhms(
translator: ExprTranslator, op: ops.temporal.TimestampFromYMDHMS
) -> str:
year, month, day, hours, minutes, seconds = (
f"CAST({translator.translate(e)} AS STRING)"
for e in [op.year, op.month, op.day, op.hours, op.minutes, op.seconds]
)
concat_string = f"CONCAT({year}, '-', {month}, '-', {day}, ' ', {hours}, ':', {minutes}, ':', {seconds})"
return f"CAST({concat_string} AS TIMESTAMP)"


operation_registry.update(
{
# Unary operations
Expand All @@ -335,7 +395,7 @@ def _string_to_timestamp(
ops.RegexSearch: fixed_arity("regexp", 2),
# Timestamp operations
ops.Date: _date,
ops.ExtractEpochSeconds: extract_epoch_seconds,
ops.ExtractEpochSeconds: _extract_epoch_seconds,
ops.ExtractYear: _extract_field("year"), # equivalent to YEAR(date)
ops.ExtractMonth: _extract_field("month"), # equivalent to MONTH(date)
ops.ExtractDay: _extract_field("day"), # equivalent to DAYOFMONTH(date)
Expand All @@ -354,23 +414,43 @@ def _string_to_timestamp(
ops.Literal: _literal,
ops.TryCast: _try_cast,
ops.IfElse: _filter,
ops.TimestampAdd: _timestamp_add,
ops.TimestampDiff: _timestamp_diff,
ops.TimestampFromUNIX: _timestamp_from_unix,
ops.TimestampFromYMDHMS: _timestamp_from_ymdhms,
ops.TimestampSub: _timestamp_sub,
ops.Window: _window,
ops.Clip: _clip,
# Binary operations
ops.Power: fixed_arity("power", 2),
ops.FloorDivide: _floor_divide,
# Temporal functions
# Collection functions
ops.ArrayIndex: _array_index,
ops.ArrayLength: _array_length,
ops.JSONGetItem: _json_get_item,
ops.Map: _map,
ops.MapGet: _map_get,
# Temporal functions
ops.DateAdd: _date_add,
ops.DateDiff: _date_diff,
ops.DateFromYMD: _date_from_ymd,
ops.DateSub: _date_sub,
ops.DayOfWeekIndex: _day_of_week_index,
ops.StringToTimestamp: _string_to_timestamp,
ops.Time: _time,
ops.TimeFromHMS: _time_from_hms,
ops.TimestampAdd: _timestamp_add,
ops.TimestampDiff: _timestamp_diff,
ops.TimestampFromUNIX: _timestamp_from_unix,
ops.TimestampFromYMDHMS: _timestamp_from_ymdhms,
ops.TimestampSub: _timestamp_sub,
}
)

_invalid_operations = {
# ibis.expr.operations.strings
ops.Translate,
ops.FindInSet,
# ibis.expr.operations.numeric
ops.IsNan,
ops.IsInf,
}

operation_registry = {
k: v for k, v in operation_registry.items() if k not in _invalid_operations
}

0 comments on commit b27adc6

Please sign in to comment.