diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 1ed49b89eb..60366b02c9 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -23,6 +23,7 @@ import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op +register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op register_ternary_op = scalar_compiler.scalar_op_compiler.register_ternary_op @@ -67,23 +68,18 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: return _cast(sg_expr, sg_to_type, op.safe) -@register_ternary_op(ops.clip_op) -def _( - original: TypedExpr, - lower: TypedExpr, - upper: TypedExpr, -) -> sge.Expression: - return sge.Greatest( - this=sge.Least(this=original.expr, expressions=[upper.expr]), - expressions=[lower.expr], - ) - - @register_unary_op(ops.hash_op) def _(expr: TypedExpr) -> sge.Expression: return sge.func("FARM_FINGERPRINT", expr.expr) +@register_unary_op(ops.invert_op) +def _(expr: TypedExpr) -> sge.Expression: + if expr.dtype == dtypes.BOOL_DTYPE: + return sge.Not(this=expr.expr) + return sge.BitwiseNot(this=expr.expr) + + @register_unary_op(ops.isnull_op) def _(expr: TypedExpr) -> sge.Expression: return sge.Is(this=expr.expr, expression=sge.Null()) @@ -114,6 +110,44 @@ def _( return sge.If(this=condition.expr, true=original.expr, false=replacement.expr) +@register_ternary_op(ops.clip_op) +def _( + original: TypedExpr, + lower: TypedExpr, + upper: TypedExpr, +) -> sge.Expression: + return sge.Greatest( + this=sge.Least(this=original.expr, expressions=[upper.expr]), + expressions=[lower.expr], + ) + + +@register_nary_op(ops.case_when_op) +def _(*cases_and_outputs: TypedExpr) -> sge.Expression: + # Need to upcast BOOL to INT if any output is numeric + result_values = cases_and_outputs[1::2] + do_upcast_bool = any( + dtypes.is_numeric(t.dtype, include_bool=False) for t in result_values + ) + if do_upcast_bool: + result_values = tuple( + TypedExpr( + sge.Cast(this=val.expr, to="INT64"), + dtypes.INT_DTYPE, + ) + if val.dtype == dtypes.BOOL_DTYPE + else val + for val in result_values + ) + + return sge.Case( + ifs=[ + sge.If(this=predicate.expr, true=output.expr) + for predicate, output in zip(cases_and_outputs[::2], result_values) + ], + ) + + # Helper functions def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: from_type = expr.dtype diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index ac40e4a667..3bbe2623ea 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -148,11 +148,6 @@ def _(expr: TypedExpr) -> sge.Expression: return sge.Floor(this=expr.expr) -@register_unary_op(ops.invert_op) -def _(expr: TypedExpr) -> sge.Expression: - return sge.BitwiseNot(this=expr.expr) - - @register_unary_op(ops.ln_op) def _(expr: TypedExpr) -> sge.Expression: return sge.Case( diff --git a/tests/system/small/engines/test_generic_ops.py b/tests/system/small/engines/test_generic_ops.py index ae7eafd347..f209b95496 100644 --- a/tests/system/small/engines/test_generic_ops.py +++ b/tests/system/small/engines/test_generic_ops.py @@ -357,7 +357,7 @@ def test_engines_fillna_op(scalars_array_value: array_value.ArrayValue, engine): assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_casewhen_op_single_case( scalars_array_value: array_value.ArrayValue, engine ): @@ -373,7 +373,7 @@ def test_engines_casewhen_op_single_case( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_casewhen_op_double_case( scalars_array_value: array_value.ArrayValue, engine ): @@ -391,7 +391,7 @@ def test_engines_casewhen_op_double_case( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_isnull_op(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ops.isnull_op.as_expr(expression.deref("string_col"))] @@ -400,7 +400,7 @@ def test_engines_isnull_op(scalars_array_value: array_value.ArrayValue, engine): assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_notnull_op(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ops.notnull_op.as_expr(expression.deref("string_col"))] @@ -409,7 +409,7 @@ def test_engines_notnull_op(scalars_array_value: array_value.ArrayValue, engine) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( [ diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_case_when_op/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_case_when_op/out.sql new file mode 100644 index 0000000000..08db34a632 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_case_when_op/out.sql @@ -0,0 +1,29 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `int64_too` AS `bfcol_2`, + `float64_col` AS `bfcol_3` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE WHEN `bfcol_0` THEN `bfcol_1` END AS `bfcol_4`, + CASE WHEN `bfcol_0` THEN `bfcol_1` WHEN `bfcol_0` THEN `bfcol_2` END AS `bfcol_5`, + CASE WHEN `bfcol_0` THEN `bfcol_0` WHEN `bfcol_0` THEN `bfcol_0` END AS `bfcol_6`, + CASE + WHEN `bfcol_0` + THEN `bfcol_1` + WHEN `bfcol_0` + THEN CAST(`bfcol_0` AS INT64) + WHEN `bfcol_0` + THEN `bfcol_3` + END AS `bfcol_7` + FROM `bfcte_0` +) +SELECT + `bfcol_4` AS `single_case`, + `bfcol_5` AS `double_case`, + `bfcol_6` AS `bool_types_case`, + `bfcol_7` AS `mixed_types_cast` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_invert/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_invert/out.sql new file mode 100644 index 0000000000..b5a5b92b52 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_invert/out.sql @@ -0,0 +1,19 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `bytes_col` AS `bfcol_1`, + `int64_col` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + ~`bfcol_2` AS `bfcol_6`, + ~`bfcol_1` AS `bfcol_7`, + NOT `bfcol_0` AS `bfcol_8` + FROM `bfcte_0` +) +SELECT + `bfcol_6` AS `int64_col`, + `bfcol_7` AS `bytes_col`, + `bfcol_8` AS `bool_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_invert/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_invert/out.sql deleted file mode 100644 index 28f2aa6e06..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_invert/out.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `int64_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - ~`bfcol_0` AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `int64_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index 261a630d3a..b7abc63213 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -168,6 +168,47 @@ def test_astype_json_invalid( ) +def test_case_when_op(scalar_types_df: bpd.DataFrame, snapshot): + ops_map = { + "single_case": ops.case_when_op.as_expr( + "bool_col", + "int64_col", + ), + "double_case": ops.case_when_op.as_expr( + "bool_col", + "int64_col", + "bool_col", + "int64_too", + ), + "bool_types_case": ops.case_when_op.as_expr( + "bool_col", + "bool_col", + "bool_col", + "bool_col", + ), + "mixed_types_cast": ops.case_when_op.as_expr( + "bool_col", + "int64_col", + "bool_col", + "bool_col", + "bool_col", + "float64_col", + ), + } + + array_value = scalar_types_df._block.expr + result, col_ids = array_value.compute_values(list(ops_map.values())) + + # Rename columns for deterministic golden SQL results. + assert len(col_ids) == len(ops_map.keys()) + result = result.rename_columns( + {col_id: key for col_id, key in zip(col_ids, ops_map.keys())} + ).select_columns(list(ops_map.keys())) + + sql = result.session._executor.to_sql(result, enable_cache=False) + snapshot.assert_match(sql, "out.sql") + + def test_clip(scalar_types_df: bpd.DataFrame, snapshot): op_expr = ops.clip_op.as_expr("rowindex", "int64_col", "int64_too") @@ -192,6 +233,18 @@ def test_hash(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_invert(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bytes_col", "bool_col"]] + ops_map = { + "int64_col": ops.invert_op.as_expr("int64_col"), + "bytes_col": ops.invert_op.as_expr("bytes_col"), + "bool_col": ops.invert_op.as_expr("bool_col"), + } + sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) + + snapshot.assert_match(sql, "out.sql") + + def test_isnull(scalar_types_df: bpd.DataFrame, snapshot): col_name = "float64_col" bf_df = scalar_types_df[[col_name]] diff --git a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py index 231d9d5bf0..59726da73b 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py @@ -126,14 +126,6 @@ def test_floor(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") -def test_invert(scalar_types_df: bpd.DataFrame, snapshot): - col_name = "int64_col" - bf_df = scalar_types_df[[col_name]] - sql = utils._apply_unary_ops(bf_df, [ops.invert_op.as_expr(col_name)], [col_name]) - - snapshot.assert_match(sql, "out.sql") - - def test_ln(scalar_types_df: bpd.DataFrame, snapshot): col_name = "float64_col" bf_df = scalar_types_df[[col_name]]