From 59b6080496c6e888ffd033a4f5b075fd5a3eca65 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 15 Oct 2025 22:24:56 +0000 Subject: [PATCH 1/3] refactor: support ops.case_when_op for the sqlglot compiler --- .../sqlglot/expressions/generic_ops.py | 51 ++++++++++++++----- .../system/small/engines/test_generic_ops.py | 4 +- .../test_case_when_op/out.sql | 29 +++++++++++ .../sqlglot/expressions/test_generic_ops.py | 41 +++++++++++++++ 4 files changed, 111 insertions(+), 14 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_case_when_op/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 1ed49b89eb..132403f18d 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -23,6 +23,7 @@ import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op +register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op register_ternary_op = scalar_compiler.scalar_op_compiler.register_ternary_op @@ -67,18 +68,6 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: return _cast(sg_expr, sg_to_type, op.safe) -@register_ternary_op(ops.clip_op) -def _( - original: TypedExpr, - lower: TypedExpr, - upper: TypedExpr, -) -> sge.Expression: - return sge.Greatest( - this=sge.Least(this=original.expr, expressions=[upper.expr]), - expressions=[lower.expr], - ) - - @register_unary_op(ops.hash_op) def _(expr: TypedExpr) -> sge.Expression: return sge.func("FARM_FINGERPRINT", expr.expr) @@ -114,6 +103,44 @@ def _( return sge.If(this=condition.expr, true=original.expr, false=replacement.expr) +@register_ternary_op(ops.clip_op) +def _( + original: TypedExpr, + lower: TypedExpr, + upper: TypedExpr, +) -> sge.Expression: + return sge.Greatest( + this=sge.Least(this=original.expr, expressions=[upper.expr]), + expressions=[lower.expr], + ) + + +@register_nary_op(ops.case_when_op) +def _(*cases_and_outputs: TypedExpr) -> sge.Expression: + # Need to upcast BOOL to INT if any output is numeric + result_values = cases_and_outputs[1::2] + do_upcast_bool = any( + dtypes.is_numeric(t.dtype, include_bool=False) for t in result_values + ) + if do_upcast_bool: + result_values = tuple( + TypedExpr( + sge.Cast(this=val.expr, to="INT64"), + dtypes.INT_DTYPE, + ) + if val.dtype == dtypes.BOOL_DTYPE + else val + for val in result_values + ) + + return sge.Case( + ifs=[ + sge.If(this=predicate.expr, true=output.expr) + for predicate, output in zip(cases_and_outputs[::2], result_values) + ], + ) + + # Helper functions def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: from_type = expr.dtype diff --git a/tests/system/small/engines/test_generic_ops.py b/tests/system/small/engines/test_generic_ops.py index ae7eafd347..e311ce2589 100644 --- a/tests/system/small/engines/test_generic_ops.py +++ b/tests/system/small/engines/test_generic_ops.py @@ -357,7 +357,7 @@ def test_engines_fillna_op(scalars_array_value: array_value.ArrayValue, engine): assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_casewhen_op_single_case( scalars_array_value: array_value.ArrayValue, engine ): @@ -373,7 +373,7 @@ def test_engines_casewhen_op_single_case( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_casewhen_op_double_case( scalars_array_value: array_value.ArrayValue, engine ): diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_case_when_op/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_case_when_op/out.sql new file mode 100644 index 0000000000..08db34a632 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_case_when_op/out.sql @@ -0,0 +1,29 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `int64_too` AS `bfcol_2`, + `float64_col` AS `bfcol_3` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE WHEN `bfcol_0` THEN `bfcol_1` END AS `bfcol_4`, + CASE WHEN `bfcol_0` THEN `bfcol_1` WHEN `bfcol_0` THEN `bfcol_2` END AS `bfcol_5`, + CASE WHEN `bfcol_0` THEN `bfcol_0` WHEN `bfcol_0` THEN `bfcol_0` END AS `bfcol_6`, + CASE + WHEN `bfcol_0` + THEN `bfcol_1` + WHEN `bfcol_0` + THEN CAST(`bfcol_0` AS INT64) + WHEN `bfcol_0` + THEN `bfcol_3` + END AS `bfcol_7` + FROM `bfcte_0` +) +SELECT + `bfcol_4` AS `single_case`, + `bfcol_5` AS `double_case`, + `bfcol_6` AS `bool_types_case`, + `bfcol_7` AS `mixed_types_cast` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index 261a630d3a..2336cd0c92 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -168,6 +168,47 @@ def test_astype_json_invalid( ) +def test_case_when_op(scalar_types_df: bpd.DataFrame, snapshot): + ops_map = { + "single_case": ops.case_when_op.as_expr( + "bool_col", + "int64_col", + ), + "double_case": ops.case_when_op.as_expr( + "bool_col", + "int64_col", + "bool_col", + "int64_too", + ), + "bool_types_case": ops.case_when_op.as_expr( + "bool_col", + "bool_col", + "bool_col", + "bool_col", + ), + "mixed_types_cast": ops.case_when_op.as_expr( + "bool_col", + "int64_col", + "bool_col", + "bool_col", + "bool_col", + "float64_col", + ), + } + + array_value = scalar_types_df._block.expr + result, col_ids = array_value.compute_values(list(ops_map.values())) + + # Rename columns for deterministic golden SQL results. + assert len(col_ids) == len(ops_map.keys()) + result = result.rename_columns( + {col_id: key for col_id, key in zip(col_ids, ops_map.keys())} + ).select_columns(list(ops_map.keys())) + + sql = result.session._executor.to_sql(result, enable_cache=False) + snapshot.assert_match(sql, "out.sql") + + def test_clip(scalar_types_df: bpd.DataFrame, snapshot): op_expr = ops.clip_op.as_expr("rowindex", "int64_col", "int64_too") From 136ec2fcc8d4b4b2ef32bee8dd3184ca7ddc4bc3 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 15 Oct 2025 22:44:41 +0000 Subject: [PATCH 2/3] fix ops.invert_op in engine tests --- .../sqlglot/expressions/generic_ops.py | 7 +++++++ .../sqlglot/expressions/numeric_ops.py | 5 ----- .../system/small/engines/test_generic_ops.py | 8 ++++---- .../test_generic_ops/test_invert/out.sql | 19 +++++++++++++++++++ .../test_numeric_ops/test_invert/out.sql | 13 ------------- .../sqlglot/expressions/test_generic_ops.py | 12 ++++++++++++ .../sqlglot/expressions/test_numeric_ops.py | 8 -------- 7 files changed, 42 insertions(+), 30 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_invert/out.sql delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_invert/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 132403f18d..60366b02c9 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -73,6 +73,13 @@ def _(expr: TypedExpr) -> sge.Expression: return sge.func("FARM_FINGERPRINT", expr.expr) +@register_unary_op(ops.invert_op) +def _(expr: TypedExpr) -> sge.Expression: + if expr.dtype == dtypes.BOOL_DTYPE: + return sge.Not(this=expr.expr) + return sge.BitwiseNot(this=expr.expr) + + @register_unary_op(ops.isnull_op) def _(expr: TypedExpr) -> sge.Expression: return sge.Is(this=expr.expr, expression=sge.Null()) diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index ac40e4a667..3bbe2623ea 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -148,11 +148,6 @@ def _(expr: TypedExpr) -> sge.Expression: return sge.Floor(this=expr.expr) -@register_unary_op(ops.invert_op) -def _(expr: TypedExpr) -> sge.Expression: - return sge.BitwiseNot(this=expr.expr) - - @register_unary_op(ops.ln_op) def _(expr: TypedExpr) -> sge.Expression: return sge.Case( diff --git a/tests/system/small/engines/test_generic_ops.py b/tests/system/small/engines/test_generic_ops.py index e311ce2589..5641f91a9a 100644 --- a/tests/system/small/engines/test_generic_ops.py +++ b/tests/system/small/engines/test_generic_ops.py @@ -391,7 +391,7 @@ def test_engines_casewhen_op_double_case( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_isnull_op(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ops.isnull_op.as_expr(expression.deref("string_col"))] @@ -400,7 +400,7 @@ def test_engines_isnull_op(scalars_array_value: array_value.ArrayValue, engine): assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_notnull_op(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ops.notnull_op.as_expr(expression.deref("string_col"))] @@ -409,7 +409,7 @@ def test_engines_notnull_op(scalars_array_value: array_value.ArrayValue, engine) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ @@ -454,7 +454,7 @@ def test_engines_isin_op(scalars_array_value: array_value.ArrayValue, engine): assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_isin_op_nested_filter( scalars_array_value: array_value.ArrayValue, engine ): diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_invert/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_invert/out.sql new file mode 100644 index 0000000000..b5a5b92b52 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_invert/out.sql @@ -0,0 +1,19 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `bytes_col` AS `bfcol_1`, + `int64_col` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + ~`bfcol_2` AS `bfcol_6`, + ~`bfcol_1` AS `bfcol_7`, + NOT `bfcol_0` AS `bfcol_8` + FROM `bfcte_0` +) +SELECT + `bfcol_6` AS `int64_col`, + `bfcol_7` AS `bytes_col`, + `bfcol_8` AS `bool_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_invert/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_invert/out.sql deleted file mode 100644 index 28f2aa6e06..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_invert/out.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `int64_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - ~`bfcol_0` AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `int64_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index 2336cd0c92..b7abc63213 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -233,6 +233,18 @@ def test_hash(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_invert(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bytes_col", "bool_col"]] + ops_map = { + "int64_col": ops.invert_op.as_expr("int64_col"), + "bytes_col": ops.invert_op.as_expr("bytes_col"), + "bool_col": ops.invert_op.as_expr("bool_col"), + } + sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + + snapshot.assert_match(sql, "out.sql") + + def test_isnull(scalar_types_df: bpd.DataFrame, snapshot): col_name = "float64_col" bf_df = scalar_types_df[[col_name]] diff --git a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py index 231d9d5bf0..59726da73b 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py @@ -126,14 +126,6 @@ def test_floor(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") -def test_invert(scalar_types_df: bpd.DataFrame, snapshot): - col_name = "int64_col" - bf_df = scalar_types_df[[col_name]] - sql = utils._apply_unary_ops(bf_df, [ops.invert_op.as_expr(col_name)], [col_name]) - - snapshot.assert_match(sql, "out.sql") - - def test_ln(scalar_types_df: bpd.DataFrame, snapshot): col_name = "float64_col" bf_df = scalar_types_df[[col_name]] From 4858b9dfdc0cc8ff756543283ebf14ad0df4cc8e Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 16 Oct 2025 21:35:45 +0000 Subject: [PATCH 3/3] exclude test_engines_isin_op_nested_filter due to lack of parantheses --- tests/system/small/engines/test_generic_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/system/small/engines/test_generic_ops.py b/tests/system/small/engines/test_generic_ops.py index 5641f91a9a..f209b95496 100644 --- a/tests/system/small/engines/test_generic_ops.py +++ b/tests/system/small/engines/test_generic_ops.py @@ -454,7 +454,7 @@ def test_engines_isin_op(scalars_array_value: array_value.ArrayValue, engine): assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) def test_engines_isin_op_nested_filter( scalars_array_value: array_value.ArrayValue, engine ):