Skip to content

Commit

Permalink
fix(flink): fix compilation of over aggregation query in flink backend (
Browse files Browse the repository at this point in the history
  • Loading branch information
chloeh13q authored Feb 23, 2024
1 parent b51eb2d commit de174a2
Show file tree
Hide file tree
Showing 14 changed files with 56 additions and 18 deletions.
7 changes: 7 additions & 0 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def _minimize_spec(start, end, spec):
and end.following
):
return None
elif (
isinstance(getattr(end, "value", None), ops.Cast)
and end.value.arg.value == 0
and end.following
):
spec.args["end"] = "CURRENT ROW"
spec.args["end_side"] = None
return spec

def visit_TumbleWindowingTVF(self, op, *, table, time_col, window_size, offset):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SELECT
FROM (
SELECT
`t0`.*
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE)) AS `t0`
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE(2))) AS `t0`
) AS `t1`
GROUP BY
`t1`.`window_start`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ FROM (
FROM (
SELECT
`t0`.*
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '600' SECOND)) AS `t0`
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '600' SECOND(3))) AS `t0`
) AS `t1`
) AS `t2`
) AS `t3`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
`t0`.*
FROM TABLE(
CUMULATE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '10' SECOND, INTERVAL '1' MINUTE)
CUMULATE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '10' SECOND(2), INTERVAL '1' MINUTE(2))
) AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
SELECT
`t0`.*
FROM TABLE(HOP(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '1' MINUTE, INTERVAL '15' MINUTE)) AS `t0`
FROM TABLE(
HOP(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '1' MINUTE(2), INTERVAL '15' MINUTE(2))
) AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
`t0`.*
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE)) AS `t0`
FROM TABLE(TUMBLE(TABLE `table`, DESCRIPTOR(`i`), INTERVAL '15' MINUTE(2))) AS `t0`

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST RANGE BETWEEN INTERVAL '500' MINUTE preceding AND CAST(0 AS INTERVAL MINUTE) following) AS `Sum(f)`
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST RANGE BETWEEN INTERVAL '500' MINUTE(3) preceding AND CURRENT ROW) AS `Sum(f)`
FROM `table` AS `t0`
41 changes: 41 additions & 0 deletions ibis/backends/sql/dialects.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import math

import sqlglot.expressions as sge
from sqlglot import transforms
Expand Down Expand Up @@ -68,6 +69,45 @@ class Generator(Postgres.Generator):
}


def _calculate_precision(interval_value: int) -> int:
"""Calculate interval precision.
FlinkSQL interval data types use leading precision and fractional-
seconds precision. Because the leading precision defaults to 2, we need to
specify a different precision when the value exceeds 2 digits.
(see
https://learn.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literals)
"""
# log10(interval_value) + 1 is equivalent to len(str(interval_value)), but is significantly
# faster and more memory-efficient
if interval_value == 0:
return 0
if interval_value < 0:
raise ValueError(
f"Expecting value to be a non-negative integer, got {interval_value}"
)
return int(math.log10(interval_value)) + 1


def _interval_with_precision(self, e):
"""Format interval with precision."""
arg = e.args["this"].this
formatted_arg = arg
with contextlib.suppress(AttributeError):
formatted_arg = arg.sql(self.dialect)

unit = e.args["unit"]
# when formatting interval scalars, need to quote arg and add precision
if isinstance(arg, str):
formatted_arg = f"'{formatted_arg}'"
prec = _calculate_precision(int(arg))
prec = max(prec, 2)
unit += f"({prec})"

return f"INTERVAL {formatted_arg} {unit}"


class Flink(Hive):
class Generator(Hive.Generator):
TYPE_MAPPING = Hive.Generator.TYPE_MAPPING.copy() | {
Expand All @@ -91,6 +131,7 @@ class Generator(Hive.Generator):
sge.DayOfYear: rename_func("dayofyear"),
sge.DayOfWeek: rename_func("dayofweek"),
sge.DayOfMonth: rename_func("dayofmonth"),
sge.Interval: _interval_with_precision,
}

class Tokenizer(Hive.Tokenizer):
Expand Down

0 comments on commit de174a2

Please sign in to comment.