diff --git a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py index 61f1eba607..84e783bb66 100644 --- a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py @@ -38,21 +38,15 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.Concat(expressions=[left.expr, right.expr]) if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): - left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE: - left_expr = sge.Cast(this=left_expr, to="INT64") - right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE: - right_expr = sge.Cast(this=right_expr, to="INT64") + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) return sge.Add(this=left_expr, expression=right_expr) if ( dtypes.is_time_or_date_like(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE ): - left_expr = left.expr - if left.dtype == dtypes.DATE_DTYPE: - left_expr = sge.Cast(this=left_expr, to="DATETIME") + left_expr = _coerce_date_to_datetime(left) return sge.TimestampAdd( this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND") ) @@ -60,9 +54,7 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: dtypes.is_time_or_date_like(right.dtype) and left.dtype == dtypes.TIMEDELTA_DTYPE ): - right_expr = right.expr - if right.dtype == dtypes.DATE_DTYPE: - right_expr = sge.Cast(this=right_expr, to="DATETIME") + right_expr = _coerce_date_to_datetime(right) return sge.TimestampAdd( this=right_expr, expression=left.expr, unit=sge.Var(this="MICROSECOND") ) @@ -74,14 +66,37 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: ) -@BINARY_OP_REGISTRATION.register(ops.div_op) +@BINARY_OP_REGISTRATION.register(ops.eq_op) +def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + return sge.EQ(this=left_expr, expression=right_expr) + + +@BINARY_OP_REGISTRATION.register(ops.eq_null_match_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE: - left_expr = sge.Cast(this=left_expr, to="INT64") + if right.dtype != dtypes.BOOL_DTYPE: + left_expr = _coerce_bool_to_int(left) + right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE: - right_expr = sge.Cast(this=right_expr, to="INT64") + if left.dtype != dtypes.BOOL_DTYPE: + right_expr = _coerce_bool_to_int(right) + + sentinel = sge.convert("$NULL_SENTINEL$") + left_coalesce = sge.Coalesce( + this=sge.Cast(this=left_expr, to="STRING"), expressions=[sentinel] + ) + right_coalesce = sge.Coalesce( + this=sge.Cast(this=right_expr, to="STRING"), expressions=[sentinel] + ) + return sge.EQ(this=left_coalesce, expression=right_coalesce) + + +@BINARY_OP_REGISTRATION.register(ops.div_op) +def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) result = sge.func("IEEE_DIVIDE", left_expr, right_expr) if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype): @@ -92,12 +107,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: @BINARY_OP_REGISTRATION.register(ops.floordiv_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE: - left_expr = sge.Cast(this=left_expr, to="INT64") - right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE: - right_expr = sge.Cast(this=right_expr, to="INT64") + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) result: sge.Expression = sge.Cast( this=sge.Floor(this=sge.func("IEEE_DIVIDE", left_expr, right_expr)), to="INT64" @@ -139,12 +150,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: @BINARY_OP_REGISTRATION.register(ops.mul_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE: - left_expr = sge.Cast(this=left_expr, to="INT64") - right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE: - right_expr = sge.Cast(this=right_expr, to="INT64") + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) result = sge.Mul(this=left_expr, expression=right_expr) @@ -156,36 +163,33 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return result +@BINARY_OP_REGISTRATION.register(ops.ne_op) +def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + return sge.NEQ(this=left_expr, expression=right_expr) + + @BINARY_OP_REGISTRATION.register(ops.sub_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): - left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE: - left_expr = sge.Cast(this=left_expr, to="INT64") - right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE: - right_expr = sge.Cast(this=right_expr, to="INT64") + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) return sge.Sub(this=left_expr, expression=right_expr) if ( dtypes.is_time_or_date_like(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE ): - left_expr = left.expr - if left.dtype == dtypes.DATE_DTYPE: - left_expr = sge.Cast(this=left_expr, to="DATETIME") + left_expr = _coerce_date_to_datetime(left) return sge.TimestampSub( this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND") ) if dtypes.is_time_or_date_like(left.dtype) and dtypes.is_time_or_date_like( right.dtype ): - left_expr = left.expr - if left.dtype == dtypes.DATE_DTYPE: - left_expr = sge.Cast(this=left_expr, to="DATETIME") - right_expr = right.expr - if right.dtype == dtypes.DATE_DTYPE: - right_expr = sge.Cast(this=right_expr, to="DATETIME") + left_expr = _coerce_date_to_datetime(left) + right_expr = _coerce_date_to_datetime(right) return sge.TimestampDiff( this=left_expr, expression=right_expr, unit=sge.Var(this="MICROSECOND") ) @@ -201,3 +205,17 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: @BINARY_OP_REGISTRATION.register(ops.obj_make_ref_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.func("OBJ.MAKE_REF", left.expr, right.expr) + + +def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression: + """Coerce boolean expression to integer.""" + if typed_expr.dtype == dtypes.BOOL_DTYPE: + return sge.Cast(this=typed_expr.expr, to="INT64") + return typed_expr.expr + + +def _coerce_date_to_datetime(typed_expr: TypedExpr) -> sge.Expression: + """Coerce date expression to datetime.""" + if typed_expr.dtype == dtypes.DATE_DTYPE: + return sge.Cast(this=typed_expr.expr, to="DATETIME") + return typed_expr.expr diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_null_match/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_null_match/out.sql new file mode 100644 index 0000000000..90cbcfe5c7 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_null_match/out.sql @@ -0,0 +1,14 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + COALESCE(CAST(`bfcol_1` AS STRING), '$NULL_SENTINEL$') = COALESCE(CAST(CAST(`bfcol_0` AS INT64) AS STRING), '$NULL_SENTINEL$') AS `bfcol_4` + FROM `bfcte_0` +) +SELECT + `bfcol_4` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_numeric/out.sql new file mode 100644 index 0000000000..8e3c52310d --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_numeric/out.sql @@ -0,0 +1,54 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `rowindex` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_2` AS `bfcol_6`, + `bfcol_1` AS `bfcol_7`, + `bfcol_0` AS `bfcol_8`, + `bfcol_1` = `bfcol_1` AS `bfcol_9` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_6` AS `bfcol_14`, + `bfcol_7` AS `bfcol_15`, + `bfcol_8` AS `bfcol_16`, + `bfcol_9` AS `bfcol_17`, + `bfcol_7` = 1 AS `bfcol_18` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + `bfcol_14` AS `bfcol_24`, + `bfcol_15` AS `bfcol_25`, + `bfcol_16` AS `bfcol_26`, + `bfcol_17` AS `bfcol_27`, + `bfcol_18` AS `bfcol_28`, + `bfcol_15` = CAST(`bfcol_16` AS INT64) AS `bfcol_29` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + `bfcol_24` AS `bfcol_36`, + `bfcol_25` AS `bfcol_37`, + `bfcol_26` AS `bfcol_38`, + `bfcol_27` AS `bfcol_39`, + `bfcol_28` AS `bfcol_40`, + `bfcol_29` AS `bfcol_41`, + CAST(`bfcol_26` AS INT64) = `bfcol_25` AS `bfcol_42` + FROM `bfcte_3` +) +SELECT + `bfcol_36` AS `rowindex`, + `bfcol_37` AS `int64_col`, + `bfcol_38` AS `bool_col`, + `bfcol_39` AS `int_ne_int`, + `bfcol_40` AS `int_ne_1`, + `bfcol_41` AS `int_ne_bool`, + `bfcol_42` AS `bool_ne_int` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_ne_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_ne_numeric/out.sql new file mode 100644 index 0000000000..6fba4b960f --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_ne_numeric/out.sql @@ -0,0 +1,54 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `rowindex` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_2` AS `bfcol_6`, + `bfcol_1` AS `bfcol_7`, + `bfcol_0` AS `bfcol_8`, + `bfcol_1` <> `bfcol_1` AS `bfcol_9` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_6` AS `bfcol_14`, + `bfcol_7` AS `bfcol_15`, + `bfcol_8` AS `bfcol_16`, + `bfcol_9` AS `bfcol_17`, + `bfcol_7` <> 1 AS `bfcol_18` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + `bfcol_14` AS `bfcol_24`, + `bfcol_15` AS `bfcol_25`, + `bfcol_16` AS `bfcol_26`, + `bfcol_17` AS `bfcol_27`, + `bfcol_18` AS `bfcol_28`, + `bfcol_15` <> CAST(`bfcol_16` AS INT64) AS `bfcol_29` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + `bfcol_24` AS `bfcol_36`, + `bfcol_25` AS `bfcol_37`, + `bfcol_26` AS `bfcol_38`, + `bfcol_27` AS `bfcol_39`, + `bfcol_28` AS `bfcol_40`, + `bfcol_29` AS `bfcol_41`, + CAST(`bfcol_26` AS INT64) <> `bfcol_25` AS `bfcol_42` + FROM `bfcte_3` +) +SELECT + `bfcol_36` AS `rowindex`, + `bfcol_37` AS `int64_col`, + `bfcol_38` AS `bool_col`, + `bfcol_39` AS `int_ne_int`, + `bfcol_40` AS `int_ne_1`, + `bfcol_41` AS `int_ne_bool`, + `bfcol_42` AS `bool_ne_int` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py index 49426fe6c3..11586cad02 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py @@ -107,6 +107,24 @@ def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(bf_df.sql, "out.sql") +def test_eq_null_match(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + sql = _apply_binary_op(bf_df, ops.eq_null_match_op, "int64_col", "bool_col") + snapshot.assert_match(sql, "out.sql") + + +def test_eq_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_ne_int"] = bf_df["int64_col"] == bf_df["int64_col"] + bf_df["int_ne_1"] = bf_df["int64_col"] == 1 + + bf_df["int_ne_bool"] = bf_df["int64_col"] == bf_df["bool_col"] + bf_df["bool_ne_int"] = bf_df["bool_col"] == bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col"]] @@ -121,8 +139,6 @@ def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot): bf_df["int_div_bool"] = bf_df["int64_col"] // bf_df["bool_col"] bf_df["bool_div_int"] = bf_df["bool_col"] // bf_df["int64_col"] - snapshot.assert_match(bf_df.sql, "out.sql") - def test_floordiv_timedelta(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["timestamp_col", "date_col"]] @@ -200,3 +216,15 @@ def test_mul_timedelta(scalar_types_df: bpd.DataFrame, snapshot): def test_obj_make_ref(scalar_types_df: bpd.DataFrame, snapshot): blob_df = scalar_types_df["string_col"].str.to_blob() snapshot.assert_match(blob_df.to_frame().sql, "out.sql") + + +def test_ne_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_ne_int"] = bf_df["int64_col"] != bf_df["int64_col"] + bf_df["int_ne_1"] = bf_df["int64_col"] != 1 + + bf_df["int_ne_bool"] = bf_df["int64_col"] != bf_df["bool_col"] + bf_df["bool_ne_int"] = bf_df["bool_col"] != bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql")