Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@

from google.cloud import bigquery
import numpy as np
import pandas as pd
import pyarrow as pa
import sqlglot as sg
import sqlglot.dialects.bigquery
import sqlglot.expressions as sge

from bigframes import dtypes
from bigframes.core import guid, local_data, schema, utils
from bigframes.core.compile.sqlglot.expressions import typed_expr
from bigframes.core.compile.sqlglot.expressions import constants, typed_expr
import bigframes.core.compile.sqlglot.sqlglot_types as sgt

# shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0.
Expand Down Expand Up @@ -639,12 +640,30 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
if sqlglot_type is None:
if value is not None:
raise ValueError("Cannot infer SQLGlot type from None dtype.")
if not pd.isna(value):
raise ValueError(f"Cannot infer SQLGlot type from None dtype: {value}")
return sge.Null()

if value is None:
return _cast(sge.Null(), sqlglot_type)
if dtypes.is_struct_like(dtype):
items = [
_literal(value=value[field_name], dtype=field_dtype).as_(
field_name, quoted=True
)
for field_name, field_dtype in dtypes.get_struct_fields(dtype).items()
]
return sge.Struct.from_arg_list(items)
elif dtypes.is_array_like(dtype):
value_type = dtypes.get_array_inner_type(dtype)
values = sge.Array(
expressions=[_literal(value=v, dtype=value_type) for v in value]
)
return values if len(value) > 0 else _cast(values, sqlglot_type)
elif pd.isna(value):
return _cast(sge.Null(), sqlglot_type)
elif dtype == dtypes.JSON_DTYPE:
return sge.ParseJSON(this=sge.convert(str(value)))
elif dtype == dtypes.BYTES_DTYPE:
return _cast(str(value), sqlglot_type)
elif dtypes.is_time_like(dtype):
Expand All @@ -658,24 +677,12 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
elif dtypes.is_geo_like(dtype):
wkt = value if isinstance(value, str) else to_wkt(value)
return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt))
elif dtype == dtypes.JSON_DTYPE:
return sge.ParseJSON(this=sge.convert(str(value)))
elif dtype == dtypes.TIMEDELTA_DTYPE:
return sge.convert(utils.timedelta_to_micros(value))
elif dtypes.is_struct_like(dtype):
items = [
_literal(value=value[field_name], dtype=field_dtype).as_(
field_name, quoted=True
)
for field_name, field_dtype in dtypes.get_struct_fields(dtype).items()
]
return sge.Struct.from_arg_list(items)
elif dtypes.is_array_like(dtype):
value_type = dtypes.get_array_inner_type(dtype)
values = sge.Array(
expressions=[_literal(value=v, dtype=value_type) for v in value]
)
return values if len(value) > 0 else _cast(values, sqlglot_type)
elif dtype == dtypes.FLOAT_DTYPE:
if np.isinf(value):
return constants._INF if value > 0 else constants._NEG_INF
return sge.convert(value)
else:
if isinstance(value, np.generic):
value = value.item()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
WITH `bfcte_0` AS (
SELECT
*
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` FLOAT64, `bfcol_1` FLOAT64, `bfcol_2` FLOAT64, `bfcol_3` FLOAT64, `bfcol_4` STRUCT<foo INT64>, `bfcol_5` STRUCT<foo INT64>, `bfcol_6` ARRAY<INT64>, `bfcol_7` INT64>>[STRUCT(
CAST(NULL AS FLOAT64),
CAST('Infinity' AS FLOAT64),
CAST('-Infinity' AS FLOAT64),
CAST(NULL AS FLOAT64),
CAST(NULL AS STRUCT<foo INT64>),
STRUCT(CAST(NULL AS INT64) AS `foo`),
ARRAY<INT64>[],
0
), STRUCT(1.0, 1.0, 1.0, 1.0, STRUCT(1 AS `foo`), STRUCT(1 AS `foo`), [1, 2], 1), STRUCT(2.0, 2.0, 2.0, 2.0, STRUCT(2 AS `foo`), STRUCT(2 AS `foo`), [3, 4], 2)])
)
SELECT
`bfcol_0` AS `col_none`,
`bfcol_1` AS `col_inf`,
`bfcol_2` AS `col_neginf`,
`bfcol_3` AS `col_nan`,
`bfcol_4` AS `col_struct_none`,
`bfcol_5` AS `col_struct_w_none`,
`bfcol_6` AS `col_list_none`
FROM `bfcte_0`
ORDER BY
`bfcol_7` ASC NULLS LAST
23 changes: 23 additions & 0 deletions tests/unit/core/compile/sqlglot/test_compile_readlocal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -58,3 +61,23 @@ def test_compile_readlocal_w_json_df(
):
bf_df = bpd.DataFrame(json_pandas_df, session=compiler_session_w_json_types)
snapshot.assert_match(bf_df.sql, "out.sql")


def test_compile_readlocal_w_special_values(
compiler_session: bigframes.Session, snapshot
):
if sys.version_info < (3, 12):
pytest.skip("Skipping test due to inconsistent SQL formatting")
df = pd.DataFrame(
{
"col_none": [None, 1, 2],
"col_inf": [np.inf, 1.0, 2.0],
"col_neginf": [-np.inf, 1.0, 2.0],
"col_nan": [np.nan, 1.0, 2.0],
"col_struct_none": [None, {"foo": 1}, {"foo": 2}],
"col_struct_w_none": [{"foo": None}, {"foo": 1}, {"foo": 2}],
"col_list_none": [None, [1, 2], [3, 4]],
}
)
bf_df = bpd.DataFrame(df, session=compiler_session)
snapshot.assert_match(bf_df.sql, "out.sql")