diff --git a/ibis/backends/flink/registry.py b/ibis/backends/flink/registry.py index 6283269c5973..91f20f21d2ac 100644 --- a/ibis/backends/flink/registry.py +++ b/ibis/backends/flink/registry.py @@ -40,9 +40,19 @@ def extract_field_formatter(translator: ExprTranslator, op: ops.Node) -> str: def _cast(translator: ExprTranslator, op: ops.generic.Cast) -> str: - if op.to.is_timestamp() and op.to.timezone: - arg_translated = translator.translate(op.arg) - return f"TO_TIMESTAMP(CONVERT_TZ(CAST({arg_translated} AS STRING), 'UTC+0', '{op.to.timezone}'))" + arg, to = op.arg, op.to + arg_translated = translator.translate(arg) + if to.is_timestamp(): + if arg.dtype.is_numeric(): + arg_translated = f"FROM_UNIXTIME({arg_translated})" + + if to.timezone: + return f"TO_TIMESTAMP(CONVERT_TZ(CAST({arg_translated} AS STRING), 'UTC+0', '{to.timezone}'))" + else: + return f"TO_TIMESTAMP({arg_translated})" + + elif to.is_json(): + return arg_translated from ibis.backends.base.sql.registry.main import cast @@ -101,48 +111,6 @@ def _filter(translator: ExprTranslator, op: ops.Node) -> str: return f"CASE WHEN {bool_expr} THEN {true_expr} ELSE {false_null_expr} END" -def _timestamp_add(translator: ExprTranslator, op: ops.temporal.TimestampAdd) -> str: - return _left_op_right(translator=translator, op_node=op, op_sign="+") - - -def _timestamp_diff(translator: ExprTranslator, op: ops.temporal.TimestampDiff) -> str: - return _left_op_right(translator=translator, op_node=op, op_sign="-") - - -def _timestamp_sub(translator: ExprTranslator, op: ops.temporal.TimestampSub) -> str: - table_column = op.left - interval = op.right - - table_column_translated = translator.translate(table_column) - interval_translated = translator.translate(interval) - return f"{table_column_translated} - {interval_translated}" - - -def _timestamp_from_unix(translator: ExprTranslator, op: ops.Node) -> str: - arg, unit = op.args - - numeric = helpers.quote_identifier(arg.name, force=True) - if unit == TimestampUnit.MILLISECOND: - precision = 3 - elif unit == TimestampUnit.SECOND: - precision = 0 - else: - raise ValueError(f"{unit!r} unit is not supported!") - - return f"TO_TIMESTAMP_LTZ({numeric}, {precision})" - - -def _timestamp_from_ymdhms( - translator: ExprTranslator, op: ops.temporal.TimestampFromYMDHMS -) -> str: - year, month, day, hours, minutes, seconds = ( - f"CAST({translator.translate(e)} AS STRING)" - for e in [op.year, op.month, op.day, op.hours, op.minutes, op.seconds] - ) - concat_string = f"CONCAT({year}, '-', {month}, '-', {day}, ' ', {hours}, ':', {minutes}, ':', {seconds})" - return f"CAST({concat_string} AS TIMESTAMP)" - - def _format_window_start(translator: ExprTranslator, boundary): if boundary is None: return "UNBOUNDED PRECEDING" @@ -281,6 +249,35 @@ def _array_index(translator: ExprTranslator, op: ops.arrays.ArrayIndex): return f"{table_column_translated} [ {index_translated} + 1 ]" +def _array_length(translator: ExprTranslator, op: ops.arrays.ArrayLength) -> str: + return f"CARDINALITY({translator.translate(op.arg)})" + + +def _json_get_item(translator: ExprTranslator, op: ops.json.JSONGetItem) -> str: + arg_translated = translator.translate(op.arg) + if op.index.dtype.is_integer(): + query_path = f"$[{op.index.value}]" + else: # is string + query_path = f"$.{op.index.value}" + + return ( + f"JSON_QUERY({arg_translated}, '{query_path}' WITH CONDITIONAL ARRAY WRAPPER)" + ) + + +def _map(translator: ExprTranslator, op: ops.maps.Map) -> str: + key_array = translator.translate(op.keys) + value_array = translator.translate(op.values) + + return f"MAP_FROM_ARRAYS({key_array}, {value_array})" + + +def _map_get(translator: ExprTranslator, op: ops.maps.MapGet) -> str: + map_ = translator.translate(op.arg) + key = translator.translate(op.key) + return f"{map_} [ {key} ]" + + def _day_of_week_index( translator: ExprTranslator, op: ops.temporal.DayOfWeekIndex ) -> str: @@ -297,16 +294,19 @@ def _date_diff(translator: ExprTranslator, op: ops.temporal.DateDiff) -> str: def _date_from_ymd(translator: ExprTranslator, op: ops.temporal.DateFromYMD) -> str: - year, month, day = op.year, op.month, op.day - date_string = f"{year.value}-{month.value}-{day.value}" - return f"CAST('{date_string}' AS DATE)" + year, month, day = ( + f"CAST({translator.translate(e)} AS STRING)" + for e in [op.year, op.month, op.day] + ) + concat_string = f"CONCAT({year}, '-', {month}, '-', {day})" + return f"CAST({concat_string} AS DATE)" def _date_sub(translator: ExprTranslator, op: ops.temporal.DateSub) -> str: return _left_op_right(translator=translator, op_node=op, op_sign="-") -def extract_epoch_seconds(translator: ExprTranslator, op: ops.Node) -> str: +def _extract_epoch_seconds(translator: ExprTranslator, op: ops.Node) -> str: arg = translator.translate(op.arg) return f"UNIX_TIMESTAMP(CAST({arg} AS STRING))" @@ -319,6 +319,66 @@ def _string_to_timestamp( return f"TO_TIMESTAMP({arg}, {format_string})" +def _time(translator: ExprTranslator, op: ops.temporal.Time) -> str: + if op.arg.dtype.is_timestamp(): + datetime = op.arg.value + return f"TIME '{datetime.hour}:{datetime.minute}:{datetime.second}'" + + else: + raise com.UnsupportedOperationError(f"Does NOT support dtype= {op.arg.dtype}") + + +def _time_from_hms(translator: ExprTranslator, op: ops.temporal.TimeFromHMS) -> str: + hours, minutes, seconds = ( + f"CAST({translator.translate(e)} AS STRING)" + for e in [op.hours, op.minutes, op.seconds] + ) + concat_string = f"CONCAT({hours}, ':', {minutes}, ':', {seconds})" + return f"CAST({concat_string} AS TIME)" + + +def _timestamp_add(translator: ExprTranslator, op: ops.temporal.TimestampAdd) -> str: + return _left_op_right(translator=translator, op_node=op, op_sign="+") + + +def _timestamp_diff(translator: ExprTranslator, op: ops.temporal.TimestampDiff) -> str: + return _left_op_right(translator=translator, op_node=op, op_sign="-") + + +def _timestamp_sub(translator: ExprTranslator, op: ops.temporal.TimestampSub) -> str: + table_column = op.left + interval = op.right + + table_column_translated = translator.translate(table_column) + interval_translated = translator.translate(interval) + return f"{table_column_translated} - {interval_translated}" + + +def _timestamp_from_unix(translator: ExprTranslator, op: ops.Node) -> str: + arg, unit = op.args + + numeric = helpers.quote_identifier(arg.name, force=True) + if unit == TimestampUnit.MILLISECOND: + precision = 3 + elif unit == TimestampUnit.SECOND: + precision = 0 + else: + raise ValueError(f"{unit!r} unit is not supported!") + + return f"TO_TIMESTAMP_LTZ({numeric}, {precision})" + + +def _timestamp_from_ymdhms( + translator: ExprTranslator, op: ops.temporal.TimestampFromYMDHMS +) -> str: + year, month, day, hours, minutes, seconds = ( + f"CAST({translator.translate(e)} AS STRING)" + for e in [op.year, op.month, op.day, op.hours, op.minutes, op.seconds] + ) + concat_string = f"CONCAT({year}, '-', {month}, '-', {day}, ' ', {hours}, ':', {minutes}, ':', {seconds})" + return f"CAST({concat_string} AS TIMESTAMP)" + + operation_registry.update( { # Unary operations @@ -335,7 +395,7 @@ def _string_to_timestamp( ops.RegexSearch: fixed_arity("regexp", 2), # Timestamp operations ops.Date: _date, - ops.ExtractEpochSeconds: extract_epoch_seconds, + ops.ExtractEpochSeconds: _extract_epoch_seconds, ops.ExtractYear: _extract_field("year"), # equivalent to YEAR(date) ops.ExtractMonth: _extract_field("month"), # equivalent to MONTH(date) ops.ExtractDay: _extract_field("day"), # equivalent to DAYOFMONTH(date) @@ -354,23 +414,43 @@ def _string_to_timestamp( ops.Literal: _literal, ops.TryCast: _try_cast, ops.IfElse: _filter, - ops.TimestampAdd: _timestamp_add, - ops.TimestampDiff: _timestamp_diff, - ops.TimestampFromUNIX: _timestamp_from_unix, - ops.TimestampFromYMDHMS: _timestamp_from_ymdhms, - ops.TimestampSub: _timestamp_sub, ops.Window: _window, ops.Clip: _clip, # Binary operations ops.Power: fixed_arity("power", 2), ops.FloorDivide: _floor_divide, - # Temporal functions + # Collection functions ops.ArrayIndex: _array_index, + ops.ArrayLength: _array_length, + ops.JSONGetItem: _json_get_item, + ops.Map: _map, + ops.MapGet: _map_get, + # Temporal functions ops.DateAdd: _date_add, ops.DateDiff: _date_diff, ops.DateFromYMD: _date_from_ymd, ops.DateSub: _date_sub, ops.DayOfWeekIndex: _day_of_week_index, ops.StringToTimestamp: _string_to_timestamp, + ops.Time: _time, + ops.TimeFromHMS: _time_from_hms, + ops.TimestampAdd: _timestamp_add, + ops.TimestampDiff: _timestamp_diff, + ops.TimestampFromUNIX: _timestamp_from_unix, + ops.TimestampFromYMDHMS: _timestamp_from_ymdhms, + ops.TimestampSub: _timestamp_sub, } ) + +_invalid_operations = { + # ibis.expr.operations.strings + ops.Translate, + ops.FindInSet, + # ibis.expr.operations.numeric + ops.IsNan, + ops.IsInf, +} + +operation_registry = { + k: v for k, v in operation_registry.items() if k not in _invalid_operations +}